How to use hooks?

What are the hooks?

During the interactions between agents and environments, we often want to collect some useful information. One straightforward approach is the imperative programming. We write the code in a loop and execute them step by step.

while true
    action = plan!(policy, env)
    act!(env, action)

    # write your own logic here
    # like saving parameters, recording loss function, evaluating policy, etc.
    check!(stop_condition, env, policy) && break
    is_terminated(env) && reset!(env)
end

The benefit of this approach is the great clarity. You are responsible for what you write. And this is the encouraged approach for new users to try different components in this package.

Another approach is the declarative programming. We describe when and what we want to do during an experiment. Then put them together with the agent and environment. Finally we execute the run command to conduct our experiment. In this way, we can reuse some common hooks and execution pipelines instead of writing many duplicate codes. In many existing reinforcement learning python packages, people usually use a set of configuration files to define the execution pipeline. However, we believe this is not necessary in Julia. With the declarative programming approach, we gain much more flexibilities.

Now the question is how to design the hook. A natural choice is to wrap the comments part in the above pseudo-code into a function:

while true
    action = plan!(policy, env)
    act!(env, action)
    push!(hook, policy, env)
    check!(stop_condition, env, policy) && break
    is_terminated(env) && reset!(env)
end

But sometimes, we'd like to have a more fine-grained control. So we split the calling of hooks into several different stages:

How to define a customized hook?

By default, an instance of AbstractHook will do nothing when called with push!(hook::AbstractHook, ::AbstractStage, policy, env). So when writing a customized hook, you only need to implement the necessary runtime logic.

For example, assume we want to record the wall time of each episode.

julia> using ReinforcementLearning
julia> import Base.push!
julia> Base.@kwdef mutable struct TimeCostPerEpisode <: AbstractHook t::UInt64 = time_ns() time_costs::Vector{UInt64} = [] endMain.TimeCostPerEpisode
julia> Base.push!(h::TimeCostPerEpisode, ::PreEpisodeStage, policy, env) = h.t = time_ns()
julia> Base.push!(h::TimeCostPerEpisode, ::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t)
julia> h = TimeCostPerEpisode()Main.TimeCostPerEpisode(0x000000a3559c5fbf, UInt64[])
julia> run(RandomPolicy(), CartPoleEnv(), StopAfterNEpisodes(10), h)ERROR: MethodError: push!(::Main.TimeCostPerEpisode, ::PreEpisodeStage, ::RandomPolicy{Nothing, Random.TaskLocalRNG}, ::CartPoleEnv{Float64, Int64}) is ambiguous. Candidates: push!(h::Main.TimeCostPerEpisode, ::PreEpisodeStage, policy, env) @ Main REPL[4]:1 push!(::AbstractHook, ::AbstractStage, ::AbstractPolicy, ::AbstractEnv) @ ReinforcementLearningCore ~/work/ReinforcementLearning.jl/ReinforcementLearning.jl/src/ReinforcementLearningCore/src/core/hooks.jl:35 Possible fix, define push!(::Main.TimeCostPerEpisode, ::PreEpisodeStage, ::AbstractPolicy, ::AbstractEnv)
julia> h.time_costsUInt64[]

Common hooks

Periodic jobs

Sometimes, we'd like to periodically run some functions. Two handy hooks are provided for this kind of tasks:

Following are some typical usages.

Evaluating policy during training

