Skip to content

Commit 089c20f

Browse files
committed
Convert StochasticPolicy to use PolicyInput type
I also reworked significant parts of CNNModule because there were some bugs and it was somewhat hard to use.
1 parent 89c71d4 commit 089c20f

28 files changed

+543
-543
lines changed

docs/user/implement_algo.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ import numpy as np
195195

196196
from garage.samplers import LocalSampler
197197
from garage.np import discount_cumsum
198+
from garage.torch import PolicyMode, PolicyInput
198199

199200
class SimpleVPG:
200201

@@ -220,7 +221,9 @@ class SimpleVPG:
220221
returns = torch.Tensor(returns_numpy.copy())
221222
obs = torch.Tensor(path['observations'])
222223
actions = torch.Tensor(path['actions'])
223-
dist = self.policy(obs)[0]
224+
policy_input = PolicyInput(PolicyMode.FULL, obs,
225+
lengths=[len(path)])
226+
dist = self.policy(policy_input)[0]
224227
log_likelihoods = dist.log_prob(actions)
225228
loss = (-log_likelihoods * returns).mean()
226229
loss.backward()

examples/torch/tutorial_vpg.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from garage.experiment.deterministic import set_seed
99
from garage.np import discount_cumsum
1010
from garage.sampler import LocalSampler
11+
from garage.torch import PolicyInput, PolicyMode
1112
from garage.torch.policies import GaussianMLPPolicy
1213
from garage.trainer import Trainer
1314

