Skip to content

Commit 043001b

Browse files
authored
polish(zc): change PD config name (#749)
* add action * change entry
1 parent e9a978e commit 043001b

10 files changed

+11
-9
lines changed

ding/policy/plan_diffuser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def _init_learn(self) -> None:
178178
self.step_start_update_target = self._cfg.learn.step_start_update_target
179179
self.target_weight = self._cfg.learn.target_weight
180180
self.value_step = self._cfg.learn.value_step
181-
self.use_target = True
181+
self.use_target = False
182182
self.horizon = self._cfg.model.diffuser_model_cfg.horizon
183183
self.include_returns = self._cfg.learn.include_returns
184184

ding/utils/data/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,11 +1090,13 @@ def __getitem__(self, idx, eps=1e-4):
10901090
'trajectories': trajectories,
10911091
'returns': returns,
10921092
'done': done,
1093+
'action': actions,
10931094
}
10941095
else:
10951096
batch = {
10961097
'trajectories': trajectories,
10971098
'done': done,
1099+
'action': actions,
10981100
}
10991101

11001102
batch.update(self.get_conditions(observations))

dizoo/d4rl/config/antmaze_umaze_pd_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
model=dict(
2525
diffuser_model='GaussianDiffusion',
2626
diffuser_model_cfg=dict(
27-
model='TemporalUnet',
27+
model='DiffusionUNet1d',
2828
model_cfg=dict(
2929
transition_dim=37,
3030
dim=32,

dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
model=dict(
2525
diffuser_model='GaussianDiffusion',
2626
diffuser_model_cfg=dict(
27-
model='TemporalUnet',
27+
model='DiffusionUNet1d',
2828
model_cfg=dict(
2929
transition_dim=23,
3030
dim=32,

dizoo/d4rl/config/halfcheetah_medium_pd_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
model=dict(
2525
diffuser_model='GaussianDiffusion',
2626
diffuser_model_cfg=dict(
27-
model='TemporalUnet',
27+
model='DiffusionUNet1d',
2828
model_cfg=dict(
2929
transition_dim=23,
3030
dim=32,

dizoo/d4rl/config/hopper_medium_expert_pd_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
model=dict(
2525
diffuser_model='GaussianDiffusion',
2626
diffuser_model_cfg=dict(
27-
model='TemporalUnet',
27+
model='DiffusionUNet1d',
2828
model_cfg=dict(
2929
transition_dim=14,
3030
dim=32,

dizoo/d4rl/config/hopper_medium_pd_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
model=dict(
2525
diffuser_model='GaussianDiffusion',
2626
diffuser_model_cfg=dict(
27-
model='TemporalUnet',
27+
model='DiffusionUNet1d',
2828
model_cfg=dict(
2929
transition_dim=14,
3030
dim=32,

dizoo/d4rl/config/walker2d_medium_expert_pd_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
model=dict(
2525
diffuser_model='GaussianDiffusion',
2626
diffuser_model_cfg=dict(
27-
model='TemporalUnet',
27+
model='DiffusionUNet1d',
2828
model_cfg=dict(
2929
transition_dim=23,
3030
dim=32,

dizoo/d4rl/config/walker2d_medium_pd_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
model=dict(
2525
diffuser_model='GaussianDiffusion',
2626
diffuser_model_cfg=dict(
27-
model='TemporalUnet',
27+
model='DiffusionUNet1d',
2828
model_cfg=dict(
2929
transition_dim=23,
3030
dim=32,

dizoo/d4rl/entry/d4rl_pd_main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ def train(args):
1616

1717
parser = argparse.ArgumentParser()
1818
parser.add_argument('--seed', '-s', type=int, default=10)
19-
parser.add_argument('--config', '-c', type=str, default='hopper_expert_cql_config.py')
19+
parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py')
2020
args = parser.parse_args()
2121
train(args)

0 commit comments

Comments
 (0)