Replies: 5 comments 4 replies
-
Hi @kifuman I know that there are currently no examples of hyperparameter tuning in the library documentation. It is a pending topic (which I have had in mind for a long long long time 😅🙈) for future releases. |
Beta Was this translation helpful? Give feedback.
-
Thanks for your quick reply! Is there any chance you can help me out with a small example for DDPG in a gymnasium environment? |
Beta Was this translation helpful? Give feedback.
-
Hi @kifuman In the Note that it is necessary to use skrl-v1.1.0 ( The script will generate a database file that can be loaded with the Optuna Dashboard for visualization (check the Optuna documentation for information about the Optuna Dashboard) as follows: optuna-dashboard sqlite:///hyperparameter_optimization.db Feel free to continue the discussion if you have any questions. hyperparameter_optimization.zip import optuna
import logging
import numpy as np
# disable skrl logging
from skrl import logger
logger.setLevel(logging.WARNING)
def objective(trial: optuna.Trial):
# parameters to optimize
# https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html
batch_size = trial.suggest_categorical("batch_size", [64, 128, 256])
learning_rate = trial.suggest_float("learning_rate", low=1e-5, high=1e-2, log=True)
discount_factor = trial.suggest_categorical("discount_factor", [0.98, 0.99, 0.999])
# metrics
episode_rewards = []
instantaneous_rewards = []
# reinforcement learning experiment
# ---------------------------------
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
# import the skrl components to build the RL system
from skrl.agents.torch.ddpg import DDPG, DDPG_DEFAULT_CONFIG
from skrl.envs.wrappers.torch import wrap_env
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, Model
from skrl.resources.noises.torch import OrnsteinUhlenbeckNoise
from skrl.trainers.torch import StepTrainer
from skrl.utils import set_seed
# define models (deterministic models) using mixin
class Actor(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False):
Model.__init__(self, observation_space, action_space, device)
DeterministicMixin.__init__(self, clip_actions)
self.linear_layer_1 = nn.Linear(self.num_observations, 400)
self.linear_layer_2 = nn.Linear(400, 300)
self.action_layer = nn.Linear(300, self.num_actions)
def compute(self, inputs, role):
x = F.relu(self.linear_layer_1(inputs["states"]))
x = F.relu(self.linear_layer_2(x))
# Pendulum-v1 action_space is -2 to 2
return 2 * torch.tanh(self.action_layer(x)), {}
class Critic(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False):
Model.__init__(self, observation_space, action_space, device)
DeterministicMixin.__init__(self, clip_actions)
self.linear_layer_1 = nn.Linear(self.num_observations + self.num_actions, 400)
self.linear_layer_2 = nn.Linear(400, 300)
self.linear_layer_3 = nn.Linear(300, 1)
def compute(self, inputs, role):
x = F.relu(self.linear_layer_1(torch.cat([inputs["states"], inputs["taken_actions"]], dim=1)))
x = F.relu(self.linear_layer_2(x))
return self.linear_layer_3(x), {}
# seed for reproducibility
set_seed() # e.g. `set_seed(42)` for fixed seed
# load and wrap the gymnasium environment.
# note: the environment version may change depending on the gymnasium version
try:
env = gym.make("Pendulum-v1")
except (gym.error.DeprecatedEnv, gym.error.VersionNotFound) as e:
env_id = [spec for spec in gym.envs.registry if spec.startswith("Pendulum-v")][0]
print("Pendulum-v1 not found. Trying {}".format(env_id))
env = gym.make(env_id)
env = wrap_env(env)
device = env.device
# instantiate a memory as experience replay
memory = RandomMemory(memory_size=10000, num_envs=env.num_envs, device=device, replacement=False)
# instantiate the agent's models (function approximators).
# DDPG requires 4 models, visit its documentation for more details
# https://skrl.readthedocs.io/en/latest/api/agents/ddpg.html#models
models = {}
models["policy"] = Actor(env.observation_space, env.action_space, device)
models["target_policy"] = Actor(env.observation_space, env.action_space, device)
models["critic"] = Critic(env.observation_space, env.action_space, device)
models["target_critic"] = Critic(env.observation_space, env.action_space, device)
# initialize models' parameters (weights and biases)
for model in models.values():
model.init_parameters(method_name="normal_", mean=0.0, std=0.1)
# configure and instantiate the agent (visit its documentation to see all the options)
# https://skrl.readthedocs.io/en/latest/api/agents/ddpg.html#configuration-and-hyperparameters
cfg = DDPG_DEFAULT_CONFIG.copy()
cfg["exploration"]["noise"] = OrnsteinUhlenbeckNoise(theta=0.15, sigma=0.1, base_scale=1.0, device=device)
cfg["discount_factor"] = discount_factor
cfg["batch_size"] = batch_size
cfg["random_timesteps"] = 100
cfg["learning_starts"] = 100
cfg["actor_learning_rate"] = learning_rate
cfg["critic_learning_rate"] = learning_rate
# skip logging to TensorBoard and write checkpoints (in timesteps)
cfg["experiment"]["write_interval"] = 0
cfg["experiment"]["checkpoint_interval"] = 0
agent = DDPG(models=models,
memory=memory,
cfg=cfg,
observation_space=env.observation_space,
action_space=env.action_space,
device=device)
# configure and instantiate the RL trainer
cfg_trainer = {"timesteps": 10000,
"headless": True,
"disable_progressbar": True,
"close_environment_at_exit": False}
trainer = StepTrainer(cfg=cfg_trainer, env=env, agents=[agent])
# train the agent
for timestep in range(cfg_trainer["timesteps"]):
# training step
next_states, rewards, terminated, truncated, infos = trainer.train(timestep=timestep)
# storage metrics
instantaneous_rewards.append(rewards.item())
if terminated.any() or truncated.any():
episode_rewards.append(np.sum(instantaneous_rewards))
instantaneous_rewards = []
# close the environment
env.close()
# ---------------------------------
return np.mean(episode_rewards)
# https://optuna.readthedocs.io/en/stable/reference/generated/optuna.create_study.html
storage = "sqlite:///hyperparameter_optimization.db"
sampler = optuna.samplers.TPESampler()
direction = "maximize" # maximize episode reward
study = optuna.create_study(storage=storage,
sampler=sampler,
study_name="optimization",
direction=direction,
load_if_exists=True)
study.optimize(objective, n_trials=25)
print(f"The best trial obtains a normalized score of {study.best_trial.value}", study.best_trial.params) |
Beta Was this translation helpful? Give feedback.
-
Hi @Toni-SM , thank you so much for your example. Below i attached a snippet of my code. At first i create a environment for training and a SequentialTrainer. The trainer trains the agent on the training environemnt. For evaluation i create a seperate environment (necessary as i need to use a different callback here) to which i assign the trainer. I also change the amount of steps for evaluation. # ....
# this is just a snippet
# more code above
def ddpg_param_objective(trial):
# ....
# this is just a snippet
# more code above
train_Callback = [trainCallback()]
env_path = "runs/gem_optuna/CONT-CC-PMSM-v0_Trial_" + str(trial.number)
env = create_env(directory_path=env_path, callbacks=train_Callback)
device = env.device
agent = DDPG(models=models,
memory=memory,
cfg=cfg,
observation_space=env.observation_space,
action_space=env.action_space,
device=device)
# configure and instantiate the RL trainer
cfg_trainer = {"timesteps": nb_training_steps, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=[agent])
# start training
trainer.train()
env.close()
# crete eval env with eval_Callback anc change trainer parameters for evaluation
eval_Callback = [evalCallback(ref_i_q, ref_i_d)]
eval_env = create_env(directory_path=env_path, callbacks=eval_Callback)
trainer.env = eval_env
trainer.timesteps = eval_total_steps
# evaluate the trained agent, cumulative_error will be calculated by the eval_callback
trainer.eval()
eval_env.close()
return cumulative_error
sampler = optuna.samplers.TPESampler()
study = optuna.create_study(direction='minimize',
sampler=sampler,
study_name='ddpg_hypermaram')
study.optimize(ddpg_param_objective, n_trials=5)
print('Number of finished trials: ', len(study.trials))
print('Best trial:', study.best_trial.number)
trial = study.best_trial
print('Value: ', trial.value)
print('Params: ')
for key, value in trial.params.items():
print(f' {key}: {value}') The console output is the following:
Don't mind the values for the hyperparameters or the training length, I just created this for demonstration. Best regards |
Beta Was this translation helpful? Give feedback.
-
Hi @kifuman Sorry for late replay. You can disable the skrl's trainer environment closing feature be setting the https://skrl.readthedocs.io/en/develop/api/trainers/sequential.html#configuration |
Beta Was this translation helpful? Give feedback.
-
Hi @Toni-SM,
thank you for creating this awesome library!
Is there any chance to get support for hyperparameter tuning?
Best regards,
Fabi
Beta Was this translation helpful? Give feedback.
All reactions