JuliaRL_PrioritizedDQN_CartPole
using ReinforcementLearning
using StableRNGs
using Flux
using Flux.Losses
function RL.Experiment(
::Val{:JuliaRL},
::Val{:PrioritizedDQN},
::Val{:CartPole},
::Nothing;
save_dir = nothing,
seed = 123,
)
rng = StableRNG(seed)
env = CartPoleEnv(; T = Float32, rng = rng)
ns, na = length(state(env)), length(action_space(env))
agent = Agent(
policy = QBasedPolicy(
learner = PrioritizedDQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na; init = glorot_uniform(rng)),
) |> gpu,
optimizer = ADAM(),
),
target_approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na; init = glorot_uniform(rng)),
) |> gpu,
optimizer = ADAM(),
),
loss_func = (ŷ, y) -> huber_loss(ŷ, y; agg = identity),
stack_size = nothing,
batch_size = 32,
update_horizon = 1,
min_replay_history = 100,
update_freq = 1,
target_update_freq = 100,
rng = rng,
),
explorer = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 500,
rng = rng,
),
),
trajectory = CircularArrayPSARTTrajectory(
capacity = 1000,
state = Vector{Float32} => (ns,),
),
)
stop_condition = StopAfterStep(10_000)
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "# Play CartPole with PrioritizedDQN")
end
using Plots
pyplot() #hide
ex = E`JuliaRL_PrioritizedDQN_CartPole`
run(ex)
plot(ex.hook.rewards)
savefig("assets/JuliaRL_PrioritizedDQN_CartPole.png") #hide
This page was generated using DemoCards.jl and Literate.jl.