How to use hooks?

What are the hooks?

During the interactions between agents and environments, we often want to collect some useful information. One straightforward approach is the imperative programming. We write the code in a loop and execute them step by step.

while true
    env |> policy |> env
    # write your own logic here
    # like saving parameters, recording loss function, evaluating policy, etc.
    stop_condition(env, policy) && break
    is_terminated(env) && reset!(env)
end

The benifit of this approach is the great clarity. You are responsible for what you write. And this is the encouraged approach for new users to try different components in this package.

Another approach is the declarative programming. We describe when and what we want to do during an experiment. Then put them together with the agent and environment. Finally we execute the run command to conduct our experiment. In this way, we can reuse some common hooks and execution pipelines instead of writing many duplicate codes. In many existing reinforcement learning python packages, people usually use a set of configuration files to define the execution pipeline. However, we believe this is not necessary in Julia. With the declarative programming approach, we gain much more flexibilities.

Now the question is how to design the hook. A natural choice is to wrap the comments part in the above pseudocode into a function:

while true
    env |> policy |> env
    hook(policy, env)
    stop_condition(env, policy) && break
    is_terminated(env) && reset!(env)
end

But sometimes, we'd like to have a more fingrained control. So we split the calling of hooks into several different stages:

How to define a customized hook?

By default, an instance of AbstractHook will do nothing when called with (hook::AbstractHook)(::AbstractStage, policy, env). So when writing a customized hook, you only need to implement the necessary runtime logic.

For example, assume we want to record the wall time of each episode.

julia> using ReinforcementLearning

julia> Base.@kwdef mutable struct TimeCostPerEpisode <: AbstractHook
           t::UInt64 = time_ns()
           time_costs::Vector{UInt64} = []
       end
Main.ex-how_to_use_hooks.TimeCostPerEpisode

julia> (h::TimeCostPerEpisode)(::PreEpisodeStage, policy, env) = h.t = time_ns()

julia> (h::TimeCostPerEpisode)(::PostEpisodeStage, policy, env) = push!(h.time_costs, time_ns()-h.t)

julia> h = TimeCostPerEpisode()
Main.ex-how_to_use_hooks.TimeCostPerEpisode(0x000001bbfc6938ed, UInt64[])

julia> run(RandomPolicy(), CartPoleEnv(), StopAfterEpisode(10), h)
Main.ex-how_to_use_hooks.TimeCostPerEpisode(0x000001bc30996c37, UInt64[0x0000000000000b55, 0x0000000000000708, 0x0000000000000514, 0x00000000000005dc, 0x00000000000002bc, 0x000000000000076c, 0x00000000000003e8, 0x000000000000044c, 0x0000000000000384, 0x00000000000003e8])

julia> h.time_costs
10-element Vector{UInt64}:
 0x0000000000000b55
 0x0000000000000708
 0x0000000000000514
 0x00000000000005dc
 0x00000000000002bc
 0x000000000000076c
 0x00000000000003e8
 0x000000000000044c
 0x0000000000000384
 0x00000000000003e8

Common hooks

Periodic jobs

Sometimes, we'd like to periodically run some functions. Two handy hooks are provided for this kind of tasks:

Following are some typical usages.

Evaluating policy during training

julia> using Statistics: mean

julia> policy = RandomPolicy()
typename(RandomPolicy)
├─ action_space => typename(Nothing)
└─ rng => typename(Random._GLOBAL_RNG)

julia> run(
           policy,
           CartPoleEnv(),
           StopAfterEpisode(100),
           DoEveryNEpisode(;n=10) do t, policy, env
               # In real world cases, the policy is usually wrapped in an Agent,
               # we need to extract the inner policy to run it in the *actor* mode.
               # Here for illustration only, we simply use the origina policy.
       
               # Note that we create a new instance of CartPoleEnv here to avoid
               # polluting the original env.
       
               hook = TotalRewardPerEpisode(;is_display_on_exit=false)
               run(policy, CartPoleEnv(), StopAfterEpisode(10), hook)
       
               # now you can report the result of the hook.
               println("avg reward at episode $t is: $(mean(hook.rewards))")
           end
       )
