7
7
categorical_model ,
8
8
deterministic_model ,
9
9
gaussian_model ,
10
+ multicategorical_model ,
10
11
multivariate_gaussian_model ,
11
12
shared_model ,
12
13
)
@@ -91,12 +92,44 @@ def test_categorical_model(capsys, device):
91
92
network = yaml .safe_load (NETWORK_SPEC_OBSERVATION [observation_space_type ][0 ])["network" ],
92
93
output = "ACTIONS" ,
93
94
)
94
- model .to (device = device )
95
+ model .to (device = model . device )
95
96
96
- output = model .act ({"states" : flatten_tensorized_space (sample_space (observation_space , 10 , "torch" , device ))})
97
+ output = model .act (
98
+ {
99
+ "states" : flatten_tensorized_space (
100
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
101
+ )
102
+ }
103
+ )
97
104
assert output [0 ].shape == (10 , 1 )
98
105
99
106
107
+ @pytest .mark .parametrize ("device" , [None , "cpu" , "cuda:0" ])
108
+ def test_multicategorical_model (capsys , device ):
109
+ # observation
110
+ action_space = spaces .MultiDiscrete ([2 , 3 ])
111
+ for observation_space_type in [spaces .Box , spaces .Tuple , spaces .Dict ]:
112
+ observation_space = NETWORK_SPEC_OBSERVATION [observation_space_type ][1 ]
113
+ model = multicategorical_model (
114
+ observation_space = observation_space ,
115
+ action_space = action_space ,
116
+ device = device ,
117
+ unnormalized_log_prob = True ,
118
+ network = yaml .safe_load (NETWORK_SPEC_OBSERVATION [observation_space_type ][0 ])["network" ],
119
+ output = "ACTIONS" ,
120
+ )
121
+ model .to (device = model .device )
122
+
123
+ output = model .act (
124
+ {
125
+ "states" : flatten_tensorized_space (
126
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
127
+ )
128
+ }
129
+ )
130
+ assert output [0 ].shape == (10 , 2 )
131
+
132
+
100
133
@pytest .mark .parametrize ("device" , [None , "cpu" , "cuda:0" ])
101
134
def test_deterministic_model (capsys , device ):
102
135
# observation
@@ -111,9 +144,15 @@ def test_deterministic_model(capsys, device):
111
144
network = yaml .safe_load (NETWORK_SPEC_OBSERVATION [observation_space_type ][0 ])["network" ],
112
145
output = "ACTIONS" ,
113
146
)
114
- model .to (device = device )
147
+ model .to (device = model . device )
115
148
116
- output = model .act ({"states" : flatten_tensorized_space (sample_space (observation_space , 10 , "torch" , device ))})
149
+ output = model .act (
150
+ {
151
+ "states" : flatten_tensorized_space (
152
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
153
+ )
154
+ }
155
+ )
117
156
assert output [0 ].shape == (10 , 2 )
118
157
119
158
@@ -135,9 +174,15 @@ def test_gaussian_model(capsys, device):
135
174
network = yaml .safe_load (NETWORK_SPEC_OBSERVATION [observation_space_type ][0 ])["network" ],
136
175
output = "ACTIONS" ,
137
176
)
138
- model .to (device = device )
177
+ model .to (device = model . device )
139
178
140
- output = model .act ({"states" : flatten_tensorized_space (sample_space (observation_space , 10 , "torch" , device ))})
179
+ output = model .act (
180
+ {
181
+ "states" : flatten_tensorized_space (
182
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
183
+ )
184
+ }
185
+ )
141
186
assert output [0 ].shape == (10 , 2 )
142
187
143
188
@@ -159,9 +204,15 @@ def test_multivariate_gaussian_model(capsys, device):
159
204
network = yaml .safe_load (NETWORK_SPEC_OBSERVATION [observation_space_type ][0 ])["network" ],
160
205
output = "ACTIONS" ,
161
206
)
162
- model .to (device = device )
207
+ model .to (device = model . device )
163
208
164
- output = model .act ({"states" : flatten_tensorized_space (sample_space (observation_space , 10 , "torch" , device ))})
209
+ output = model .act (
210
+ {
211
+ "states" : flatten_tensorized_space (
212
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
213
+ )
214
+ }
215
+ )
165
216
assert output [0 ].shape == (10 , 2 )
166
217
167
218
@@ -196,9 +247,13 @@ def test_shared_gaussian_deterministic_model(capsys, device, single_forward_pass
196
247
],
197
248
single_forward_pass = single_forward_pass ,
198
249
)
199
- model .to (device = device )
250
+ model .to (device = model . device )
200
251
201
- inputs = {"states" : flatten_tensorized_space (sample_space (observation_space , 10 , "torch" , device ))}
252
+ inputs = {
253
+ "states" : flatten_tensorized_space (
254
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
255
+ )
256
+ }
202
257
output = model .act (inputs , role = "role_0" )
203
258
assert output [0 ].shape == (10 , 2 )
204
259
output = model .act (inputs , role = "role_1" )
@@ -236,9 +291,13 @@ def test_shared_multivariate_gaussian_deterministic_model(capsys, device, single
236
291
],
237
292
single_forward_pass = single_forward_pass ,
238
293
)
239
- model .to (device = device )
294
+ model .to (device = model . device )
240
295
241
- inputs = {"states" : flatten_tensorized_space (sample_space (observation_space , 10 , "torch" , device ))}
296
+ inputs = {
297
+ "states" : flatten_tensorized_space (
298
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
299
+ )
300
+ }
242
301
output = model .act (inputs , role = "role_0" )
243
302
assert output [0 ].shape == (10 , 2 )
244
303
output = model .act (inputs , role = "role_1" )
@@ -249,7 +308,7 @@ def test_shared_multivariate_gaussian_deterministic_model(capsys, device, single
249
308
@pytest .mark .parametrize ("device" , [None , "cpu" , "cuda:0" ])
250
309
def test_shared_categorical_deterministic_model (capsys , device , single_forward_pass ):
251
310
# observation
252
- action_space = spaces .Box ( low = - 1 , high = 1 , shape = ( 2 ,) )
311
+ action_space = spaces .Discrete ( 2 )
253
312
for observation_space_type in [spaces .Box , spaces .Tuple , spaces .Dict ]:
254
313
observation_space = NETWORK_SPEC_OBSERVATION [observation_space_type ][1 ]
255
314
model = shared_model (
@@ -272,10 +331,54 @@ def test_shared_categorical_deterministic_model(capsys, device, single_forward_p
272
331
],
273
332
single_forward_pass = single_forward_pass ,
274
333
)
275
- model .to (device = device )
334
+ model .to (device = model . device )
276
335
277
- inputs = {"states" : flatten_tensorized_space (sample_space (observation_space , 10 , "torch" , device ))}
336
+ inputs = {
337
+ "states" : flatten_tensorized_space (
338
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
339
+ )
340
+ }
278
341
output = model .act (inputs , role = "role_0" )
279
342
assert output [0 ].shape == (10 , 1 )
280
343
output = model .act (inputs , role = "role_1" )
281
344
assert output [0 ].shape == (10 , 1 )
345
+
346
+
347
+ @pytest .mark .parametrize ("single_forward_pass" , [True , False ])
348
+ @pytest .mark .parametrize ("device" , [None , "cpu" , "cuda:0" ])
349
+ def test_shared_multicategorical_deterministic_model (capsys , device , single_forward_pass ):
350
+ # observation
351
+ action_space = spaces .MultiDiscrete ([2 , 3 ])
352
+ for observation_space_type in [spaces .Box , spaces .Tuple , spaces .Dict ]:
353
+ observation_space = NETWORK_SPEC_OBSERVATION [observation_space_type ][1 ]
354
+ model = shared_model (
355
+ observation_space = observation_space ,
356
+ action_space = action_space ,
357
+ device = device ,
358
+ structure = ["MultiCategoricalMixin" , "DeterministicMixin" ],
359
+ roles = ["role_0" , "role_1" ],
360
+ parameters = [
361
+ {
362
+ "unnormalized_log_prob" : True ,
363
+ "network" : yaml .safe_load (NETWORK_SPEC_OBSERVATION [observation_space_type ][0 ])["network" ],
364
+ "output" : "ACTIONS" ,
365
+ },
366
+ {
367
+ "clip_actions" : False ,
368
+ "network" : yaml .safe_load (NETWORK_SPEC_OBSERVATION [observation_space_type ][0 ])["network" ],
369
+ "output" : "ONE" ,
370
+ },
371
+ ],
372
+ single_forward_pass = single_forward_pass ,
373
+ )
374
+ model .to (device = model .device )
375
+
376
+ inputs = {
377
+ "states" : flatten_tensorized_space (
378
+ sample_space (observation_space , batch_size = 10 , backend = "native" , device = device )
379
+ )
380
+ }
381
+ output = model .act (inputs , role = "role_0" )
382
+ assert output [0 ].shape == (10 , 2 )
383
+ output = model .act (inputs , role = "role_1" )
384
+ assert output [0 ].shape == (10 , 1 )
0 commit comments