Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Categorical GRU Policy #2196

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/garage/torch/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from garage.torch.modules.discrete_cnn_module import DiscreteCNNModule
from garage.torch.modules.discrete_dueling_cnn_module import (
DiscreteDuelingCNNModule)
from garage.torch.modules.gru_module import GRUModule
from garage.torch.modules.categorical_gru_module import CategoricalGRUModule
# yapf: enable

__all__ = [
Expand All @@ -26,4 +28,6 @@
'GaussianMLPModule',
'GaussianMLPIndependentStdModule',
'GaussianMLPTwoHeadedModule',
'GRUModule',
'CategoricalGRUModule',
]
93 changes: 93 additions & 0 deletions src/garage/torch/modules/categorical_gru_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Categorical GRU Module.

A model represented by a Categorical distribution
which is parameterized by a Gated Recurrent Unit (GRU)
followed a multilayer perceptron (MLP).
"""
import torch
from torch import nn
from torch.distributions import Categorical

from garage.torch.modules.gru_module import GRUModule
from garage.torch import global_device

class CategoricalGRUModule(nn.Module):
"""Categorical GRU Model.
A model represented by a Categorical distribution
which is parameterized by a gated recurrent unit (GRU)
followed by a fully-connected layer.

Args:
input_dim (int): Dimension of the network input.
output_dim (int): Dimension of the network output.
hidden_dim (int): Hidden dimension for GRU cell.
hidden_nonlinearity (callable): Activation function for intermediate
dense layer(s). It should return a torch.Tensor. Set it to
None to maintain a linear activation.
hidden_w_init (callable): Initializer function for the weight
of intermediate dense layer(s). The function should return a
torch.Tensor.
hidden_b_init (callable): Initializer function for the bias
of intermediate dense layer(s). The function should return a
torch.Tensor.
output_nonlinearity (callable): Activation function for output dense
layer. It should return a torch.Tensor. Set it to None to
maintain a linear activation.
output_w_init (callable): Initializer function for the weight
of output dense layer(s). The function should return a
torch.Tensor.
output_b_init (callable): Initializer function for the bias
of output dense layer(s). The function should return a
torch.Tensor.
layer_normalization (bool): Bool for using layer normalization or not.
"""

def __init__(
self,
input_dim,
output_dim,
hidden_dim,
hidden_nonlinearity=nn.Tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
output_nonlinearity=None,
output_w_init=nn.init.xavier_uniform_,
output_b_init=nn.init.zeros_,
layer_normalization=False,
):
super().__init__()

self._gru_module = GRUModule(
input_dim,
hidden_dim,
hidden_nonlinearity,
hidden_w_init,
hidden_b_init,
layer_normalization,
)

self._linear_layer = nn.Sequential()
hidden_layer = nn.Linear(hidden_dim, output_dim)
output_w_init(hidden_layer.weight)
output_b_init(hidden_layer.bias)
self._linear_layer.add_module("output", hidden_layer)
if output_nonlinearity:
self._linear_layer.add_module(
"output_activation", NonLinearity(output_nonlinearity)
)

def forward(self, *inputs):
"""Forward method.

Args:
*inputs: Input to the module.

Returns:
torch.distributions.Categorical: Policy distribution.

"""
assert len(inputs) == 1
gru_output = self._gru_module(inputs[0])
fc_output = self._linear_layer(gru_output)
dist = Categorical(logits=fc_output.unsqueeze(0))
return dist
79 changes: 79 additions & 0 deletions src/garage/torch/modules/gru_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""GRU Module."""
import copy

import torch
from torch import nn
from torch.autograd import Variable

from garage.experiment import deterministic
from garage.torch import global_device, NonLinearity


