Skip to content

Commit 4e622dc

Browse files
authored
Cast initial_log_std parameter to float in PyTorch (#297)
1 parent 6cc871b commit 4e622dc

File tree

7 files changed

+257
-22
lines changed

7 files changed

+257
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
99
- Allow `None` type spaces and samples/values in spaces utilities
1010

1111
### Fixed
12+
- Cast model instantiator's `initial_log_std` parameter to `float` in PyTorch
1213
- Fix common property overwriting (e.g. `clip_actions`) in shared models composed of different mixin types
1314

1415
## [1.4.1] - 2025-01-27

skrl/utils/model_instantiators/torch/gaussian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(self, observation_space, action_space, device, clip_actions,
102102
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
103103
104104
{networks}
105-
self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std})
105+
self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={float(initial_log_std)}), requires_grad={not fixed_log_std})
106106
107107
def compute(self, inputs, role=""):
108108
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))

skrl/utils/model_instantiators/torch/multivariate_gaussian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(self, observation_space, action_space, device, clip_actions,
9797
MultivariateGaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std)
9898
9999
{networks}
100-
self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={initial_log_std}), requires_grad={not fixed_log_std})
100+
self.log_std_parameter = nn.Parameter(torch.full(size=({output["size"]},), fill_value={float(initial_log_std)}), requires_grad={not fixed_log_std})
101101
102102
def compute(self, inputs, role=""):
103103
states = unflatten_tensorized_space(self.observation_space, inputs.get("states"))

tests/jax/test_jax_model_instantiators.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
import yaml
44
from gymnasium import spaces
55

6-
from skrl.utils.model_instantiators.jax import categorical_model, deterministic_model, gaussian_model
6+
from skrl.utils.model_instantiators.jax import (
7+
categorical_model,
8+
deterministic_model,
9+
gaussian_model,
10+
multicategorical_model,
11+
)
712
from skrl.utils.spaces.jax import flatten_tensorized_space, sample_space
813

914

@@ -87,10 +92,42 @@ def test_categorical_model(capsys, device):
8792
)
8893
model.init_state_dict("model")
8994

90-
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "jax", device))})
95+
output = model.act(
96+
{
97+
"states": flatten_tensorized_space(
98+
sample_space(observation_space, batch_size=10, backend="native", device=device)
99+
)
100+
}
101+
)
91102
assert output[0].shape == (10, 1)
92103

93104

105+
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
106+
def test_multicategorical_model(capsys, device):
107+
# observation
108+
action_space = spaces.MultiDiscrete([2, 3])
109+
for observation_space_type in [spaces.Box, spaces.Tuple, spaces.Dict]:
110+
observation_space = NETWORK_SPEC_OBSERVATION[observation_space_type][1]
111+
model = multicategorical_model(
112+
observation_space=observation_space,
113+
action_space=action_space,
114+
device=device,
115+
unnormalized_log_prob=True,
116+
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
117+
output="ACTIONS",
118+
)
119+
model.init_state_dict("model")
120+
121+
output = model.act(
122+
{
123+
"states": flatten_tensorized_space(
124+
sample_space(observation_space, batch_size=10, backend="native", device=device)
125+
)
126+
}
127+
)
128+
assert output[0].shape == (10, 2)
129+
130+
94131
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
95132
def test_deterministic_model(capsys, device):
96133
# observation
@@ -107,7 +144,13 @@ def test_deterministic_model(capsys, device):
107144
)
108145
model.init_state_dict("model")
109146

110-
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "jax", device))})
147+
output = model.act(
148+
{
149+
"states": flatten_tensorized_space(
150+
sample_space(observation_space, batch_size=10, backend="native", device=device)
151+
)
152+
}
153+
)
111154
assert output[0].shape == (10, 2)
112155

113156

@@ -131,5 +174,11 @@ def test_gaussian_model(capsys, device):
131174
)
132175
model.init_state_dict("model")
133176

134-
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "jax", device))})
177+
output = model.act(
178+
{
179+
"states": flatten_tensorized_space(
180+
sample_space(observation_space, batch_size=10, backend="native", device=device)
181+
)
182+
}
183+
)
135184
assert output[0].shape == (10, 2)

tests/jax/test_jax_model_instantiators_definition.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
import jax.numpy as jnp
1010
import numpy as np
1111

12-
from skrl.utils.model_instantiators.jax import Shape, categorical_model, deterministic_model, gaussian_model
12+
from skrl.utils.model_instantiators.jax import (
13+
Shape,
14+
categorical_model,
15+
deterministic_model,
16+
gaussian_model,
17+
multicategorical_model,
18+
)
1319
from skrl.utils.model_instantiators.jax.common import _generate_modules, _get_activation_function, _parse_input
1420