julia> using Statistics: mean
julia> policy = RandomPolicy()::RandomPolicy ├─ action_space::Nothing => nothing └─ rng::TaskLocalRNG => Random.TaskLocalRNG()
julia> run( policy, CartPoleEnv(), StopAfterNEpisodes(100), DoEveryNEpisodes(;n=10) do t, policy, env # In real world cases, the policy is usually wrapped in an Agent, # we need to extract the inner policy to run it in the *actor* mode. # Here for illustration only, we simply use the original policy. # Note that we create a new instance of CartPoleEnv here to avoid # polluting the original env. hook = TotalRewardPerEpisode(;is_display_on_exit=false) run(policy, CartPoleEnv(), StopAfterNEpisodes(10), hook) # now you can report the result of the hook. println("avg reward at episode $t is: $(mean(hook.rewards))") end )avg reward at episode 10 is: 16.5 avg reward at episode 20 is: 16.1 avg reward at episode 30 is: 20.8 avg reward at episode 40 is: 34.5 avg reward at episode 50 is: 19.6 avg reward at episode 60 is: 17.6 avg reward at episode 70 is: 19.7 avg reward at episode 80 is: 25.2 avg reward at episode 90 is: 15.3 avg reward at episode 100 is: 19.7 DoEveryNEpisodes{PostEpisodeStage, Main.var"#2#3"}(Main.var"#2#3"(), 10, 100)

Save parameters

JLD2.jl is recommended to save the parameters of a policy.

julia> using ReinforcementLearning
julia> using JLD2
julia> env = RandomWalk1D()# RandomWalk1D ## Traits | Trait Type | Value | |:----------------- | --------------------:| | NumAgentStyle | SingleAgent() | | DynamicStyle | Sequential() | | InformationStyle | PerfectInformation() | | ChanceStyle | Deterministic() | | RewardStyle | TerminalReward() | | UtilityStyle | GeneralSum() | | ActionStyle | MinimalActionSet() | | StateStyle | Observation{Int64}() | | DefaultStateStyle | Observation{Int64}() | | EpisodeStyle | Episodic() | ## Is Environment Terminated? No ## State Space `Base.OneTo(7)` ## Action Space `Base.OneTo(2)` ## Current State ``` 4 ```
julia> ns, na = length(state_space(env)), length(action_space(env))(7, 2)
julia> policy = Agent( QBasedPolicy(; learner = TDLearner( TabularQApproximator(n_state = ns, n_action = na), :SARS; ), explorer = EpsilonGreedyExplorer(ϵ_stable=0.01), ), Trajectory( CircularArraySARTSTraces(; capacity = 1, state = Int64 => (), action = Int64 => (), reward = Float64 => (), terminal = Bool => (), ), DummySampler(), InsertSampleRatioController(), ), )Agent{QBasedPolicy{TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}, EpsilonGreedyExplorer{:linear, false, Random.TaskLocalRNG}}, Trajectory{EpisodesBuffer{(:state, :next_state, :action, :reward, :terminal), Tuple{Int64, Int64, Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, CircularArraySARTSTraces{Tuple{MultiplexTraces{(:state, :next_state), Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Int64}, Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, 5, Tuple{Int64, Int64, Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}}, DataStructures.CircularBuffer{Int64}, DataStructures.CircularBuffer{Bool}}, DummySampler, InsertSampleRatioController, typeof(identity)}}(QBasedPolicy{TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}, EpsilonGreedyExplorer{:linear, false, Random.TaskLocalRNG}}(TDLearner{:SARS, TabularQApproximator{Matrix{Float64}}}(TabularQApproximator{Matrix{Float64}}([0.0 0.0 … 0.0 0.0; 0.0 0.0 … 0.0 0.0]), 1.0, 0.01, 0), EpsilonGreedyExplorer{:linear, false, Random.TaskLocalRNG}(0.01, 1.0, 0, 0, 1, Random.TaskLocalRNG())), Trajectory{EpisodesBuffer{(:state, :next_state, :action, :reward, :terminal), Tuple{Int64, Int64, Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, CircularArraySARTSTraces{Tuple{MultiplexTraces{(:state, :next_state), Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Int64}, Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}, 5, Tuple{Int64, Int64, Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, Trace{CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}}, DataStructures.CircularBuffer{Int64}, DataStructures.CircularBuffer{Bool}}, DummySampler, InsertSampleRatioController, typeof(identity)}(@NamedTuple{state::Int64, next_state::Int64, action::Trace{CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, SubArray{Int64, 0, CircularArrayBuffers.CircularVectorBuffer{Int64, Vector{Int64}}, Tuple{Int64}, true}}, reward::Trace{CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, SubArray{Float64, 0, CircularArrayBuffers.CircularVectorBuffer{Float64, Vector{Float64}}, Tuple{Int64}, true}}, terminal::Trace{CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, SubArray{Bool, 0, CircularArrayBuffers.CircularVectorBuffer{Bool, Vector{Bool}}, Tuple{Int64}, true}}}[], DummySampler(), InsertSampleRatioController(1.0, 1, 0, 0), identity))
julia> parameters_dir = mktempdir()"/tmp/jl_3SBHd8"
julia> run( policy, env, StopAfterNSteps(10_000), DoEveryNSteps(n=1_000) do t, p, e ps = policy.policy.learner.approximator f = joinpath(parameters_dir, "parameters_at_step_$t.jld2") JLD2.@save f ps println("parameters at step $t saved to $f") end ) Progress: 0%| | ETA: 3:11:58parameters at step 1000 saved to /tmp/jl_3SBHd8/parameters_at_step_1000.jld2 parameters at step 2000 saved to /tmp/jl_3SBHd8/parameters_at_step_2000.jld2 parameters at step 3000 saved to /tmp/jl_3SBHd8/parameters_at_step_3000.jld2 parameters at step 4000 saved to /tmp/jl_3SBHd8/parameters_at_step_4000.jld2 parameters at step 5000 saved to /tmp/jl_3SBHd8/parameters_at_step_5000.jld2 parameters at step 6000 saved to /tmp/jl_3SBHd8/parameters_at_step_6000.jld2 parameters at step 7000 saved to /tmp/jl_3SBHd8/parameters_at_step_7000.jld2 parameters at step 8000 saved to /tmp/jl_3SBHd8/parameters_at_step_8000.jld2 parameters at step 9000 saved to /tmp/jl_3SBHd8/parameters_at_step_9000.jld2 Progress: 100%|█████████████████████████████████████████| Time: 0:00:03 parameters at step 10000 saved to /tmp/jl_3SBHd8/parameters_at_step_10000.jld2 DoEveryNSteps{Main.var"#4#5"}(Main.var"#4#5"(), 1000, 10000)