avg reward at episode 10 is: 23.1
avg reward at episode 20 is: 24.0
avg reward at episode 30 is: 21.2
avg reward at episode 40 is: 21.0
avg reward at episode 50 is: 21.7
avg reward at episode 60 is: 25.3
avg reward at episode 70 is: 22.9
avg reward at episode 80 is: 19.5
avg reward at episode 90 is: 15.2
avg reward at episode 100 is: 18.5
DoEveryNEpisode{PostEpisodeStage, Main.ex-how_to_use_hooks.var"#2#3"}(Main.ex-how_to_use_hooks.var"#2#3"(), 10, 100)

Save parameters

BSON.jl is recommended to save the parameters of a policy.

julia> using Flux

julia> using Flux.Losses: huber_loss

julia> using BSON

julia> env = CartPoleEnv(; T = Float32)
# CartPoleEnv

## Traits

| Trait Type        |                  Value |
|:----------------- | ----------------------:|
| NumAgentStyle     |          SingleAgent() |
| DynamicStyle      |           Sequential() |
| InformationStyle  | ImperfectInformation() |
| ChanceStyle       |           Stochastic() |
| RewardStyle       |           StepReward() |
| UtilityStyle      |           GeneralSum() |
| ActionStyle       |     MinimalActionSet() |
| StateStyle        |     Observation{Any}() |
| DefaultStateStyle |     Observation{Any}() |

## Is Environment Terminated?

No

## Action Space

`Base.OneTo(2)`

julia> ns, na = length(state(env)), length(action_space(env))
(4, 2)

julia> policy = Agent(
           policy = QBasedPolicy(
               learner = BasicDQNLearner(
                   approximator = NeuralNetworkApproximator(
                       model = Chain(
                           Dense(ns, 128, relu; init = glorot_uniform),
                           Dense(128, 128, relu; init = glorot_uniform),
                           Dense(128, na; init = glorot_uniform),
                       ) |> cpu,
                       optimizer = ADAM(),
                   ),
                   batch_size = 32,
                   min_replay_history = 100,
                   loss_func = huber_loss,
               ),
               explorer = EpsilonGreedyExplorer(
                   kind = :exp,
                   ϵ_stable = 0.01,
                   decay_steps = 500,
               ),
           ),
           trajectory = CircularArraySARTTrajectory(
               capacity = 1000,
               state = Vector{Float32} => (ns,),
           ),
       )
typename(Agent)
├─ policy => typename(QBasedPolicy)
│  ├─ learner => typename(BasicDQNLearner)
│  │  ├─ approximator => typename(NeuralNetworkApproximator)
│  │  │  ├─ model => typename(Flux.Chain)
│  │  │  │  └─ layers
│  │  │  │     ├─ 1
│  │  │  │     │  └─ typename(Flux.Dense)
│  │  │  │     │     ├─ weight => 128×4 Matrix{Float32}
│  │  │  │     │     ├─ bias => 128-element Vector{Float32}
│  │  │  │     │     └─ σ => typename(typeof(NNlib.relu))
│  │  │  │     ├─ 2
│  │  │  │     │  └─ typename(Flux.Dense)
│  │  │  │     │     ├─ weight => 128×128 Matrix{Float32}
│  │  │  │     │     ├─ bias => 128-element Vector{Float32}
│  │  │  │     │     └─ σ => typename(typeof(NNlib.relu))
│  │  │  │     └─ 3
│  │  │  │        └─ typename(Flux.Dense)
│  │  │  │           ├─ weight => 2×128 Matrix{Float32}
│  │  │  │           ├─ bias => 2-element Vector{Float32}
│  │  │  │           └─ σ => typename(typeof(identity))
│  │  │  └─ optimizer => typename(Flux.Optimise.ADAM)
│  │  │     ├─ eta => 0.001
│  │  │     ├─ beta
│  │  │     │  ├─ 1
│  │  │     │  │  └─ 0.9
│  │  │     │  └─ 2
│  │  │     │     └─ 0.999
│  │  │     └─ state => typename(IdDict)
│  │  ├─ loss_func => typename(typeof(Flux.Losses.huber_loss))
│  │  ├─ γ => 0.99
│  │  ├─ sampler => typename(BatchSampler)
│  │  │  ├─ batch_size => 32
│  │  │  ├─ cache => typename(Nothing)
│  │  │  └─ rng => typename(Random._GLOBAL_RNG)
│  │  ├─ min_replay_history => 100
│  │  ├─ rng => typename(Random._GLOBAL_RNG)
│  │  └─ loss => 0.0
│  └─ explorer => typename(EpsilonGreedyExplorer)
│     ├─ ϵ_stable => 0.01
│     ├─ ϵ_init => 1.0
│     ├─ warmup_steps => 0
│     ├─ decay_steps => 500
│     ├─ step => 1
│     ├─ rng => typename(Random._GLOBAL_RNG)
│     └─ is_training => true
└─ trajectory => typename(Trajectory)
   └─ traces => typename(NamedTuple)
      ├─ state => 4×0 CircularArrayBuffers.CircularArrayBuffer{Float32, 2}
      ├─ action => 0-element CircularArrayBuffers.CircularVectorBuffer{Int64}
      ├─ reward => 0-element CircularArrayBuffers.CircularVectorBuffer{Float32}
      └─ terminal => 0-element CircularArrayBuffers.CircularVectorBuffer{Bool}

