-
Notifications
You must be signed in to change notification settings - Fork 0
/
policies.py
22 lines (20 loc) · 740 Bytes
/
policies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# This file is here just to define MlpPolicy/CnnPolicy
# that work for PPO
from stable_baselines3.common.policies import (
ActorCriticCnnPolicy,
ActorCriticPolicy,
ActorCriticPolicyNorm,
ActorCriticPolicyOptim,
MultiInputActorCriticPolicy,
register_policy,
)
MlpPolicy = ActorCriticPolicy
MlpNormPolicy = ActorCriticPolicyNorm
MlpOptimPolicy = ActorCriticPolicyOptim
CnnPolicy = ActorCriticCnnPolicy
MultiInputPolicy = MultiInputActorCriticPolicy
register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("MlpNormPolicy", ActorCriticPolicyNorm)
register_policy("MlpOptimPolicy", ActorCriticPolicyOptim)
register_policy("CnnPolicy", ActorCriticCnnPolicy)
register_policy("MultiInputPolicy", MultiInputPolicy)