# pytorch v1.6 issue, see https://github.com/pytorch/pytorch/issues/42305
# pylint: disable=abstract-method
# pylint: disable=unused-argument
class GRUModule(nn.Module):
"""Gated Recurrent Unit (GRU) model in pytorch.

Args:
input_dim (int): Dimension of the network input.
hidden_dim (int): Hidden dimension for GRU cell.
hidden_nonlinearity (callable): Activation function for intermediate
dense layer(s). It should return a torch.Tensor. Set it to
None to maintain a linear activation.
hidden_w_init (callable): Initializer function for the weight
of intermediate dense layer(s). The function should return a
torch.Tensor.
hidden_b_init (callable): Initializer function for the bias
of intermediate dense layer(s). The function should return a
torch.Tensor.
layer_normalization (bool): Bool for using layer normalization or not.
"""

def __init__(
self,
input_dim,
hidden_dim,
hidden_nonlinearity=nn.Tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
layer_normalization=False,
):
super().__init__()
self._layers = nn.Sequential()
self.hidden_dim = hidden_dim
self._gru_cell = nn.GRUCell(input_dim, hidden_dim)
hidden_w_init(self._gru_cell.weight_ih)
hidden_w_init(self._gru_cell.weight_hh)
hidden_b_init(self._gru_cell.bias_ih)
hidden_b_init(self._gru_cell.bias_hh)
self.hidden_nonlinearity = NonLinearity(hidden_nonlinearity)

self._layers.add_module("activation", self.hidden_nonlinearity)
if layer_normalization:
self._layers.add_module("layer_normalization", nn.LayerNorm(hidden_dim))

# pylint: disable=arguments-differ
def forward(self, input_val):
"""Forward method.

Args:
input_val (torch.Tensor): Input values with (N, *, input_dim) shape.

Returns:
torch.Tensor: Output values with (N, *, hidden_dim) shape.

"""
if len(input_val.size()) == 2:
input_val = input_val.unsqueeze(0)
h0 = Variable(
torch.zeros(input_val.size(0), self.hidden_dim)).to(global_device())
outs = []
hn = h0
for seq in range(input_val.size(1)):
hn = self._gru_cell(input_val[:, seq, :], hn)
outs.append(hn)
out = outs[-1].squeeze(dim=1)
out = self._layers(out)
outs = torch.stack(outs)
return out
2 changes: 2 additions & 0 deletions src/garage/torch/policies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from garage.torch.policies.policy import Policy
from garage.torch.policies.tanh_gaussian_mlp_policy import (
TanhGaussianMLPPolicy)
from garage.torch.policies.categorical_gru_policy import CategoricalGRUPolicy

__all__ = [
'CategoricalCNNPolicy',
Expand All @@ -21,4 +22,5 @@
'Policy',
'TanhGaussianMLPPolicy',
'ContextConditionedPolicy',
'CategoricalGRUPolicy',
]
177 changes: 177 additions & 0 deletions src/garage/torch/policies/categorical_gru_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
"""CategoricalGRUPolicy."""
import akro
import numpy as np
import torch
from torch import nn

from garage.torch.modules import CategoricalGRUModule
from garage.torch.policies.stochastic_policy import StochasticPolicy


class CategoricalGRUPolicy(StochasticPolicy):
"""CategoricalGRUPolicy.

A policy that contains a GRU and a MLP to make prediction based on
a categorical distribution.

It only works with akro.Discrete action space.

Args:
env_spec (EnvSpec): Environment specification.
hidden_dim (int): Hidden dimension for GRU cell.
hidden_nonlinearity (callable): Activation function for intermediate
dense layer(s). It should return a torch.Tensor. Set it to
None to maintain a linear activation.
hidden_w_init (callable): Initializer function for the weight
of intermediate dense layer(s). The function should return a
torch.Tensor.
hidden_b_init (callable): Initializer function for the bias
of intermediate dense layer(s). The function should return a
torch.Tensor.
output_nonlinearity (callable): Activation function for output dense
layer. It should return a torch.Tensor. Set it to None to
maintain a linear activation.
output_w_init (callable): Initializer function for the weight
of output dense layer(s). The function should return a
torch.Tensor.
output_b_init (callable): Initializer function for the bias
of output dense layer(s). The function should return a
torch.Tensor.
state_include_action (bool): Whether the state includes action.
If True, input dimension will be
(observation dimension + action dimension).
layer_normalization (bool): Bool for using layer normalization or not.
name (str): Name of policy.
"""

