Skip to content

Commit 017274f

Browse files
committed
WIP torch optimizer refactor
1 parent 92646ed commit 017274f

File tree

17 files changed

+189
-228
lines changed

17 files changed

+189
-228
lines changed

benchmarks/src/garage_benchmarks/experiments/algos/ppo_garage_pytorch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from garage.experiment import deterministic
77
from garage.sampler import RaySampler
88
from garage.torch.algos import PPO as PyTorch_PPO
9-
from garage.torch.optimizers import OptimizerWrapper
9+
from garage.torch.optimizers import MinibatchOptimizer
1010
from garage.torch.policies import GaussianMLPPolicy as PyTorch_GMP
1111
from garage.torch.value_functions import GaussianMLPValueFunction
1212
from garage.trainer import Trainer
@@ -45,15 +45,15 @@ def ppo_garage_pytorch(ctxt, env_id, seed):
4545
hidden_nonlinearity=torch.tanh,
4646
output_nonlinearity=None)
4747

48-
policy_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=2.5e-4)),
49-
policy,
50-
max_optimization_epochs=10,
51-
minibatch_size=64)
48+
policy_optimizer = MinibatchOptimizer((torch.optim.Adam, dict(lr=2.5e-4)),
49+
policy,
50+
max_optimization_epochs=10,
51+
minibatch_size=64)
5252

53-
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=2.5e-4)),
54-
value_function,
55-
max_optimization_epochs=10,
56-
minibatch_size=64)
53+
vf_optimizer = MinibatchOptimizer((torch.optim.Adam, dict(lr=2.5e-4)),
54+
value_function,
55+
max_optimization_epochs=10,
56+
minibatch_size=64)
5757

5858
sampler = RaySampler(agents=policy,
5959
envs=env,

src/garage/torch/__init__.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
"""PyTorch-backed modules and algorithms."""
22
# yapf: disable
3-
from garage.torch._dtypes import (ObservationBatch, ObservationOrder,
4-
ShuffledOptimizationNotSupported,
5-
observation_batch_to_packed_sequence)
6-
from garage.torch._functions import (compute_advantages, dict_np_to_torch,
7-
filter_valids, flatten_batch,
8-
flatten_to_single_vector, global_device,
9-
NonLinearity, np_to_torch, pad_to_last,
10-
prefer_gpu, product_of_gaussians,
11-
set_gpu_mode, soft_update_model,
12-
torch_to_np, TransposeImage,
13-
update_module_params)
3+
from garage.torch._dtypes import (observation_batch_to_packed_sequence,
4+
ObservationBatch, ObservationOrder,
5+
ShuffledOptimizationNotSupported)
6+
from garage.torch._functions import (as_tensor, compute_advantages,
7+
dict_np_to_torch, filter_valids,
8+
flatten_batch, flatten_to_single_vector,
9+
global_device, NonLinearity, np_to_torch,
10+
pad_to_last, prefer_gpu,
11+
product_of_gaussians, set_gpu_mode,
12+
soft_update_model, torch_to_np,
13+
TransposeImage, update_module_params)
1414

1515
# yapf: enable
1616
__all__ = [
17-
'compute_advantages', 'dict_np_to_torch', 'filter_valids', 'flatten_batch',
18-
'global_device', 'np_to_torch', 'pad_to_last', 'prefer_gpu',
19-
'product_of_gaussians', 'set_gpu_mode', 'soft_update_model', 'torch_to_np',
20-
'update_module_params', 'NonLinearity', 'flatten_to_single_vector',
21-
'TransposeImage', 'ObservationBatch', 'ObservationOrder',
22-
'ShuffledOptimizationNotSupported', 'observation_batch_to_packed_sequence'
17+
'as_tensor', 'compute_advantages', 'dict_np_to_torch', 'filter_valids',
18+
'flatten_batch', 'global_device', 'np_to_torch', 'pad_to_last',
19+
'prefer_gpu', 'product_of_gaussians', 'set_gpu_mode', 'soft_update_model',
20+
'torch_to_np', 'update_module_params', 'NonLinearity',
21+
'flatten_to_single_vector', 'TransposeImage', 'ObservationBatch',
22+
'ObservationOrder', 'ShuffledOptimizationNotSupported',
23+
'observation_batch_to_packed_sequence'
2324
]

src/garage/torch/_dtypes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ def __new__(cls, observations, order, lengths=None):
7979
f'lengths has dtype {self.lengths.dtype}, but must have '
8080
f'an integer dtype')
8181
total_size = sum(self.lengths)
82-
if self.observations.shape[0] != total_size:
82+
if self.shape[0] != total_size:
8383
raise ValueError(
8484
f'observations has batch size '
8585
f'{self.observations.shape[0]}, but must have batch '
8686
f'size {total_size} to match lengths')
87-
assert self.observations.shape[0] == total_size
87+
assert self.shape[0] == total_size
8888
elif self.lengths is not None:
8989
raise ValueError(
9090
f'lengths has value {self.lengths}, but must be None '

src/garage/torch/_functions.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ def discount_cumsum(x, discount):
9292
discount,
9393
dtype=torch.float,
9494
device=x.device)
95+
discount_x[0] = 1.0
9596
filter = torch.cumprod(discount_x, dim=0)
96-
returns = F.conv1d(x, filter, stride=1)
97-
assert returns.shape == (len(x), )
98-
from garage.np import discount_cumsum as np_discout_cumsum
99-
import numpy as np
100-
expected = np_discout_cumsum(torch_to_np(x), discount)
101-
assert np.array_equal(expected, torch_to_np(returns))
97+
pad = len(x) - 1
98+
# minibatch of 1, with 1 channel
99+
filter = filter.reshape(1, 1, -1)
100+
returns = F.conv1d(x.reshape(1, 1, -1), filter, stride=1, padding=pad)
101+
returns = returns[0, 0, pad:]
102102
return returns
103103