1521

@@ -255,3 +261,40 @@ def test_categorical_model(capsys):
255261
observations = jnp.ones((10, model.num_observations))
256262
output = model.act({"states": observations})
257263
assert output[0].shape == (10, 1)
264+
265+
266+
def test_multicategorical_model(capsys):
267+
device = "cpu"
268+
observation_space = gym.spaces.Box(np.array([-1] * 5), np.array([1] * 5))
269+
action_space = gym.spaces.MultiDiscrete([2, 3])
270+
271+
content = r"""
272+
unnormalized_log_prob: True
273+
network:
274+
- name: net
275+
input: OBSERVATIONS
276+
layers:
277+
- linear: 32
278+
- linear: [32]
279+
- linear: {out_features: 32}
280+
activations: elu
281+
output: ACTIONS
282+
"""
283+
content = yaml.safe_load(content)
284+
# source
285+
model = multicategorical_model(
286+
observation_space=observation_space, action_space=action_space, device=device, return_source=True, **content
287+
)
288+
with capsys.disabled():
289+
print(model)
290+
# instance
291+
model = multicategorical_model(
292+
observation_space=observation_space, action_space=action_space, device=device, return_source=False, **content
293+
)
294+
model.init_state_dict("model")
295+
with capsys.disabled():
296+
print(model)
297+
298+
observations = jnp.ones((10, model.num_observations))
299+
output = model.act({"states": observations})
300+
assert output[0].shape == (10, 2)

tests/torch/test_torch_model_instantiators.py

Lines changed: 118 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
categorical_model,
88
deterministic_model,
99
gaussian_model,
10+
multicategorical_model,
1011
multivariate_gaussian_model,
1112
shared_model,
1213
)
@@ -91,12 +92,44 @@ def test_categorical_model(capsys, device):
9192
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
9293
output="ACTIONS",
9394
)
94-
model.to(device=device)
95+
model.to(device=model.device)
9596

96-
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "torch", device))})
97+
output = model.act(
98+
{
99+
"states": flatten_tensorized_space(
100+
sample_space(observation_space, batch_size=10, backend="native", device=device)
101+
)
102+
}
103+
)
97104
assert output[0].shape == (10, 1)
98105

99106

107+
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
108+
def test_multicategorical_model(capsys, device):
109+
# observation
110+
action_space = spaces.MultiDiscrete([2, 3])
111+
for observation_space_type in [spaces.Box, spaces.Tuple, spaces.Dict]:
112+
observation_space = NETWORK_SPEC_OBSERVATION[observation_space_type][1]
113+
model = multicategorical_model(
114+
observation_space=observation_space,
115+
action_space=action_space,
116+
device=device,
117+
unnormalized_log_prob=True,
118+
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
119+
output="ACTIONS",
120+
)
121+
model.to(device=model.device)
122+
123+
output = model.act(
124+
{
125+
"states": flatten_tensorized_space(
126+
sample_space(observation_space, batch_size=10, backend="native", device=device)
127+
)
128+
}
129+
)
130+
assert output[0].shape == (10, 2)
131+
132+
100133
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
101134
def test_deterministic_model(capsys, device):
102135
# observation
@@ -111,9 +144,15 @@ def test_deterministic_model(capsys, device):
111144
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
112145
output="ACTIONS",
113146
)
114-
model.to(device=device)
147+
model.to(device=model.device)
115148

116-
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "torch", device))})
149+
output = model.act(
150+
{
151+
"states": flatten_tensorized_space(
152+
sample_space(observation_space, batch_size=10, backend="native", device=device)
153+
)
154+
}
155+
)
117156
assert output[0].shape == (10, 2)
118157

119158

@@ -135,9 +174,15 @@ def test_gaussian_model(capsys, device):
135174
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
136175
output="ACTIONS",
137176
)
138-
model.to(device=device)
177+
model.to(device=model.device)
139178

140-
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "torch", device))})
179+
output = model.act(
180+
{
181+
"states": flatten_tensorized_space(
182+
sample_space(observation_space, batch_size=10, backend="native", device=device)
183+
)
184+
}
185+
)
141186
assert output[0].shape == (10, 2)
142187

143188

@@ -159,9 +204,15 @@ def test_multivariate_gaussian_model(capsys, device):
159204
network=yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
160205
output="ACTIONS",
161206
)
162-
model.to(device=device)
207+
model.to(device=model.device)
163208