def __init__(
self,
env_spec,
hidden_dim=32,
hidden_nonlinearity=nn.Tanh,
hidden_w_init=nn.init.xavier_uniform_,
hidden_b_init=nn.init.zeros_,
output_nonlinearity=None,
output_w_init=nn.init.xavier_uniform_,
output_b_init=nn.init.zeros_,
state_include_action=True,
layer_normalization=False,
name="CategoricalGRUPolicy",
):
if not isinstance(env_spec.action_space, akro.Discrete):
raise ValueError('CategoricalGRUPolicy only works'
'with akro.Discrete action space.')

super().__init__(env_spec, name)
self._env_spec = env_spec
self._obs_dim = env_spec.observation_space.flat_dim
self._action_dim = env_spec.action_space.n

self._hidden_dim = hidden_dim
self._hidden_nonlinearity = hidden_nonlinearity
self._hidden_w_init = hidden_w_init
self._hidden_b_init = hidden_b_init
self._output_nonlinearity = output_nonlinearity
self._output_w_init = output_w_init
self._output_b_init = output_b_init
self._layer_normalization = layer_normalization
self._state_include_action = state_include_action

if state_include_action:
self._input_dim = self._obs_dim + self._action_dim
else:
self._input_dim = self._obs_dim

self._prev_actions = None

self._module = CategoricalGRUModule(
input_dim=self._input_dim,
output_dim=self._action_dim,
hidden_dim=self._hidden_dim,
hidden_nonlinearity=self._hidden_nonlinearity,
hidden_w_init=self._hidden_w_init,
hidden_b_init=self._hidden_b_init,
output_nonlinearity=self._output_nonlinearity,
output_w_init=self._output_w_init,
output_b_init=self._output_b_init,
layer_normalization=self._layer_normalization,
)

def forward(self, observations):
"""Compute the action distributions from the observations.

Args:
observations (torch.Tensor): Batch of observations on default
torch device.

Returns:
torch.distributions.Distribution: Batch distribution of actions.
dict[str, torch.Tensor]: Additional agent_info, as torch Tensors.
Do not need to be detached, and can be on any device.
"""
dist = self._module(observations)
return dist, {}

def reset(self, do_resets=None):
"""Reset the policy.

Note:
If `do_resets` is None, it will be by default np.array([True]),
which implies the policy will not be "vectorized", i.e. number of
paralle environments for training data sampling = 1.

Args:
do_resets (numpy.ndarray): Bool that indicates terminal state(s).

"""
if do_resets is None:
do_resets = [True]
do_resets = np.asarray(do_resets)
if self._prev_actions is None or len(do_resets) != len(
self._prev_actions):
self._prev_actions = np.zeros(
(len(do_resets), self.action_space.flat_dim))
self._prev_hiddens = np.zeros((len(do_resets), self._hidden_dim))

self._prev_actions[do_resets] = 0.

def get_actions(self, observations):
"""Return multiple actions.

Args:
observations (numpy.ndarray): Observations.

Returns:
list[int]: Actions given input observations.
dict(numpy.ndarray): Distribution parameters.

"""
if self._state_include_action:
assert self._prev_actions is not None
all_input = np.concatenate([observations, self._prev_actions],
axis=-1)
else:
all_input = observations
prev_actions = self._prev_actions
actions, agent_info = super().get_actions(all_input)
self._prev_actions = self.action_space.flatten_n([a.item() for a in actions])
if self._state_include_action:
agent_info['prev_action'] = np.copy(prev_actions)
return actions, agent_info

@property
def input_dim(self):
"""int: Dimension of the policy input."""
return self._input_dim

@property
def env_spec(self):
"""Policy environment specification.

Returns:
garage.EnvSpec: Environment specification.

"""
return self._env_spec


Loading