104104

@@ -372,6 +372,19 @@ def product_of_gaussians(mus, sigmas_squared):
372372
return mu, sigma_squared
373373

374374

375+
def as_tensor(data):
376+
"""Convert a list to a PyTorch tensor
377+
378+
Args:
379+
data (list): Data to convert to tensor
380+
381+
Returns:
382+
torch.Tensor: A float tensor
383+
384+
"""
385+
return torch.as_tensor(data, dtype=torch.float32, device=global_device())
386+
387+
375388
# pylint: disable=W0223
376389
class NonLinearity(nn.Module):
377390
"""Wrapper class for non linear function or module.

src/garage/torch/algos/maml_ppo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from garage import _Default
55
from garage.torch.algos import PPO
66
from garage.torch.algos.maml import MAML
7-
from garage.torch.optimizers import OptimizerWrapper
7+
from garage.torch.optimizers import MinibatchOptimizer
88

99

1010
class MAMLPPO(MAML):
@@ -70,10 +70,10 @@ def __init__(self,
7070
meta_evaluator=None,
7171
evaluate_every_n_epochs=1):
7272

73-
policy_optimizer = OptimizerWrapper(
73+
policy_optimizer = MinibatchOptimizer(
7474
(torch.optim.Adam, dict(lr=inner_lr)), policy)
75-
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
76-
value_function)
75+
vf_optimizer = MinibatchOptimizer(
76+
(torch.optim.Adam, dict(lr=inner_lr)), value_function)
7777

7878
inner_algo = PPO(env.spec,
7979
policy,

src/garage/torch/algos/maml_trpo.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from garage.torch.algos import VPG
66
from garage.torch.algos.maml import MAML
77
from garage.torch.optimizers import (ConjugateGradientOptimizer,
8-
OptimizerWrapper)
8+
MinibatchOptimizer)
99

1010

1111
class MAMLTRPO(MAML):
@@ -71,10 +71,10 @@ def __init__(self,
7171
meta_evaluator=None,
7272
evaluate_every_n_epochs=1):
7373

74-
policy_optimizer = OptimizerWrapper(
74+
policy_optimizer = MinibatchOptimizer(
7575
(torch.optim.Adam, dict(lr=inner_lr)), policy)
76-
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
77-
value_function)
76+
vf_optimizer = MinibatchOptimizer(
77+
(torch.optim.Adam, dict(lr=inner_lr)), value_function)
7878

7979
inner_algo = VPG(env.spec,
8080
policy,

src/garage/torch/algos/maml_vpg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from garage import _Default
55
from garage.torch.algos import VPG
66
from garage.torch.algos.maml import MAML
7-
from garage.torch.optimizers import OptimizerWrapper
7+
from garage.torch.optimizers import MinibatchOptimizer
88

99

1010
class MAMLVPG(MAML):
@@ -66,10 +66,10 @@ def __init__(self,
6666
num_grad_updates=1,
6767
meta_evaluator=None,
6868
evaluate_every_n_epochs=1):
69-
policy_optimizer = OptimizerWrapper(
69+
policy_optimizer = MinibatchOptimizer(
7070
(torch.optim.Adam, dict(lr=inner_lr)), policy)
71-
vf_optimizer = OptimizerWrapper((torch.optim.Adam, dict(lr=inner_lr)),
72-
value_function)
71+
vf_optimizer = MinibatchOptimizer(
72+
(torch.optim.Adam, dict(lr=inner_lr)), value_function)
7373

7474
inner_algo = VPG(env.spec,
7575
policy,

src/garage/torch/algos/ppo.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33

44
from garage.torch.algos import VPG
5-
from garage.torch.optimizers import OptimizerWrapper
5+
from garage.torch.optimizers import MinibatchOptimizer
66

77

88
class PPO(VPG):
@@ -14,9 +14,9 @@ class PPO(VPG):
1414
value_function (garage.torch.value_functions.ValueFunction): The value
1515
function.
1616
sampler (garage.sampler.Sampler): Sampler.
17-
policy_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer
17+
policy_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer
1818
for policy.
19-
vf_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer for
19+
vf_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer for
2020
value function.
2121
lr_clip_range (float): The limit on the likelihood ratio between
2222
policies.
@@ -63,13 +63,13 @@ def __init__(self,
6363
entropy_method='no_entropy'):
6464

6565
if policy_optimizer is None:
66-
policy_optimizer = OptimizerWrapper(
66+
policy_optimizer = MinibatchOptimizer(
6767
(torch.optim.Adam, dict(lr=2.5e-4)),
6868
policy,
6969
max_optimization_epochs=10,
7070
minibatch_size=64)
7171
if vf_optimizer is None:
72-
vf_optimizer = OptimizerWrapper(
72+
vf_optimizer = MinibatchOptimizer(
7373
(torch.optim.Adam, dict(lr=2.5e-4)),
7474
value_function,
7575
max_optimization_epochs=10,

src/garage/torch/algos/trpo.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from garage.torch.algos import VPG
55
from garage.torch.optimizers import (ConjugateGradientOptimizer,
6-
OptimizerWrapper)
6+
MinibatchOptimizer)
77

88

99
class TRPO(VPG):
@@ -15,9 +15,9 @@ class TRPO(VPG):
1515
value_function (garage.torch.value_functions.ValueFunction): The value
1616
function.
1717
sampler (garage.sampler.Sampler): Sampler.
18-
policy_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer
18+
policy_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer
1919
for policy.
20-
vf_optimizer (garage.torch.optimizer.OptimizerWrapper): Optimizer for
20+
vf_optimizer (garage.torch.optimizer.MinibatchOptimizer): Optimizer for
2121
value function.
2222
num_train_per_epoch (int): Number of train_once calls per epoch.
2323
discount (float): Discount.
@@ -61,11 +61,11 @@ def __init__(self,
6161
entropy_method='no_entropy'):
6262

6363
if policy_optimizer is None:
64-
policy_optimizer = OptimizerWrapper(
64+
policy_optimizer = MinibatchOptimizer(
6565
(ConjugateGradientOptimizer, dict(max_constraint_value=0.01)),
6666
policy)
6767
if vf_optimizer is None:
68-
vf_optimizer = OptimizerWrapper(
68+
vf_optimizer = MinibatchOptimizer(
6969
(torch.optim.Adam, dict(lr=2.5e-4)),
7070
value_function,
7171
max_optimization_epochs=10,
@@ -116,7 +116,8 @@ def _compute_objective(self, advantages, obs, actions, rewards):
116116

117117
return surrogate
118118

119-
def _train_policy(self, obs, actions, rewards, advantages):
119+
def _train_policy(self, observations, actions, rewards, advantages,
120+
lengths):
120121
r"""Train the policy.
121122
122123
Args:
@@ -128,17 +129,19 @@ def _train_policy(self, obs, actions, rewards, advantages):
128129
with shape :math:`(N, )`.
129130
advantages (torch.Tensor): Advantage value at each step
130131
with shape :math:`(N, )`.
132+
lengths (torch.Tensor): Lengths of episodes.
131133
132134
Returns:
133135
torch.Tensor: Calculated mean scalar value of policy loss (float).
134136
135137
"""
136-
self._policy_optimizer.zero_grad()
137-
loss = self._compute_loss_with_adv(obs, actions, rewards, advantages)
138-
loss.backward()
139-
self._policy_optimizer.step(
140-
f_loss=lambda: self._compute_loss_with_adv(obs, actions, rewards,
141-
advantages),
142-
f_constraint=lambda: self._compute_kl_constraint(obs))
143-
144-
return loss
138+
data = {
139+
'observations': observations,
140+
'actions': actions,
141+
'rewards': rewards,
142+
'advantages': advantages,
143+
'lengths': lengths
144+
}
145+
f_constraint = lambda: self._compute_kl_constraint(observations)
146+
return self._policy_optimizer.step(data, self._loss_function,
147+
f_constraint)

0 commit comments

Comments
 (0)