164-
output = model.act({"states": flatten_tensorized_space(sample_space(observation_space, 10, "torch", device))})
209+
output = model.act(
210+
{
211+
"states": flatten_tensorized_space(
212+
sample_space(observation_space, batch_size=10, backend="native", device=device)
213+
)
214+
}
215+
)
165216
assert output[0].shape == (10, 2)
166217

167218

@@ -196,9 +247,13 @@ def test_shared_gaussian_deterministic_model(capsys, device, single_forward_pass
196247
],
197248
single_forward_pass=single_forward_pass,
198249
)
199-
model.to(device=device)
250+
model.to(device=model.device)
200251

201-
inputs = {"states": flatten_tensorized_space(sample_space(observation_space, 10, "torch", device))}
252+
inputs = {
253+
"states": flatten_tensorized_space(
254+
sample_space(observation_space, batch_size=10, backend="native", device=device)
255+
)
256+
}
202257
output = model.act(inputs, role="role_0")
203258
assert output[0].shape == (10, 2)
204259
output = model.act(inputs, role="role_1")
@@ -236,9 +291,13 @@ def test_shared_multivariate_gaussian_deterministic_model(capsys, device, single
236291
],
237292
single_forward_pass=single_forward_pass,
238293
)
239-
model.to(device=device)
294+
model.to(device=model.device)
240295

241-
inputs = {"states": flatten_tensorized_space(sample_space(observation_space, 10, "torch", device))}
296+
inputs = {
297+
"states": flatten_tensorized_space(
298+
sample_space(observation_space, batch_size=10, backend="native", device=device)
299+
)
300+
}
242301
output = model.act(inputs, role="role_0")
243302
assert output[0].shape == (10, 2)
244303
output = model.act(inputs, role="role_1")
@@ -249,7 +308,7 @@ def test_shared_multivariate_gaussian_deterministic_model(capsys, device, single
249308
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
250309
def test_shared_categorical_deterministic_model(capsys, device, single_forward_pass):
251310
# observation
252-
action_space = spaces.Box(low=-1, high=1, shape=(2,))
311+
action_space = spaces.Discrete(2)
253312
for observation_space_type in [spaces.Box, spaces.Tuple, spaces.Dict]:
254313
observation_space = NETWORK_SPEC_OBSERVATION[observation_space_type][1]
255314
model = shared_model(
@@ -272,10 +331,54 @@ def test_shared_categorical_deterministic_model(capsys, device, single_forward_p
272331
],
273332
single_forward_pass=single_forward_pass,
274333
)
275-
model.to(device=device)
334+
model.to(device=model.device)
276335

277-
inputs = {"states": flatten_tensorized_space(sample_space(observation_space, 10, "torch", device))}
336+
inputs = {
337+
"states": flatten_tensorized_space(
338+
sample_space(observation_space, batch_size=10, backend="native", device=device)
339+
)
340+
}
278341
output = model.act(inputs, role="role_0")
279342
assert output[0].shape == (10, 1)
280343
output = model.act(inputs, role="role_1")
281344
assert output[0].shape == (10, 1)
345+
346+
347+
@pytest.mark.parametrize("single_forward_pass", [True, False])
348+
@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
349+
def test_shared_multicategorical_deterministic_model(capsys, device, single_forward_pass):
350+
# observation
351+
action_space = spaces.MultiDiscrete([2, 3])
352+
for observation_space_type in [spaces.Box, spaces.Tuple, spaces.Dict]:
353+
observation_space = NETWORK_SPEC_OBSERVATION[observation_space_type][1]
354+
model = shared_model(
355+
observation_space=observation_space,
356+
action_space=action_space,
357+
device=device,
358+
structure=["MultiCategoricalMixin", "DeterministicMixin"],
359+
roles=["role_0", "role_1"],
360+
parameters=[
361+
{
362+
"unnormalized_log_prob": True,
363+
"network": yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
364+
"output": "ACTIONS",
365+
},
366+
{
367+
"clip_actions": False,
368+
"network": yaml.safe_load(NETWORK_SPEC_OBSERVATION[observation_space_type][0])["network"],
369+
"output": "ONE",
370+
},
371+
],
372+
single_forward_pass=single_forward_pass,
373+
)
374+
model.to(device=model.device)
375+
376+
inputs = {
377+
"states": flatten_tensorized_space(
378+
sample_space(observation_space, batch_size=10, backend="native", device=device)
379+
)
380+
}
381+
output = model.act(inputs, role="role_0")
382+
assert output[0].shape == (10, 2)
383+
output = model.act(inputs, role="role_1")
384+
assert output[0].shape == (10, 1)

0 commit comments

Comments
 (0)