JuliaRL_PLAS_Pendulum
using ReinforcementLearning
using StableRNGs
using Flux
using Flux.Losses
function RL.Experiment(
::Val{:JuliaRL},
::Val{:PLAS},
::Val{:Pendulum},
type::AbstractString;
save_dir = nothing,
seed = 123,
)
rng = StableRNG(seed)
inner_env = PendulumEnv(T = Float32, rng = rng)
A = action_space(inner_env)
low = A.left
high = A.right
ns = length(state(inner_env))
na = 1
latent_dims = 2
trajectory_num = 10000
dataset_size = 10000
batch_size = 64
env = ActionTransformedEnv(
inner_env;
action_mapping = x -> low + (x[1] + 1) * 0.5 * (high - low),
)
init = glorot_uniform(rng)
create_policy_net() = NeuralNetworkApproximator(
model = Chain(
Dense(ns, 64, relu; init = glorot_uniform(rng)),
Dense(64, 64, relu; init = glorot_uniform(rng)),
Dense(64, latent_dims; init = glorot_uniform(rng))
),
optimizer = ADAM(0.003),
)
create_q_net() = NeuralNetworkApproximator(
model = Chain(
Dense(ns + na, 64, relu; init = init),
Dense(64, 64, relu; init = init),
Dense(64, 1; init = init),
),
optimizer = ADAM(0.003),
)
create_vae_net() = NeuralNetworkApproximator(
model = VAE(
encoder = GaussianNetwork(
pre = Chain(
Dense(ns + na, 64, relu),
Dense(64, 64, relu),
),
μ = Chain(Dense(64, latent_dims, init = init)),
logσ = Chain(Dense(64, latent_dims, init = init)),
),
decoder = Chain(
Dense(ns + latent_dims, 64, relu; init = init),
Dense(64, 64, relu; init = init),
Dense(64, na; init = init),
),
latent_dims = latent_dims,
),
optimizer = ADAM(0.003),
)
agent = Agent(
policy = OfflinePolicy(
learner = PLASLearner(
policy = create_policy_net() |> cpu,
target_policy = create_policy_net() |> cpu,
qnetwork1 = create_q_net() |> cpu,
qnetwork2 = create_q_net() |> cpu,
target_qnetwork1 = create_q_net() |> cpu,
target_qnetwork2 = create_q_net() |> cpu,
vae = create_vae_net() |> cpu,
batch_size = batch_size,
pretrain_step = 1500,
update_freq = 1,
),
dataset = gen_JuliaRL_dataset(:SAC, :Pendulum, type; dataset_size = dataset_size),
continuous = true,
batch_size = batch_size,
),
trajectory = CircularArraySARTTrajectory(
capacity = 1000,
state = Vector{Float32} => (ns,),
action = Vector{Float32} => (na,),
),
)
stop_condition = StopAfterStep(trajectory_num, is_show_progress=!haskey(ENV, "CI"))
hook = TotalRewardPerEpisode()
Experiment(agent, env, stop_condition, hook, "PLAS <-> Pendulum ($type dataset)")
end
using Plots
pyplot() #hide
ex = E`JuliaRL_PLAS_Pendulum(medium)`
run(ex)
plot(ex.hook.rewards)
savefig("assets/JuliaRL_PLAS_Pendulum_medium.png") #hide
This page was generated using DemoCards.jl and Literate.jl.