Enriching Offline Reinforcement Learning Algorithms in ReinforcementLearning.jl

This is the phase 1 technical report of the summer OSPP project Enriching Offline Reinforcement Learning Algorithms in ReinforcementLearning.jl used for mid-term evaluation. The report is split into the following parts: Project Information, Project Schedule and Future Plan.


Table of content

  1. Technical Report
    1. Project Information
    2. Project Schedule
        1. Basic framework
        2. Useful Components
          1. GaussianNetwork
          2. Variational Auto-Encoder (VAE)
        3. Offline RL Algorithms
          1. Benchmark
          2. Conservative Q-Learning (CQL)
          3. Critic Regularizer Regression (CRR)
          4. Policy in the Latent Action Space (PLAS)
        4. Other Work
        5. Conclusion
    3. Future Plan

Technical Report

This technical report is the first evaluation report of Project "Enriching Offline Reinforcement Learning Algorithms in ReinforcementLearning.jl" in OSPP. It includes three components: project information, project schedule, future plan.

Project Information

DateWork
Prior - June 30Preliminary research, including algorithm papers, ReinforcementLearning.jl library code, etc.
The first phase
July1 - July15Design and build the framework of offline RL.
July16 - July31Implement and experiment offline DQN and offline SAC as benchmark.
August1 - August15Write build-in documentation and technical report. Implement and experiment CRR.
The second phase
August16 - August31Implement and experiment PLAS.
September1 - September15Research, implement and experiment new SOTA offline RL algorithms.
September16 - September30Write build-in documentation and technical report. Buffer for unexpected delay.
After projectCarry on fixing issues and maintain implemented algorithms.

Project Schedule

This part mainly introduces the results of the first phase.

Basic framework

To run and test the offline algorithm, we first implemented OfflinePolicy.

Base.@kwdef struct OfflinePolicy{L,T} <: AbstractPolicy
    learner::L
    dataset::T
    continuous::Bool
    batch_size::Int
end

This implementation of OfflinePolicy refers to QBasePolicy. It provides a parameter continuous to support different action space types, including continuous and discrete. learner is a specific algorithm for learning and providing policy. dataset and batch_size are used to sample data for learning.

Besides, we implement corresponding functions π, update! and sample. π is used to select the action, whose form is determined by the type of action space. update! can be used in two stages. In PreExperiment stage, we can call this function for pre-training algorithms with pretrain_step parameters. In PreAct stage, we call this function for training the learner. In function update!, we need to call function sample to sample a batch of data from the dataset. With the development of ReinforcementLearningDataset.jl, the sample function will be deprecated.

We can quickly call the offline version of the existing algorithms with almost no additional code with this framework. Therefore, the implementation and performance testing of offline DQN and offline SAC can be completed soon. For example:

offline_dqn_policy = OfflinePolicy(
    learner = DQNLearner(
        # Omit specific code
    ),
    dataset = dataset,
    continuous = false,
    batch_size = 64,
)

Therefore, we unify the parameter name in different algorithms so that different learners can be compatible with OfflinePolicy.

Useful Components

GaussianNetwork

GaussianNetwork models a Normal Distribution N(μ,σ2)\mathcal{N}(\mu,\sigma^2), which is often used in tasks with continuous action space. It consists of three neural network chains:

Base.@kwdef struct GaussianNetwork{P,U,S}
    pre::P = identity
    μ::U
    logσ::S
    min_σ::Float32 = 0f0
    max_σ::Float32 = Inf32
end

We implement the evaluation function and inference function of GaussianNetwork. By evaluation function, given the state, then the mean and log-standard deviation are obtained. Furthermore, we can sample the action from distribution and get the probability of the action in a given state. When calling the inference function with parameter state and action, we get the likelihood of the action in a given state.

### Evaluation
function (model::GaussianNetwork)(state; is_sampling::Bool=false, is_return_log_prob::Bool=false)
    # Omit specific code
    if is_sampling
        if is_return_log_prob
            return tanh.(z), logp_π
        else
            return tanh.(z)
        end
    else
        return μ, logσ
    end
