JuliaRL_ED_OpenSpiel(kuhn_poker)

Source code Author Update time

using ReinforcementLearning
using StableRNGs
using OpenSpiel
using Flux

mutable struct KuhnOpenEDHook <: AbstractHook
    results::Vector{Float64}
end

function (hook::KuhnOpenEDHook)(::PreEpisodeStage, policy, env)
    # get nash_conv of the current policy.
    push!(hook.results, RLZoo.nash_conv(policy, env))

    # update agents' learning rate.
    for (_, agent) in policy.agents
        agent.learner.optimizer[2].eta = 1.0 / sqrt(length(hook.results))
    end
end

function (hook::KuhnOpenEDHook)(::PostExperimentStage, policy, env)
    reset!(env)

    # get nash_conv of the latest model.
    push!(hook.results, RLZoo.nash_conv(policy, env))
end

function RL.Experiment(
    ::Val{:JuliaRL},
    ::Val{:ED},
    ::Val{:OpenSpiel},
    game;
    seed = 123,
)
    rng = StableRNG(seed)

    env = OpenSpielEnv(game)
    wrapped_env = ActionTransformedEnv(
        env,
        action_mapping = a -> RLBase.current_player(env) == chance_player(env) ? a : Int(a - 1),
        action_space_mapping = as -> RLBase.current_player(env) == chance_player(env) ?
            as : Base.OneTo(num_distinct_actions(env.game)),
    )
    wrapped_env = DefaultStateStyleEnv{InformationSet{Array}()}(wrapped_env)
    player = 0 # or 1
    ns, na = length(state(wrapped_env, player)), length(action_space(wrapped_env, player))

    create_network() = Chain(
        Dense(ns, 64, relu;init = glorot_uniform(rng)),
        Dense(64, na;init = glorot_uniform(rng))
    )

    create_learner() = NeuralNetworkApproximator(
        model = create_network(),
        optimizer = Flux.Optimise.Optimiser(WeightDecay(0.001), Descent())
    )

    EDmanager = EDManager(
        Dict(
            player => EDPolicy(
                1 - player, # opponent
                create_learner(), # neural network learner
                WeightedSoftmaxExplorer(), # explorer
            ) for player in players(env) if player != chance_player(env)
        )
    )

    stop_condition = StopAfterEpisode(500, is_show_progress=!haskey(ENV, "CI"))
    hook = KuhnOpenEDHook([])

    Experiment(EDmanager, wrapped_env, stop_condition, hook, "# play OpenSpiel $game with ED algorithm")
end

using Plots
ex = E`JuliaRL_ED_OpenSpiel(kuhn_poker)`
run(ex)
plot(ex.hook.results, xlabel="episode", ylabel="nash_conv")
โ”Œ Warning: unexpected player -1, falling back to default state value.
โ”” @ ReinforcementLearningEnvironments ~/work/ReinforcementLearning.jl/ReinforcementLearning.jl/src/ReinforcementLearningEnvironments/src/environments/3rd_party/open_spiel.jl:127


This page was generated using DemoCards.jl and Literate.jl.