@@ -62,7 +63,10 @@ def _train_once(self, samples):
6263
returns = torch.Tensor(returns_numpy.copy())
6364
obs = torch.Tensor(path['observations'])
6465
actions = torch.Tensor(path['actions'])
65-
dist = self.policy(obs)[0]
66+
policy_input = PolicyInput(PolicyMode.FULL,
67+
obs,
68+
lengths=[len(path)])
69+
dist = self.policy(policy_input)[0]
6670
log_likelihoods = dist.log_prob(actions)
6771
loss = (-log_likelihoods * returns).mean()
6872
loss.backward()

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ per-file-ignores =
66
# See https://gitlab.com/pycqa/flake8/-/issues/494
77
#
88
# errors on valid property docstrings
9-
src/garage/*:D403
9+
src/garage/*:D403,R1720
1010
# unit tests don't need docstrings
1111
tests/garage/*:D, F401, F811
1212
# interferes with idiomatic `from torch.nn import functional as F`
1313
examples/torch/*:N812
14-
src/garage/torch/*:N812,D403
14+
src/garage/torch/*:N812,D403,R1720
1515
tests/garage/torch/*:N812,D
1616

1717
# Docstring style checks

src/garage/torch/__init__.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
"""PyTorch-backed modules and algorithms."""
22
# yapf: disable
3-
from garage.torch._functions import (compute_advantages, dict_np_to_torch,
3+
from garage.torch._dtypes import (PolicyInput, PolicyMode,
4+
ShuffledOptimizationNotSupported)
5+
from garage.torch._functions import (as_torch, as_torch_dict,
6+
compute_advantages, expand_var,
47
filter_valids, flatten_batch,
58
flatten_to_single_vector, global_device,
6-
NonLinearity, np_to_torch, pad_to_last,
7-
prefer_gpu, product_of_gaussians,
8-
set_gpu_mode, soft_update_model,
9-
torch_to_np, TransposeImage,
10-
update_module_params)
9+
NonLinearity, output_height_2d,
10+
output_width_2d, pad_to_last, prefer_gpu,
11+
product_of_gaussians, set_gpu_mode,
12+
soft_update_model, torch_to_np,
13+
TransposeImage, update_module_params)
1114

1215
# yapf: enable
1316
__all__ = [
14-
'compute_advantages', 'dict_np_to_torch', 'filter_valids', 'flatten_batch',
15-
'global_device', 'np_to_torch', 'pad_to_last', 'prefer_gpu',
17+
'compute_advantages', 'as_torch_dict', 'filter_valids', 'flatten_batch',
18+
'global_device', 'as_torch', 'pad_to_last', 'prefer_gpu',
1619
'product_of_gaussians', 'set_gpu_mode', 'soft_update_model', 'torch_to_np',
1720
'update_module_params', 'NonLinearity', 'flatten_to_single_vector',
18-
'TransposeImage'
21+
'TransposeImage', 'PolicyMode', 'PolicyInput',
22+
'ShuffledOptimizationNotSupported', 'output_width_2d', 'output_height_2d',
23+
'expand_var'
1924
]

src/garage/torch/_dtypes.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Data structures used in garage.torch."""
2+
from dataclasses import dataclass
3+
import enum
4+
5+
import torch
6+
from torch import nn
7+
8+
9+
class ShuffledOptimizationNotSupported(ValueError):
10+
"""Raised by recurrent policies if they're passed a shuffled batch."""
11+
12+
13+
class PolicyMode(enum.IntEnum):
14+
"""Defines what mode a PolicyInput is being used in.
15+
16+
See :class:`PolicyInput` for detailed documentation.
17+
18+
"""
19+
# Policy is being used to run a rollout.
20+
# observations contains the last observations, and all_observations
21+
# contains partial episodes batched using lengths.
22+
ROLLOUT = 0
23+
# Policy is being used to do an optimization with timesteps from different
24+
# episodes. Recurrent policies must raise
25+
# ShuffledOptimizationNotSupported if they encounter this mode.
26+
SHUFFLED = 1
27+
# Policy is being used to do an optimization on complete episodes.
28+
FULL = 2
29+
30+
31+
@dataclass
32+
class PolicyInput:
33+
r"""The (differentiable) input to all pytorch policies.
34+
35+
Args:
36+
mode (PolicyMode): The mode this batch is being used in. Determines the
37+
shape of observations.
38+
observations (torch.Tensor): A torch tensor containing flattened
39+
observations in a batch. Stateless policies should always operate
40+
on this input. Shape depends on the mode:
41+
* If `mode == ROLLOUT`, has shape :math:`(V, O)` (where V is the
42+
vectorization level).
43+
* If `mode == SHUFFLED`, has shape :math:`(B, O)` (where B is the
44+
mini-batch size).
45+
* If mode == FULL, has shape :math:`(N \bullet [T], O)` (where N
46+
is the number of episodes, and T is the episode lengths).
47+
lengths (torch.Tensor or None): Integer tensor containing the lengths
48+
of each episode. Only has a value if `mode == FULL`.
49+
50+
"""
51+
52+
mode: PolicyMode
53+
observations: torch.Tensor
54+
lengths: torch.Tensor = None
55+
56+
def __post_init__(self):
57+
"""Check that lengths is consistent with the rest of the fields.
58+
59+
Raises:
60+
ValueError: If lengths is not consistent with another field.
61+
62+
"""
63+
if self.mode == PolicyMode.FULL:
64+
if self.lengths is None:
65+
raise ValueError(
66+
'lengths is None, but must be a torch.Tensor when '
67+
'mode == PolicyMode.FULL')
68+
assert self.lengths is not None
69+
if self.lengths.dtype not in (torch.uint8, torch.int8, torch.int16,
70+
torch.int32, torch.int64):
71+
raise ValueError(
72+
f'lengths has dtype {self.lengths.dtype}, but must have '
73+
f'an integer dtype')
74+
total_size = sum(self.lengths)
75+
if self.observations.shape[0] != total_size:
76+
raise ValueError(
77+
f'observations has batch size '
78+
f'{self.observations.shape[0]}, but must have batch '
79+
f'size {total_size} to match lengths')
80+
assert self.observations.shape[0] == total_size
81+
elif self.lengths is not None:
82+
raise ValueError(
83+
f'lengths has value {self.lengths}, but must be None '
84+
f'when mode == {self.mode}')
85+
86+
def to_packed_sequence(self):
87+
"""Turn full observations into a torch.nn.utils.rnn.PackedSequence.
88+
89+
Raises:
90+
ShuffledOptimizationNotSupported: If called when `mode != FULL`
91+
92+
Returns:
93+
torch.nn.utils.rnn.PackedSequence: The sequence of flattened
94+
observations.
95+
96+
"""
97+
if self.mode != PolicyMode.FULL:
98+
raise ShuffledOptimizationNotSupported(
99+
f'mode has value {self.mode} but must have mode '
100+
f'{PolicyMode.FULL} to use to_packed_sequence')
101+
sequence = []
102+
start = 0
103+
for length in self.lengths:
104+
stop = start + length
105+
sequence.append(self.observations[start:stop])
106+
start = stop
107+
pack_sequence = nn.utils.rnn.pack_sequence
108+
return pack_sequence(sequence, enforce_sorted=False)

src/garage/torch/_functions.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
"""
1212
import copy
1313
import dataclasses
14+
import math
15+
import warnings
1416

1517
import akro
1618
import torch
@@ -134,7 +136,7 @@ def filter_valids(tensor, valids):
134136
return [tensor[i][:valid] for i, valid in enumerate(valids)]
135137

136138

137-
def np_to_torch(array):
139+
def as_torch(array):
138140
"""Numpy arrays to PyTorch tensors.
139141
140142
Args:
@@ -144,10 +146,10 @@ def np_to_torch(array):
144146
torch.Tensor: float tensor on the global device.
145147
146148
"""
147-
return torch.from_numpy(array).float().to(global_device())
149+
return torch.as_tensor(array).float().to(global_device())
148150

149151

150-
def dict_np_to_torch(array_dict):
152+
def as_torch_dict(array_dict):
151153
"""Convert a dict whose values are numpy arrays to PyTorch tensors.
152154
153155
Modifies array_dict in place.
@@ -160,7 +162,7 @@ def dict_np_to_torch(array_dict):
160162
161163
"""
162164
for key, value in array_dict.items():
163-
array_dict[key] = np_to_torch(value)
165+
array_dict[key] = as_torch(value)
164166
return array_dict
165167

166168

@@ -401,3 +403,90 @@ def step(self, action):
401403
env_step = super().step(action)
402404
obs = env_step.observation.transpose(2, 0, 1)
403405
return dataclasses.replace(env_step, observation=obs)
406+
407+
408+
def output_height_2d(layer, height):
409+
"""Compute the output height of a torch.nn.Conv2d, assuming NCHW format.
410+
411+
This requires knowing the input height. Because NCHW format makes this very
412+
easy to mix up, this is a seperate function from conv2d_output_height.
413+
414+
It also works on torch.nn.MaxPool2d.
415+
416+
This function implements the formula described in the torch.nn.Conv2d
417+
documentation:
418+
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
419+
420+
Args:
421+
layer (torch.nn.Conv2d): The layer to compute output size for.
422+
height (int): The height of the input image.
423+
424+
Returns:
425+
int: The height of the output image.
426+
427+
"""
428+
assert isinstance(layer, (torch.nn.Conv2d, torch.nn.MaxPool2d))
429+
return math.floor((height + 2 * layer.padding[0] - layer.dilation[0] *
430+
(layer.kernel_size[0] - 1) - 1) / layer.stride[0] + 1)
431+
432+
433+
def output_width_2d(layer, width):
434+
"""Compute the output width of a torch.nn.Conv2d, assuming NCHW format.
435+
436+
This requires knowing the input width. Because NCHW format makes this very
437+
easy to mix up, this is a seperate function from conv2d_output_height.
438+
439+
It also works on torch.nn.MaxPool2d.
440+
441+
This function implements the formula described in the torch.nn.Conv2d
442+
documentation:
443+
https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
444+
445+
Args:
446+
layer (torch.nn.Conv2d): The layer to compute output size for.
447+
width (int): The width of the input image.
448+
449+
Returns:
450+
int: The width of the output image.
451+
452+
"""
453+
assert isinstance(layer, (torch.nn.Conv2d, torch.nn.MaxPool2d))
454+
return math.floor((width + 2 * layer.padding[1] - layer.dilation[1] *
455+
(layer.kernel_size[1] - 1) - 1) / layer.stride[1] + 1)
456+
457+
458+
def expand_var(name, item, n_expected, reference):
459+
"""Expand a variable to an expected length.
460+
461+
This is used to handle arguments to primitives that can all be reasonably
462+
set to the same value, or multiple different values.
463+
464+
Args:
465+
name (str): Name of variable being expanded.
466+
item (any): Element being expanded.
467+
n_expected (int): Number of elements expected.
468+
reference (str): Source of n_expected.
469+
470+
Returns:
471+
list: List of references to item or item itself.
472+
473+
Raises:
474+
ValueError: If the variable is a sequence but length of the variable
475+
is not 1 or n_expected.
476+
477+
"""
478+
if n_expected == 1:
479+
warnings.warn(
480+
f'Providing a {reference} of length 1 prevents {name} from '
481+
f'being expanded.')
482+
if isinstance(item, (list, tuple)):
483+
if len(item) == n_expected:
484+
return item
485+
elif len(item) == 1:
486+
return list(item) * n_expected
487+
else:
488+
raise ValueError(
489+
f'{name} is length {len(item)} but should be length '
490+
f'{n_expected} to match {reference}')
491+
else:
492+
return [item] * n_expected

src/garage/torch/algos/bc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from garage.np.algos.rl_algorithm import RLAlgorithm
1212
from garage.np.policies import Policy
1313
from garage.sampler import RaySampler
14-
from garage.torch import np_to_torch
14+
from garage.torch import as_torch
1515

1616
# yapf: enable
1717

@@ -126,8 +126,8 @@ def _train_once(self, trainer, epoch):
126126
minibatches = np.array_split(indices, self._minibatches_per_epoch)
127127
losses = []
128128
for minibatch in minibatches:
129-
observations = np_to_torch(batch.observations[minibatch])
130-
actions = np_to_torch(batch.actions[minibatch])
129+
observations = as_torch(batch.observations[minibatch])
130+
actions = as_torch(batch.actions[minibatch])
131131
self._optimizer.zero_grad()
132132
loss = self._compute_loss(observations, actions)
133133
loss.backward()

src/garage/torch/algos/ddpg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
obtain_evaluation_episodes)
1111
from garage.np.algos import RLAlgorithm
1212
from garage.sampler import FragmentWorker, LocalSampler
13-
from garage.torch import dict_np_to_torch, torch_to_np
13+
from garage.torch import as_torch_dict, torch_to_np
1414

1515
# yapf: enable
1616

@@ -229,7 +229,7 @@ def optimize_policy(self, samples_data):
229229
qval: Q-value predicted by the Q-network.
230230
231231
"""
232-
transitions = dict_np_to_torch(samples_data)
232+
transitions = as_torch_dict(samples_data)
233233

234234
observations = transitions['observations']
235235
rewards = transitions['rewards'].reshape(-1, 1)

src/garage/torch/algos/dqn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from garage._functions import obtain_evaluation_episodes
1212
from garage.np.algos import RLAlgorithm
1313
from garage.sampler import FragmentWorker
14-
from garage.torch import global_device, np_to_torch
14+
from garage.torch import as_torch, global_device
1515

1616

1717
class DQN(RLAlgorithm):
@@ -240,12 +240,12 @@ def _optimize_qf(self, timesteps):
240240
qval: Q-value predicted by the Q-network.
241241
242242
"""
243-
observations = np_to_torch(timesteps.observations)
244-
rewards = np_to_torch(timesteps.rewards).reshape(-1, 1)
243+
observations = as_torch(timesteps.observations)
244+
rewards = as_torch(timesteps.rewards).reshape(-1, 1)
245245
rewards *= self._reward_scale
246-
actions = np_to_torch(timesteps.actions)
247-
next_observations = np_to_torch(timesteps.next_observations)
248-
terminals = np_to_torch(timesteps.terminals).reshape(-1, 1)
246+
actions = as_torch(timesteps.actions)
247+
next_observations = as_torch(timesteps.next_observations)
248+
terminals = as_torch(timesteps.terminals).reshape(-1, 1)
249249

250250
next_inputs = next_observations
251251
inputs = observations

0 commit comments

Comments
 (0)