end
### Inference
function (model::GaussianNetwork)(state, action)
    # Omit specific code
    return logp_π
end
Variational Auto-Encoder (VAE)

In offline reinforcement learning tasks, VAE is often used to learn from datasets to approximate behavior policy.

The VAE we implemented contains two neural networks: encoder and decoder (link).

Base.@kwdef struct VAE{E, D}
    encoder::E
    decoder::D
end

In the encoding stage, it accepts input state and action and outputs the mean and standard deviation of the distribution. Afterward, the hidden action is obtained by sampling from the resulted distribution. In the decoding stage, state and hidden action are used as the input to reconstruct action.

During training, we call the vae_loss function to get the reconstruction loss and KL loss. The specific task determines the ratio of these two losses.

function vae_loss(model::VAE, state, action)
    # Omit specific code
    return recon_loss, kl_loss
end

In the specific algorithm, the functions that may need to be called are as follows:

### Encode + decode
function (model::VAE)(state, action)
    ### Omit specific code
    return a, μ, σ
end
### Decode
function decode(model::VAE, state, z)
    ### Omit specific code
    return a
end

Offline RL Algorithms

We used the existing algorithms and hooks to train the offline RL algorithm to create datasets in several environments (such as CartPole, Pendulum) for training. This work can guide the subsequent development of package ReinforcementLearningDataset.jl, for example:

gen_dataset("JuliaRL-CartPole-DQN", policy, env)
Benchmark

We implement and experiment with offline DQN (in discrete action space) and offline SAC (in continuous action space) as benchmarks. The performance of offline DQN in Cartpole environment:

The performance of offline SAC in Pendulum environment:

Conservative Q-Learning (CQL)

CQL is an efficient and straightforward Q-value constraint method. Other offline RL algorithms can easily use this constraint to improve performance. Therefore, we implement CQL as a common component (link). For other algorithms, we only need to add CQL loss to their loss.

function calculate_CQL_loss(q_value, qa_value)
    cql_loss = mean(log.(sum(exp.(q_value), dims=1)) .- qa_value)
    return cql_loss
end
### DQN loss
gs = gradient(params(Q)) do
        q = Q(s)[a]
        loss = loss_func(G, q)
        ignore() do
            learner.loss = loss
        end
        loss + calculate_CQL_loss(Q(s), q)
    end

After adding CQL loss, the performance of offline DQN improve.

Currently, this function only supports discrete action space and CQL(H) method.

Critic Regularizer Regression (CRR)

CRR is a Behavior Cloning based method. To filter out bad actions and enables learning better policies from low-quality data, CRR utilizes the advantage function to regularize the learning objective of the actor. Pseudocode is as follows:

In different tasks, ff has different choices:

𝑓=I[Aθ(s,a)>0]orf=eAθ(s,a)/β 𝑓=\mathbb{I}[A_\theta(s,a)>0]\quad \mathnormal{or}\quad f=e^{A_\theta(s,a)/\beta}

We implemented discrete CRR and continuous CRR (link). The brief function parameters are as follows:

mutable struct CRRLearner{Aq, At, R} <: AbstractLearner
    ### Omit other parameters
    approximator::Aq # Actor-Critic
    target_approximator::At # Actor-Critic
    policy_improvement_mode::Symbol
    ratio_upper_bound::Float32
    beta::Float32
    advantage_estimator::Symbol
    m::Int
    continuous::Bool
end

Parameter continuous stands for the type of action space. policy_improvement_mode is the type of the weight function ff. If policy_improvement_mode=:binary, we use the first ff function. Otherwise, we use the second ff function, which needs parameter ratio_upper_bound (Upper bound of ff value) and beta. Besides, we provide two methods to estimate advantage function, specifing advantage_estimator=:mean/:max. In the discrete case, we can calculate A(s,a)A(s,a) directly. In the continuous case, we need to sample m Q-values to calculate advantage function.

