GenDataset_BasicDQN_CartPole
using ReinforcementLearning
using StableRNGs
using Flux
using Flux.Losses
function RL.Experiment(
::Val{:GenDataset},
::Val{:BasicDQN},
::Val{:CartPole},
type::AbstractString;
dataset_size = 10000,
seed = 123,
)
rng = StableRNG(seed)
env = CartPoleEnv(; T = Float32, rng = rng)
ns, na = length(state(env)), length(action_space(env))
create_greedy_explorer() = EpsilonGreedyExplorer(
kind = :exp,
ϵ_stable = 0.01,
decay_steps = 500,
rng = rng,
)
create_random_explorer() = EpsilonGreedyExplorer(
ϵ_stable = 1.0,
rng = rng,
)
if type == "random"
explorer = create_random_explorer()
trajectory_num = dataset_size
elseif type == "medium"
explorer = create_greedy_explorer()
trajectory_num = dataset_size
elseif type == "expert"
explorer = create_greedy_explorer()
trajectory_num = 10000 + dataset_size
else
@error("wrong parameter")
end
policy = Agent(
policy = QBasedPolicy(
learner = BasicDQNLearner(
approximator = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 128, relu; init = glorot_uniform(rng)),
Dense(128, 128, relu; init = glorot_uniform(rng)),
Dense(128, na; init = glorot_uniform(rng)),
) |> cpu,
optimizer = ADAM(),
),
batch_size = 32,
min_replay_history = 100,
loss_func = huber_loss,
rng = rng,
),
explorer = explorer,
),
trajectory = CircularArraySARTTrajectory(
capacity = dataset_size+1,
state = Vector{Float32} => (ns,),
),
)
stop_condition = StopAfterStep(trajectory_num+1, is_show_progress=!haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(policy, env, stop_condition, hook, "# Collect $type CartPole dataset generated by BasicDQN")
end
This page was generated using DemoCards.jl and Literate.jl.