Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary SDE resampling in PPO update #1933

Merged
merged 9 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
python -m pip install --upgrade pip
# cpu version of pytorch
pip install torch==2.1.0 --index-url https://download.pytorch.org/whl/cpu
pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu
# Install Atari Roms
pip install autorom
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ lint:
# see https://www.flake8rules.com/
ruff check ${LINT_PATHS} --select=E9,F63,F7,F82 --output-format=full
# exit-zero treats all errors as warnings.
ruff check ${LINT_PATHS} --exit-zero
ruff check ${LINT_PATHS} --exit-zero --output-format=concise

format:
# Sort imports
Expand Down
6 changes: 5 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a3 (WIP)
Release 2.4.0a4 (WIP)
--------------------------

Breaking Changes:
Expand All @@ -19,6 +19,7 @@ Bug Fixes:
- Cast type in compute gae method to avoid error when using torch compile (@amjames)
- ``CallbackList`` now sets the ``.parent`` attribute of child callbacks to its own ``.parent``. (will-maclean)
- Fixed error when loading a model that has ``net_arch`` manually set to ``None`` (@jak3122)
- Set requirement numpy<2.0 until PyTorch is compatible (https://github.com/pytorch/pytorch/issues/107302)

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand All @@ -35,6 +36,8 @@ Deprecations:
Others:
^^^^^^^
- Fixed various typos (@cschindlbeck)
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -1664,3 +1667,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.28.1,<0.30",
"numpy>=1.20",
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
"torch>=1.13",
# For saving models
"cloudpickle",
Expand Down
4 changes: 0 additions & 4 deletions stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ def train(self) -> None:
# Convert discrete action from float to long
actions = rollout_data.actions.long().flatten()

# Re-sample the noise matrix because the log_std has changed
if self.use_sde:
self.policy.reset_noise(self.batch_size)

values, log_prob, entropy = self.policy.evaluate_actions(rollout_data.observations, actions)
values = values.flatten()
# Normalize advantage
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a3
2.4.0a4
4 changes: 2 additions & 2 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,11 +791,11 @@ def test_cast_lr_schedule(tmp_path):
# Note: for recent version of numpy, np.float64 is a subclass of float
# so we need to use type here
# assert isinstance(model.lr_schedule(1.0), float)
assert type(model.lr_schedule(1.0)) is float # noqa: E721
assert type(model.lr_schedule(1.0)) is float
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))
model.save(tmp_path / "ppo.zip")
model = PPO.load(tmp_path / "ppo.zip")
assert type(model.lr_schedule(1.0)) is float # noqa: E721
assert type(model.lr_schedule(1.0)) is float
assert np.allclose(model.lr_schedule(0.5), 0.5 * np.sin(1.0))


Expand Down
Loading