Skip to content

Commit 22b04d0

Browse files
authored
Updated for compat with RLBase 0.11 (#11)
* updated for compatabilty with RLBase 0.11 * bug fix for envpool
1 parent ceff32b commit 22b04d0

File tree

7 files changed

+108
-95
lines changed

7 files changed

+108
-95
lines changed

Project.toml

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "MPOPIS"
22
uuid = "e8a75bc8-90e1-4072-945a-20230e5738f6"
33
authors = ["Dylan Asmar <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.2.0"
55

66
[deps]
77
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
88
CovarianceEstimation = "587fd27a-f159-11e8-2dae-1979310e6154"
99
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1010
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
11+
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
1112
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
@@ -18,3 +19,16 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1819
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1920
ReinforcementLearning = "158674fc-8238-5cab-b5ba-03dfc80d1318"
2021
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
22+
23+
[compat]
24+
CSV = "0.10"
25+
CovarianceEstimation = "0.2"
26+
Distributions = "0.25"
27+
DomainSets = "0.7"
28+
IntervalSets = "0.7"
29+
Plots = "1"
30+
ProgressMeter = "1"
31+
PyCall = "1.96"
32+
Reexport = "1"
33+
ReinforcementLearning = "0.11"
34+
StatsBase = "0.34"

src/MPOPIS.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import CovarianceEstimation.SimpleCovariance
1616
@reexport using ReinforcementLearning
1717
import ReinforcementLearning.AbstractEnv
1818
import ReinforcementLearning.RLBase
19+
using DomainSets
1920
using Plots
2021
@reexport import Plots.plot
2122
using ProgressMeter

src/envs/car_racing.jl

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@ Base.show(io::IO, params::CarRacingEnvParams) = print(
2525
join(["$p=$(getfield(params, p))" for p in fieldnames(CarRacingEnvParams)], ","),
2626
)
2727

28-
mutable struct CarRacingEnv{A,T,R<:AbstractRNG} <: AbstractEnv
28+
mutable struct CarRacingEnv{T,R<:AbstractRNG} <: AbstractEnv
2929
params::CarRacingEnvParams{T}
30-
action_space::A
31-
observation_space::Space{Vector{ClosedInterval{T}}}
3230
state::Vector{T}
3331
done::Bool
3432
t::Int
@@ -136,25 +134,8 @@ function CarRacingEnv(
136134
rng=Random.GLOBAL_RNG
137135
)
138136

139-
action_space = ClosedInterval{Vector{T}}(
140-
[-1.0, -1.0],
141-
[1.0, 1.0],
142-
)
143-
observation_space = Space([
144-
-Inf .. Inf, # X position in XY plane (x = north, y = west)
145-
-Inf .. Inf, # Y position in XY plane (x = north, y = west)
146-
-π .. π, # yaw (rotation from x axis toward y axis [north to west])
147-
-Inf .. Inf, # Longitudinal velocity
148-
-Inf .. Inf, # Lateral velocity
149-
-Inf .. Inf, # yaw rate
150-
-params.δ_max .. params.δ_max, # steering angle
151-
-1.0 .. 1.0, # acceleration/brake amount [-1, 1]
152-
])
153-
154137
env = CarRacingEnv(
155138
params,
156-
action_space,
157-
observation_space,
158139
zeros(T, 8),
159140
false,
160141
0,
@@ -172,8 +153,25 @@ CarRacingEnv{T}(; kwargs...) where {T} = CarRacingEnv(; T=T, kwargs...)
172153

173154
Random.seed!(env::CarRacingEnv, seed) = Random.seed!(env.rng, seed)
174155

175-
RLBase.action_space(env::CarRacingEnv) = env.action_space
176-
RLBase.state_space(env::CarRacingEnv{T}) where {T} = env.observation_space
156+
function RLBase.action_space(::CarRacingEnv{T}) where {T}
157+
action_space = ClosedInterval{Vector{T}}([-1.0, -1.0], [1.0, 1.0])
158+
return action_space
159+
end
160+
161+
function RLBase.state_space(env::CarRacingEnv)
162+
state_space = ArrayProductDomain([
163+
-Inf .. Inf, # X position in XY plane (x = north, y = west)
164+
-Inf .. Inf, # Y position in XY plane (x = north, y = west)
165+
-π .. π, # yaw (rotation from x axis toward y axis [north to west])
166+
-Inf .. Inf, # Longitudinal velocity
167+
-Inf .. Inf, # Lateral velocity
168+
-Inf .. Inf, # yaw rate
169+
-env.params.δ_max .. env.params.δ_max, # steering angle
170+
-1.0 .. 1.0, # acceleration/brake amount [-1, 1]
171+
])
172+
return state_space
173+
end
174+
177175
RLBase.is_terminated(env::CarRacingEnv) = env.done
178176
RLBase.state(env::CarRacingEnv) = env.state
179177

@@ -214,7 +212,7 @@ function RLBase.reward(env::CarRacingEnv{T}) where {T}
214212
return rew
215213
end
216214

217-
function RLBase.reset!(env::CarRacingEnv{A,T}) where {A,T}
215+
function RLBase.reset!(env::CarRacingEnv{T}) where {T}
218216
ss_size = length(env.state)
219217
env.state = zeros(T, ss_size)
220218
env.state[3] = deg2rad(90)
@@ -224,7 +222,7 @@ function RLBase.reset!(env::CarRacingEnv{A,T}) where {A,T}
224222
nothing
225223
end
226224

227-
function RLBase.reset!(env::CarRacingEnv{A,T}, state::Vector{T}) where {A,T}
225+
function RLBase.reset!(env::CarRacingEnv{T}, state::Vector{T}) where {T}
228226
env.state = state
229227
env.t = 0
230228
env.done = false
@@ -237,16 +235,16 @@ end
237235
a[1] = Turn angle [-max turn angle, max turn angle] (-1 right turn, +1 left turn)
238236
a[2] = Pedal amount (-1 = full brake, 1 = full throttle)
239237
"""
240-
function (env::CarRacingEnv{<:ClosedInterval})(a::Vector{Float64})
241-
a in env.action_space || error("Action is not in action space")
238+
function (env::CarRacingEnv)(a::Vector{Float64})
239+
a in action_space(env) || error("Action is not in action space")
242240
_step!(env, a)
243241
end
244242

245-
function (env::CarRacingEnv{<:ClosedInterval})(a::Vector{Int})
243+
function (env::CarRacingEnv)(a::Vector{Int})
246244
env(Float64.(a))
247245
end
248246

249-
function (env::CarRacingEnv{<:ClosedInterval})(a::Matrix{Float64})
247+
function (env::CarRacingEnv)(a::Matrix{Float64})
250248
size(a)[2] == 1 || error("Only implented for one step")
251249
env(vec(a))
252250
end

src/envs/envpool_env.jl

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using PyCall
22

3-
mutable struct EnvpoolEnv{A,T,R<:AbstractRNG} <: AbstractEnv
3+
mutable struct EnvpoolEnv{T,R<:AbstractRNG} <: AbstractEnv
44
task::String
55
py_env::PyObject
6-
action_space::A
7-
observation_space::Space{Vector{ClosedInterval{T}}}
86
num_states::Int
97
num_envs::Int
108
info::Dict
@@ -62,26 +60,11 @@ function EnvpoolEnv(
6260

6361
py_env = py"get_envs_ep"(task, "gym", num_envs, frame_skip)
6462
env_data = py_env.reset()
65-
py_action_space = py_env.action_space
66-
67-
action_space = ClosedInterval{Vector{T}}(
68-
py_action_space.low,
69-
py_action_space.high,
70-
)
71-
72-
py_observation_space = py_env.observation_space
73-
py_obs_len = py_observation_space.shape[1]
74-
py_obs_low = py_observation_space.low
75-
py_obs_high = py_observation_space.high
76-
77-
observation_vec = [py_obs_low[ii] .. py_obs_high[ii] for ii in 1:py_obs_len]
78-
observation_space = Space(observation_vec)
79-
63+
py_obs_len = py_env.observation_space.shape[1]
64+
8065
env = EnvpoolEnv(
8166
task,
8267
py_env,
83-
action_space,
84-
observation_space,
8568
py_obs_len,
8669
num_envs,
8770
env_data[end],
@@ -98,16 +81,35 @@ function EnvpoolEnv(
9881
end
9982

10083
Random.seed!(env::EnvpoolEnv, seed) = Random.seed!(env.rng, seed)
101-
RLBase.action_space(env::EnvpoolEnv) = env.action_space
102-
RLBase.state_space(env::EnvpoolEnv) = env.observation_space
84+
85+
function RLBase.action_space(env::EnvpoolEnv{T}) where {T}
86+
py_action_space = env.py_env.action_space
87+
action_space = ClosedInterval{Vector{T}}(
88+
py_action_space.low,
89+
py_action_space.high,
90+
)
91+
return action_space
92+
end
93+
94+
function RLBase.state_space(env::EnvpoolEnv{T}) where {T}
95+
py_obs_len = env.py_env.observation_space.shape[1]
96+
py_obs_low = env.py_env.observation_space.low
97+
py_obs_high = env.py_env.observation_space.high
98+
99+
observation_vec = [py_obs_low[ii] .. py_obs_high[ii] for ii in 1:py_obs_len]
100+
observation_space = ArrayProductDomain(observation_vec)
101+
102+
return observation_space
103+
end
104+
103105
RLBase.is_terminated(env::EnvpoolEnv) = env.done
104106
RLBase.state(env::EnvpoolEnv) = env.state
105107
RLBase.reward(env::EnvpoolEnv) = env.rews
106108

107109
"""
108110
The keywork argument `restore` is used to restore the environments based on `acts`
109111
"""
110-
function RLBase.reset!(env::EnvpoolEnv{A,T}; restore=false) where {A,T}
112+
function RLBase.reset!(env::EnvpoolEnv{T}; restore=false) where {T}
111113
env_data = env.py_env.reset()
112114
env.info = env_data[end]
113115
env.rews = zeros(T, env.num_envs)

src/envs/multi-car_racing.jl

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11

2-
mutable struct MultiCarRacingEnv{A,T,R<:AbstractRNG} <: AbstractEnv
2+
mutable struct MultiCarRacingEnv{T,R<:AbstractRNG} <: AbstractEnv
33
N::Int
44
envs::Vector{CarRacingEnv}
5-
action_space::A
6-
observation_space::Space{Vector{A}}
75
state::Vector{T}
86
done::Bool
97
t::Int
@@ -39,41 +37,26 @@ function MultiCarRacingEnv(N=2;
3937
envs = Vector{CarRacingEnv}(undef, N)
4038
for ii in 1:N
4139
if length(car_params) >= ii
42-
cre = CarRacingEnv(car_params, T=T, dt=dt, δt=δt, track=track, rng=rng)
40+
cre = CarRacingEnv(car_params; T=T, dt=dt, δt=δt, track=track, rng=rng)
4341
else
44-
cre = CarRacingEnv(T=T, dt=dt, δt=δt, track=track, rng=rng)
42+
cre = CarRacingEnv(; T=T, dt=dt, δt=δt, track=track, rng=rng)
4543
end
4644
envs[ii] = cre
4745
end
48-
49-
50-
endpts_l = []
51-
endpts_r = []
52-
single_state_size = length(RLBase.state_space(envs[1]))
53-
obs_space_vec = Vector{ClosedInterval}(undef, N * single_state_size)
46+
47+
single_state_size = length(RLBase.state_space(envs[1]).domains)
5448
state = zeros(T, N * single_state_size)
55-
for (idx, en) in enumerate(envs)
56-
endpts_l = [endpts_l; leftendpoint(RLBase.action_space(en))]
57-
endpts_r = [endpts_r; rightendpoint(RLBase.action_space(en))]
58-
start_idx = single_state_size * (idx - 1) + 1
59-
end_idx = single_state_size * idx
60-
obs_space_vec[start_idx:end_idx] = RLBase.state_space(en)[:]
61-
end
62-
action_space = ClosedInterval{Vector{T}}(endpts_l, endpts_r)
63-
observation_space = Space(obs_space_vec)
64-
49+
6550
env = MultiCarRacingEnv(
6651
N,
6752
envs,
68-
action_space,
69-
observation_space,
7053
state,
7154
false,
7255
0,
7356
dt,
7457
δt,
7558
Track(track),
76-
rng,
59+
rng
7760
)
7861

7962
reset!(env)
@@ -89,14 +72,38 @@ function Random.seed!(env::MultiCarRacingEnv, seed)
8972
end
9073
end
9174

92-
RLBase.action_space(env::MultiCarRacingEnv) = env.action_space
93-
RLBase.state_space(env::MultiCarRacingEnv{T}) where {T} = env.observation_space
75+
function RLBase.action_space(env::MultiCarRacingEnv{T}) where {T}
76+
endpts_l = []
77+
endpts_r = []
78+
for en in env.envs
79+
endpts_l = [endpts_l; leftendpoint(RLBase.action_space(en))]
80+
endpts_r = [endpts_r; rightendpoint(RLBase.action_space(en))]
81+
end
82+
action_space = ClosedInterval{Vector{T}}(endpts_l, endpts_r)
83+
return action_space
84+
end
85+
86+
function RLBase.state_space(env::MultiCarRacingEnv{T}) where {T}
87+
envs = env.envs
88+
single_state_size = length(RLBase.state_space(envs[1]).domains)
89+
obs_space_vec = Vector{ClosedInterval}(undef, env.N * single_state_size)
90+
for (idx, en) in enumerate(envs)
91+
start_idx = single_state_size * (idx - 1) + 1
92+
end_idx = single_state_size * idx
93+
obs_space_vec[start_idx:end_idx] = RLBase.state_space(en)[:]
94+
end
95+
observation_space = ArrayProductDomain(obs_space_vec)
96+
return observation_space
97+
end
98+
99+
100+
94101
RLBase.is_terminated(env::MultiCarRacingEnv) = env.done
95102
RLBase.state(env::MultiCarRacingEnv) = env.state
96103

97104
function _update_states_env2envs(env::MultiCarRacingEnv)
98105
for (idx, en) in enumerate(env.envs)
99-
ss_size = length(RLBase.state_space(en))
106+
ss_size = length(RLBase.state_space(en).domains)
100107
start_idx = ss_size * (idx - 1) + 1
101108
end_idx = ss_size * idx
102109
en.state = env.state[start_idx:end_idx]
@@ -105,7 +112,7 @@ end
105112

106113
function _update_states_envs2env(env::MultiCarRacingEnv)
107114
for (idx, en) in enumerate(env.envs)
108-
ss_size = length(RLBase.state_space(en))
115+
ss_size = length(RLBase.state_space(en).domains)
109116
start_idx = ss_size * (idx - 1) + 1
110117
end_idx = ss_size * idx
111118
env.state[start_idx:end_idx] = en.state
@@ -150,7 +157,7 @@ function RLBase.reward(env::MultiCarRacingEnv{T}) where {T}
150157
return rew
151158
end
152159

153-
function RLBase.reset!(env::MultiCarRacingEnv{A,T}) where {A,T}
160+
function RLBase.reset!(env::MultiCarRacingEnv{T}) where {T}
154161
ss_size = length(env.state)
155162
ind_ss_size = round(Int, length(env.state) / env.N)
156163
env.envs[1].state = zeros(T, ind_ss_size)
@@ -172,7 +179,7 @@ function RLBase.reset!(env::MultiCarRacingEnv{A,T}) where {A,T}
172179
nothing
173180
end
174181

175-
function RLBase.reset!(env::MultiCarRacingEnv{A,T}, state::Vector{T}) where {A,T}
182+
function RLBase.reset!(env::MultiCarRacingEnv{T}, state::Vector{T}) where {T}
176183
env.state = state
177184
_update_states_env2envs(env)
178185
env.t = 0

src/examples/cartpole_example.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11

22
# Modifications to the RLBase functions to work with different GMPPI algorithms
3-
function (env::CartPoleEnv{<:Base.OneTo{Int}})(a::Vector)
3+
function (env::CartPoleEnv)(a)
44
length(a) == 1 || error("Only implented for 1 step")
5-
env(a[1])
6-
end
7-
function (env::CartPoleEnv{<:ClosedInterval})(a::Vector)
8-
length(a) == 1 || error("Only implented for 1 step")
9-
env(a[1])
5+
RLBase.act!(env, a[1])
106
end
117

128
"""
@@ -189,4 +185,3 @@ function simulate_cartpole(;
189185
gif(anim, gif_name, fps=10)
190186
end
191187
end
192-

src/examples/mountaincar_example.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11

22

33
# Modifications to the RLBase functions to work with different GMPPI algorithms
4-
function (env::MountainCarEnv{<:ClosedInterval})(a::Vector)
4+
function (env::MountainCarEnv)(a)
55
length(a) == 1 || error("Only implented for 1 step")
6-
env(a[1])
7-
end
8-
function (env::MountainCarEnv{<:Base.OneTo{Int}})(a::Vector)
9-
length(a) == 1 || error("Only implented for 1 step")
10-
env(a[1])
6+
RLBase.act!(env, a[1])
117
end
128

139
# Modified MountainCar reward function

0 commit comments

Comments
 (0)