julia> parameters_dir = mktempdir()
"/tmp/jl_LdnQRQ"

julia> run(
           policy,
           env,
           StopAfterStep(10_000),
           DoEveryNStep(n=1_000) do t, p, e
               ps = params(p)
               f = joinpath(parameters_dir, "parameters_at_step_$t.bson")
               BSON.@save f ps
               println("parameters at step $t saved to $f")
           end
       )

Progress:   1%|▍                                        |  ETA: 0:01:45parameters at step 1000 saved to /tmp/jl_LdnQRQ/parameters_at_step_1000.bson

Progress:  11%|████▌                                    |  ETA: 0:00:41
Progress:  16%|██████▊                                  |  ETA: 0:00:31parameters at step 2000 saved to /tmp/jl_LdnQRQ/parameters_at_step_2000.bson

Progress:  22%|█████████▏                               |  ETA: 0:00:25
Progress:  28%|███████████▍                             |  ETA: 0:00:21parameters at step 3000 saved to /tmp/jl_LdnQRQ/parameters_at_step_3000.bson

Progress:  33%|█████████████▌                           |  ETA: 0:00:19
Progress:  38%|███████████████▌                         |  ETA: 0:00:17parameters at step 4000 saved to /tmp/jl_LdnQRQ/parameters_at_step_4000.bson

Progress:  43%|█████████████████▋                       |  ETA: 0:00:15
Progress:  48%|███████████████████▊                     |  ETA: 0:00:13parameters at step 5000 saved to /tmp/jl_LdnQRQ/parameters_at_step_5000.bson

Progress:  53%|█████████████████████▊                   |  ETA: 0:00:12
Progress:  58%|███████████████████████▊                 |  ETA: 0:00:10parameters at step 6000 saved to /tmp/jl_LdnQRQ/parameters_at_step_6000.bson

Progress:  63%|█████████████████████████▊               |  ETA: 0:00:09
Progress:  68%|███████████████████████████▉             |  ETA: 0:00:08parameters at step 7000 saved to /tmp/jl_LdnQRQ/parameters_at_step_7000.bson

Progress:  73%|█████████████████████████████▉           |  ETA: 0:00:06
Progress:  78%|███████████████████████████████▉         |  ETA: 0:00:05parameters at step 8000 saved to /tmp/jl_LdnQRQ/parameters_at_step_8000.bson

Progress:  83%|█████████████████████████████████▉       |  ETA: 0:00:04
Progress:  87%|███████████████████████████████████▉     |  ETA: 0:00:03parameters at step 9000 saved to /tmp/jl_LdnQRQ/parameters_at_step_9000.bson

Progress:  92%|█████████████████████████████████████▊   |  ETA: 0:00:02
Progress:  97%|███████████████████████████████████████▋ |  ETA: 0:00:01
Progress: 100%|█████████████████████████████████████████| Time: 0:00:23
parameters at step 10000 saved to /tmp/jl_LdnQRQ/parameters_at_step_10000.bson
DoEveryNStep{Main.ex-how_to_use_hooks.var"#4#5"}(Main.ex-how_to_use_hooks.var"#4#5"(), 1000, 10000)

Logging data

Below we demonstrate how to use TensorBoardLogger.jl to log runtime metrics. But users could also other tools like wandb through PyCall.jl.

julia> using TensorBoardLogger

julia> using Logging

julia> tf_log_dir = "logs"
"logs"

