Skip to content

Commit

Permalink
fix(nyz): fix many unittest bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Dec 9, 2024
1 parent 571229e commit 765b8fb
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 24 deletions.
8 changes: 5 additions & 3 deletions ding/envs/env/ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ def __init__(
self._observation_space = self._env.observation_space
self._action_space = self._env.action_space
self._action_space.seed(0) # default seed
self._reward_space = gym.spaces.Box(
low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
)
try:
low, high = self._env.reward_range
except: # for compatibility with gymnasium high-version API
low, high = -1, 1
self._reward_space = gym.spaces.Box(low=low, high=high, shape=(1, ), dtype=np.float32)
self._init_flag = True
else:
assert 'env_id' in self._cfg
Expand Down
3 changes: 1 addition & 2 deletions ding/envs/env/tests/test_ding_env_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def test_cartpole_pendulum(self, env_id):
def test_cartpole_pendulum_gymnasium(self, env_id):
env = gymnasium.make(env_id)
ding_env = DingEnvWrapper(env=env)
print(ding_env.observation_space, ding_env.action_space, ding_env.reward_space)
cfg = EasyDict(dict(
collector_env_num=16,
evaluator_env_num=3,
Expand Down Expand Up @@ -142,7 +141,7 @@ def test_atari(self, atari_env_id):
# assert isinstance(action, np.ndarray)
assert action.shape == (1, )

@pytest.mark.unittest
@pytest.mark.envtest
@pytest.mark.parametrize('lun_bip_env_id', ['LunarLander-v2', 'LunarLanderContinuous-v2', 'BipedalWalker-v3'])
def test_lunarlander_bipedalwalker(self, lun_bip_env_id):
env_cfg = EasyDict(
Expand Down
11 changes: 5 additions & 6 deletions ding/framework/middleware/tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from os import path
import os
import copy
from functools import partial
from easydict import EasyDict
from collections import deque
import pytest
import shutil
import wandb
import h5py
import torch.nn as nn
from unittest.mock import MagicMock
from unittest.mock import Mock, patch
Expand Down Expand Up @@ -207,7 +205,6 @@ def test_wandb_online_logger():
env = TheEnvClass()
ctx = OnlineRLContext()
ctx.train_output = [{'reward': 1, 'q_value': [1.0]}]
model = TheModelClass()
wandb.init(config=cfg, anonymous="must")

def mock_metric_logger(data, step):
Expand All @@ -233,15 +230,17 @@ def mock_metric_logger(data, step):
]
assert set(data.keys()) <= set(metric_list)

def mock_gradient_logger(input_model, log, log_freq, log_graph):
def mock_gradient_logger(input_model, model, log, log_freq, log_graph):
assert input_model == model

def test_wandb_online_logger_metric():
model = TheModelClass()
with patch.object(wandb, 'log', new=mock_metric_logger):
wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx)

def test_wandb_online_logger_gradient():
with patch.object(wandb, 'watch', new=mock_gradient_logger):
model = TheModelClass()
with patch.object(wandb, 'watch', new=partial(mock_gradient_logger, model=model)):
wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx)

test_wandb_online_logger_metric()
Expand Down
5 changes: 3 additions & 2 deletions ding/model/template/hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
import torch.nn as nn
from ding.model.common.head import DuelingHead
from ding.utils.registry_factory import MODEL_REGISTRY
from ding.utils import MODEL_REGISTRY, squeeze


@MODEL_REGISTRY.register('hpt')
Expand Down Expand Up @@ -36,9 +36,10 @@ def __init__(self, state_dim: int, action_dim: int):
"""
super(HPT, self).__init__()
# Initialise Policy Stem
self.policy_stem = PolicyStem()
self.policy_stem = PolicyStem(state_dim, 128)
self.policy_stem.init_cross_attn()

action_dim = squeeze(action_dim)
# Dueling Head, input is 16*128, output is action dimension
self.head = DuelingHead(hidden_size=16 * 128, output_size=action_dim)

Expand Down
15 changes: 5 additions & 10 deletions ding/model/template/tests/test_hpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,17 @@ def test_hpt(self, obs_shape, act_shape):
inputs = torch.randn(B, *obs_shape)
state_dim = obs_shape[0]

if isinstance(act_shape, int):
action_dim = act_shape
else:
action_dim = len(act_shape)

model = HPT(state_dim=state_dim, action_dim=action_dim)
model = HPT(state_dim=state_dim, action_dim=act_shape)
outputs = model(inputs)

assert isinstance(outputs, torch.Tensor)
assert isinstance(outputs, dict)

if isinstance(act_shape, int):
assert outputs.shape == (B, act_shape)
assert outputs['logit'].shape == (B, act_shape)
elif len(act_shape) == 1:
assert outputs.shape == (B, *act_shape)
assert outputs['logit'].shape == (B, *act_shape)
else:
for i, s in enumerate(act_shape):
assert outputs[i].shape == (B, s)
assert outputs['logit'][i].shape == (B, s)

self.output_check(model, outputs)
8 changes: 7 additions & 1 deletion ding/torch_utils/optimizer_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,13 @@ def _state_init(self, p, momentum, centered):
"""

state = self.state[p]
state['step'] = 0
if torch.__version__ < "1.12.0":
state['step'] = 0
# TODO
# wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
else:
state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
if self.defaults['capturable'] else torch.tensor(0.)
state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device)
state['square_avg'] = torch.zeros_like(p.data, device=p.data.device)
if momentum:
Expand Down

0 comments on commit 765b8fb

Please sign in to comment.