begin
using ReinforcementLearning
using Flux
using Statistics
using Plots
using Distributions
end
Again, we use a environment model to describe the Grid World in Chapter 4.2.
md"""
Again, we use a environment model to describe the **Grid World** in **Chapter 4.2**.
"""
begin
isterminal(s::CartesianIndex{2}) = s == CartesianIndex(1,1) || s == CartesianIndex(4,4)
function nextstep(s::CartesianIndex{2}, a::CartesianIndex{2})
s′ = s + a
if isterminal(s) || s′[1] < 1 || s′[1] > 4 || s′[2] < 1 || s′[2] > 4
s′ = s
end
r = isterminal(s) ? 0. : -1.0
[(r, isterminal(s′), LinearIndices((4,4))[s′]) => 1.0]
end
const ACTIONS = [
CartesianIndex(-1, 0),
CartesianIndex(1,0),
CartesianIndex(0, 1),
CartesianIndex(0, -1)
]
struct GridWorldEnvModel <: AbstractEnvironmentModel
cache
end
GridWorldEnvModel() = GridWorldEnvModel(
Dict(
(s, a) => nextstep(CartesianIndices((4,4))[s], ACTIONS[a])
for s in 1:16 for a in 1:4
)
)
(m::GridWorldEnvModel)(s, a) = m.cache[(s,a)]
RLBase.state_space(m::GridWorldEnvModel) = Base.OneTo(16)
RLBase.action_space(m::GridWorldEnvModel) = Base.OneTo(4)
end
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
V = TabularVApproximator(n_state=16, opt=Descent(1.0))
typename(TabularRandomPolicy)
├─ table => typename(Dict)
└─ rng => typename(Random._GLOBAL_RNG)
p = TabularRandomPolicy(table=Dict(s => fill(0.25, 4) for s in 1:16))
11
2
-1.0
false
12
1.0
7
1
-1.0
false
6
1.0
9
3
-1.0
false
13
1.0
9
4
-1.0
false
5
1.0
7
2
-1.0
false
8
1.0
2
1
-1.0
true
1
1.0
10
1
-1.0
false
9
1.0
2
2
-1.0
false
3
1.0
10
2
-1.0
false
11
1.0
16
1
0.0
true
16
1.0
model = GridWorldEnvModel()
0.0
-13.9993
-19.999
-21.9989
-13.9993
-17.9992
-19.9991
-19.9991
-19.999
0.0
1.0
4×4 Matrix{Float64}:
0.0 -13.9993 -19.999 -21.9989
-13.9993 -17.9992 -19.9991 -19.9991
-19.999 -19.9991 -17.9992 -13.9994
-21.9989 -19.9991 -13.9994 0.0