julia> lg = TBLogger(tf_log_dir, min_level = Logging.Info)
TBLogger:
	- Log level     : Info
	- Current step  : 0
	- Output        : /home/runner/work/ReinforcementLearning.jl/ReinforcementLearning.jl/docs/build/logs
	- open files    : 1

julia> total_reward_per_episode = TotalRewardPerEpisode()
TotalRewardPerEpisode(Float64[], 0.0, true)

julia> hook = ComposedHook(
           total_reward_per_episode,
           DoEveryNEpisode() do t, agent, env
               with_logger(lg) do
                   @info "training"  reward = total_reward_per_episode.rewards[end]
               end
           end
       )
ComposedHook{Tuple{TotalRewardPerEpisode, DoEveryNEpisode{PostEpisodeStage, Main.ex-how_to_use_hooks.var"#6#8"}}}((TotalRewardPerEpisode(Float64[], 0.0, true), DoEveryNEpisode{PostEpisodeStage, Main.ex-how_to_use_hooks.var"#6#8"}(Main.ex-how_to_use_hooks.var"#6#8"(), 1, 0)))

julia> run(RandomPolicy(), CartPoleEnv(), StopAfterEpisode(50), hook)
                    Total reward per episode
            ┌────────────────────────────────────────┐ 
         60 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠀⠀⠀│ 
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢸⠀⠀⠀│ 
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣼⠀⠀⠀│ 
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡏⡆⠀⠀│ 
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⡇⡇⠀⠀│ 
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣸⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⢠⠃⡇⠀⠀│ 
   Score    │⠀⠀⠀⠀⠀⠀⠀⡀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⠀⠀⠀⠀⠀⠀⠀⣸⠀⠀⠀⠀⣷⢀⡄⠀⠀⢸⠀⡇⢀⡇│ 
            │⠀⠀⠀⠀⠀⠀⡜⠱⡄⠀⠀⢠⢤⠀⠀⠀⠀⣿⠀⠀⠀⠀⣧⠀⠀⡇⡇⠀⠀⠀⡇⠎⡇⢀⢆⢸⠀⡇⢸⢇│ 
            │⢸⠀⠀⠀⠀⠀⡇⠀⡇⠀⠀⢸⠈⡆⠀⠀⠀⡇⡇⠀⠀⠀⣿⠀⠀⡇⢱⠀⠀⠀⡇⠀⡇⡎⢸⡎⠀⡇⡇⢸│ 
            │⠀⡇⠀⢰⡇⢸⠀⠀⡇⠀⠀⡸⠀⢱⠀⠀⢰⠁⡇⠀⠀⢰⢹⢀⣤⠃⠘⡄⠀⢸⠀⠀⢣⡇⠈⡇⠀⢇⠇⢸│ 
            │⠀⡇⢀⠇⠸⠼⠀⠀⢸⠀⡷⡇⠀⠈⣆⡀⢸⠀⡇⢠⢤⢸⠀⡟⠹⠀⠀⢇⢀⢸⠀⠀⢸⠃⠀⠀⠀⢸⠀⠸│ 
            │⠀⠘⠞⠀⠀⠀⠀⠀⠘⢼⠀⠀⠀⠀⠁⠈⠃⠀⠉⠃⠀⠙⠀⠀⠀⠀⠀⠘⠎⠻⠀⠀⠘⠀⠀⠀⠀⠈⠀⠀│ 
            │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
          0 │⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀│ 
            └────────────────────────────────────────┘ 
            0                                       50
                             Episode
ComposedHook{Tuple{TotalRewardPerEpisode, DoEveryNEpisode{PostEpisodeStage, Main.ex-how_to_use_hooks.var"#6#8"}}}((TotalRewardPerEpisode([24.0, 11.0, 10.0, 15.0, 20.0, 14.0, 14.0, 26.0, 29.0, 26.0  …  11.0, 25.0, 27.0, 17.0, 35.0, 54.0, 12.0, 23.0, 32.0, 14.0], 0.0, true), DoEveryNEpisode{PostEpisodeStage, Main.ex-how_to_use_hooks.var"#6#8"}(Main.ex-how_to_use_hooks.var"#6#8"(), 1, 50)))

julia> readdir(tf_log_dir)
1-element Vector{String}:
 "events.out.tfevents.1.623688698735635e9.fv-az219-313"

Then run tensorboard --logdir logs and open the link on the screen in your browser. (Obviously you need to install tensorboard first.)