Skip to content
This repository has been archived by the owner on Aug 11, 2023. It is now read-only.

Add random start policy #62

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/implementations/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ end

MultiThreadEnv(f, n) = MultiThreadEnv([f() for _ in 1:n])

@forward MultiThreadEnv.envs Base.getindex, Base.length, Base.setindex!
@forward MultiThreadEnv.envs Base.getindex, Base.length, Base.setindex!, Base.iterate

function (env::MultiThreadEnv)(actions)
@sync for i in 1:length(env)
Expand Down Expand Up @@ -199,7 +199,7 @@ for f in
end

get_actions(env::MultiThreadEnv, args...; kwargs...) =
TupleSpace([get_actions(x, args...; kwargs...) for x in env.envs])
VectSpace([get_actions(x, args...; kwargs...) for x in env.envs])
get_current_player(env::MultiThreadEnv) = [get_current_player(x) for x in env.envs]

function Base.show(io::IO, t::MIME"text/markdown", env::MultiThreadEnv)
Expand Down
35 changes: 33 additions & 2 deletions src/implementations/policies.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export RandomPolicy
export RandomPolicy, RandomStartPolicy

using Random

Expand Down Expand Up @@ -35,7 +35,6 @@ RandomPolicy(::FullActionSet, env::AbstractEnv, rng) = RandomPolicy(nothing, rng

(p::RandomPolicy{Nothing})(env) = rand(p.rng, get_legal_actions(env))
(p::RandomPolicy)(env) = rand(p.rng, p.action_space)
(p::RandomPolicy)(env::MultiThreadEnv) = [p(x) for x in env]

# TODO: TBD
# Ideally we should return a Categorical distribution.
Expand All @@ -52,6 +51,7 @@ function get_prob(p::RandomPolicy{Nothing}, env)
end

get_prob(p::RandomPolicy, env, a) = 1 / length(p.action_space)
get_prob(p::RandomPolicy{<:VectSpace}, env::MultiThreadEnv, a) = [1/length(x) for x in p.action_space.data]

function get_prob(p::RandomPolicy{Nothing}, env, a)
legal_actions = get_legal_actions(env)
Expand All @@ -61,3 +61,34 @@ function get_prob(p::RandomPolicy{Nothing}, env, a)
0.0
end
end

#####
# RandomStartPolicy
#####

Base.@kwdef mutable struct RandomStartPolicy{P,R<:RandomPolicy} <: AbstractPolicy
policy::P
random_policy::R
num_rand_start::Int
end

function (p::RandomStartPolicy)(env)
p.num_rand_start -= 1
if p.num_rand_start < 0
p.policy(env)
else
p.random_policy(env)
end
end

update!(p::RandomStartPolicy, experience) = update!(p.policy, experience)

for f in (:get_prob, :get_priority)
@eval function $f(p::RandomStartPolicy, args...)
if p.num_rand_start < 0
$f(p.policy, args...)
else
$f(p.random_policy, args...)
end
end
end
2 changes: 1 addition & 1 deletion src/implementations/spaces/spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ include("multi_continuous_space.jl")
include("discrete_space.jl")
include("empty_space.jl")
include("multi_discrete_space.jl")
include("tuple_space.jl")
include("vect_space.jl")
include("dict_space.jl")
10 changes: 0 additions & 10 deletions src/implementations/spaces/tuple_space.jl

This file was deleted.

10 changes: 10 additions & 0 deletions src/implementations/spaces/vect_space.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
export VectSpace

struct VectSpace{T} <: AbstractSpace
data::T
end

Base.eltype(s::VectSpace) = Vector{eltype(s.data[1])}
Base.in(xs, s::VectSpace) =
length(xs) == length(s.data) && all(x in d for (x, d) in zip(xs, s.data))
Random.rand(rng::AbstractRNG, s::VectSpace) = [rand(rng, d) for d in s.data]
8 changes: 4 additions & 4 deletions test/spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,14 @@
test_samples(s)
end

@testset "TupleSpace and DictSpace" begin
s = TupleSpace([
@testset "VectSpace and DictSpace" begin
s = VectSpace([
DiscreteSpace(3),
ContinuousSpace(0.0, 1.0),
TupleSpace([DiscreteSpace(3), ContinuousSpace(0.0, 1.0)]), # recursive
VectSpace([DiscreteSpace(3), ContinuousSpace(0.0, 1.0)]), # recursive
DictSpace(
:a => MultiDiscreteSpace([2.0, 4.0]),
:b => TupleSpace([
:b => VectSpace([
MultiContinuousSpace([-1, -2], [2.5, 3.5]),
MultiDiscreteSpace([3, 2]),
]),
Expand Down