Different action spaces will also affect the implementation of the Actor-Critic. In the discrete case, the Actor outputs logits of all actions in a given state. Gaussian networks are used to model the Actor in the continuous case.

Performance curve of discrete CRR algorithm in CartPole:

The continuous CRR algorithm still has some bugs and poor performance.

Policy in the Latent Action Space (PLAS)

PLAS is a policy constrain method suitable for continuous control tasks. Unlike BCQ and BEAR, PLAS implicitly constrains the policy to output actions within the support of the behavior policy through the latent action space:

PLAS pre-trains a CVAE (Conditional Variational Auto-Encoder) to constrain policy. In the pre-training phase, PLAS samples state-action pairs to train CVAE. PLAS needs to learn a deterministic policy mapping state to latent action and then uses CVAE mapping latent action to action in the training phase. When PLAS mapping state or latent action, it needs to use tanh function to limit the output range.

The advantage of pre-training VAE is that it can accelerate the convergence, and it is easier to train when encountered with complex action spaces and import existing VAE models. Its pseudocode is as follows:

Please refer to this link for specific code (link). The brief function parameters are as follows:

mutable struct PLASLearner{BA1, BA2, BC1, BC2, V, R} <: AbstractLearner
    ### Omit other parameters
    policy::BA1
    target_policy::BA2
    qnetwork1::BC1
    qnetwork2::BC2
    target_qnetwork1::BC1
    target_qnetwork2::BC2
    vae::V
    λ::Float32
    pretrain_step::Int
end

If the algorithm requires pre-training, please specify the parameter pretrain_step and function update!. We modified the run function and added an interface:

function (agent::Agent)(stage::PreExperimentStage, env::AbstractEnv)
    update!(agent.policy, agent.trajectory, env, stage)
end

function RLBase.update!(p::OfflinePolicy, traj::AbstractTrajectory, ::AbstractEnv, ::PreExperimentStage)
    l = p.learner
    if in(:pretrain_step, fieldnames(typeof(l)))
        println("Pretrain...")
        for _ in 1:l.pretrain_step
            inds, batch = sample(l.rng, p.dataset, p.batch_size)
            update!(l, batch)
        end
    end
end

In PLAS, we use conditional statements to select training components:

function RLBase.update!(l::PLASLearner, batch::NamedTuple{SARTS})
    if l.update_step == 0
        update_vae!(l, batch)
    else
        update_learner!(l, batch)
    end
end

λ is the parameter of clipped double Q-learning (used for Critic training), a small trick to reduce overestimation. Actor training uses the standard policy gradient method.

Performance curve of PLAS algorithm in Pendulum (pertrain_step=1000):

However, the action perturbation component in PLAS has not yet been completed and needs to be implemented in the second stage.

Other Work

In addition to the above work, we also did the following:

Conclusion

During this process, we learn a lot:

Future Plan

The following is our future plan:

DateWork
August16 - August23Debug and finish CRR and PLAS.
August24 - August31Read the paper and python code of UWAC.
September1 - September7Implement and experiment UWAC.
September8 - September15Read the paper and python code of FisherBRC.
September16 - September23Implement and experiment FisherBRC.
September24 - September30Write build-in documentation and technical report. Buffer for unexpected delay.
After projectCarry on fixing issues and maintain implemented algorithms.

Firstly, we need to fix bugs in continuous CRR and finish action perturbation component in PLAS. The current progress is slightly faster than the originally set progress, so we can implement more of the modern offline RL algorithms. The current plan includes UWAC and FisherBRC published on ICML'21. Here we briefly introduce these two algorithms:

In this way, the implemented algorithms basically include the mainstream of the policy constraint method in offline reinforcement learning (including distribution matching, support constrain, implicit constraint, behavior cloning).

Corrections

If you see mistakes or want to suggest changes, please create an issue in the source repository.