Skip to content

Commit

Permalink
style: fix pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Gaiejj committed May 16, 2024
1 parent 483e427 commit dd9068f
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
- id: debug-statements
- id: double-quote-string-fixer
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.4.2
rev: v0.4.4
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand Down
4 changes: 2 additions & 2 deletions omnisafe/adapter/modelbased_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ def rollout( # pylint: disable=too-many-arguments,too-many-locals

update_actor_critic_time = 0.0
update_dynamics_time = 0.0
if use_eval:
eval_time = 0.0

eval_time = 0.0

epoch_steps = 0

Expand Down
3 changes: 2 additions & 1 deletion omnisafe/common/robust_barrier_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def get_cbf_qp_constraints(
mean_pred_batch = torch.unsqueeze(mean_pred_batch, -1).to(self.device)
sigma_pred_batch = torch.unsqueeze(sigma_pred_batch, -1).to(self.device)
if self.env.dynamics_mode == 'Unicycle':

num_cbfs = len(self.env.hazards)
l_p = self.l_p
buffer = 0.1
Expand Down Expand Up @@ -299,6 +298,8 @@ def get_cbf_qp_constraints(
.to(self.device)
)
q = torch.zeros((batch_size, n_u + 1)).to(self.device)
else:
raise NotImplementedError

n_u = action_batch.shape[1]

Expand Down
6 changes: 5 additions & 1 deletion omnisafe/envs/safety_gymnasium_modelbased.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,8 @@ def get_cost_from_obs_tensor(self, obs: torch.Tensor, is_binary: bool = True) ->
elif len(obs.shape) == 3:
batch_size = obs.shape[0] * obs.shape[1]
hazard_obs = obs[:, :, hazards_key].reshape(batch_size, -1, 2)
else:
raise NotImplementedError
hazards_dist = torch.sqrt(torch.sum(torch.square(hazard_obs), dim=2)).reshape(
batch_size,
-1,
Expand Down Expand Up @@ -497,8 +499,10 @@ def reset(
self.get_lidar_from_coordinate(flat_coordinate_obs)
info['obs_original'] = obs_original
info['goal_met'] = False

obs = torch.as_tensor(flat_coordinate_obs, dtype=torch.float32, device=self._device)
else:
obs = torch.as_tensor(obs_original, dtype=torch.float32, device=self._device)

return obs, info

def set_seed(self, seed: int) -> None:
Expand Down
3 changes: 1 addition & 2 deletions omnisafe/utils/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def plot_data(
smoothed_x = np.convolve(x, y, 'same') / np.convolve(z, y, 'same')
datum['Costs'] = smoothed_x

if isinstance(data, list):
data_to_plot = pd.concat(data, ignore_index=True)
data_to_plot = pd.concat(data, ignore_index=True)
sns.lineplot(
data=data_to_plot,
x=xaxis,
Expand Down

0 comments on commit dd9068f

Please sign in to comment.