GenDataset_BasicDQN_CartPole

Source code Author Update time

using ReinforcementLearning
using StableRNGs
using Flux
using Flux.Losses

function RL.Experiment(
    ::Val{:GenDataset},
    ::Val{:BasicDQN},
    ::Val{:CartPole},
    type::AbstractString;
    dataset_size = 10000,
    seed = 123,
)
    rng = StableRNG(seed)
    env = CartPoleEnv(; T = Float32, rng = rng)
    ns, na = length(state(env)), length(action_space(env))

    create_greedy_explorer() = EpsilonGreedyExplorer(
            kind = :exp,
            ϵ_stable = 0.01,
            decay_steps = 500,
            rng = rng,
        )

    create_random_explorer() = EpsilonGreedyExplorer(
            ϵ_stable = 1.0,
            rng = rng,
        )

    if type == "random"
        explorer = create_random_explorer()
        trajectory_num = dataset_size
    elseif type == "medium"
        explorer = create_greedy_explorer()
        trajectory_num = dataset_size
    elseif type == "expert"
        explorer = create_greedy_explorer()
        trajectory_num = 10000 + dataset_size
    else
        @error("wrong parameter")
    end

    policy = Agent(
        policy = QBasedPolicy(
            learner = BasicDQNLearner(
                approximator = NeuralNetworkApproximator(
                    model = Chain(
                        Dense(ns, 128, relu; init = glorot_uniform(rng)),
                        Dense(128, 128, relu; init = glorot_uniform(rng)),
                        Dense(128, na; init = glorot_uniform(rng)),
                    ) |> cpu,
                    optimizer = ADAM(),
                ),
                batch_size = 32,
                min_replay_history = 100,
                loss_func = huber_loss,
                rng = rng,
            ),
            explorer = explorer,
        ),
        trajectory = CircularArraySARTTrajectory(
            capacity = dataset_size+1,
            state = Vector{Float32} => (ns,),
        ),
    )

    stop_condition = StopAfterStep(trajectory_num+1, is_show_progress=!haskey(ENV, "CI"))
    hook = TotalRewardPerEpisode()
    Experiment(policy, env, stop_condition, hook, "# Collect $type CartPole dataset generated by BasicDQN")
end

This page was generated using DemoCards.jl and Literate.jl.