Skip to content

Commit

Permalink
improve code style
Browse files Browse the repository at this point in the history
  • Loading branch information
muchvo committed May 5, 2024
1 parent 5be00f5 commit 960ef39
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
6 changes: 5 additions & 1 deletion omnisafe/algorithms/off_policy/crabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@ def _init_model(self) -> None:
self.mean_policy = MeanPolicy(self._actor_critic.actor)

self.model, self.model_trainer = create_model_and_trainer(
self._cfgs, self.dim_state, self.dim_action, self.normalizer, self._device,
self._cfgs,
self.dim_state,
self.dim_action,
self.normalizer,
self._device,
)

def _init_log(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions omnisafe/common/offline/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__( # pylint: disable=too-many-branches
# Load data from local .npz file
try:
data = np.load(dataset_name)
except (FileNotFoundError, IsADirectoryError, ValueError, OSError) as e:
except (ValueError, OSError) as e:
raise ValueError(f'Failed to load data from {dataset_name}') from e

else:
Expand Down Expand Up @@ -284,7 +284,7 @@ def __init__( # pylint: disable=too-many-branches, super-init-not-called
# Load data from local .npz file
try:
data = np.load(dataset_name)
except (FileNotFoundError, IsADirectoryError, ValueError, OSError) as e:
except (ValueError, OSError) as e:
raise ValueError(f'Failed to load data from {dataset_name}') from e

else:
Expand Down
27 changes: 23 additions & 4 deletions omnisafe/envs/classic_control/envs_from_crabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ class SafeInvertedPendulumEnv(InvertedPendulumEnv, SafeEnv):
episode_unsafe = False

def __init__(
self, threshold=0.2, task='upright', random_reset=False, violation_penalty=10, **kwargs,
self,
threshold=0.2,
task='upright',
random_reset=False,
violation_penalty=10,
**kwargs,
) -> None:
"""Initialize the environment."""
self.threshold = threshold
Expand All @@ -82,7 +87,11 @@ def __init__(
self.violation_penalty = violation_penalty
super().__init__(**kwargs)
EzPickle.__init__(
self, threshold=threshold, task=task, random_reset=random_reset, **kwargs,
self,
threshold=threshold,
task=task,
random_reset=random_reset,
**kwargs,
) # deepcopy calls `get_state`

def reset_model(self):
Expand Down Expand Up @@ -144,7 +153,12 @@ class SafeInvertedPendulumSwingEnv(SafeInvertedPendulumEnv):
"""Safe Inverted Pendulum Swing Environment."""

def __init__(
self, threshold=1.5, task='swing', random_reset=False, violation_penalty=10, **kwargs,
self,
threshold=1.5,
task='swing',
random_reset=False,
violation_penalty=10,
**kwargs,
) -> None:
"""Initialize the environment."""
super().__init__(threshold=threshold, task=task, **kwargs)
Expand All @@ -154,7 +168,12 @@ class SafeInvertedPendulumMoveEnv(SafeInvertedPendulumEnv):
"""Safe Inverted Pendulum Move Environment."""

def __init__(
self, threshold=0.2, task='move', random_reset=False, violation_penalty=10, **kwargs,
self,
threshold=0.2,
task='move',
random_reset=False,
violation_penalty=10,
**kwargs,
) -> None:
"""Initialize the environment."""
super().__init__(threshold=threshold, task=task, **kwargs)
Expand Down

0 comments on commit 960ef39

Please sign in to comment.