diff --git a/src/implementations/environments.jl b/src/implementations/environments.jl index ed2b6de..c959a17 100644 --- a/src/implementations/environments.jl +++ b/src/implementations/environments.jl @@ -150,7 +150,7 @@ end MultiThreadEnv(f, n) = MultiThreadEnv([f() for _ in 1:n]) -@forward MultiThreadEnv.envs Base.getindex, Base.length, Base.setindex! +@forward MultiThreadEnv.envs Base.getindex, Base.length, Base.setindex!, Base.iterate function (env::MultiThreadEnv)(actions) @sync for i in 1:length(env) @@ -199,7 +199,7 @@ for f in end get_actions(env::MultiThreadEnv, args...; kwargs...) = - TupleSpace([get_actions(x, args...; kwargs...) for x in env.envs]) + VectSpace([get_actions(x, args...; kwargs...) for x in env.envs]) get_current_player(env::MultiThreadEnv) = [get_current_player(x) for x in env.envs] function Base.show(io::IO, t::MIME"text/markdown", env::MultiThreadEnv) diff --git a/src/implementations/policies.jl b/src/implementations/policies.jl index 31bed44..daea849 100644 --- a/src/implementations/policies.jl +++ b/src/implementations/policies.jl @@ -1,4 +1,4 @@ -export RandomPolicy +export RandomPolicy, RandomStartPolicy using Random @@ -35,7 +35,6 @@ RandomPolicy(::FullActionSet, env::AbstractEnv, rng) = RandomPolicy(nothing, rng (p::RandomPolicy{Nothing})(env) = rand(p.rng, get_legal_actions(env)) (p::RandomPolicy)(env) = rand(p.rng, p.action_space) -(p::RandomPolicy)(env::MultiThreadEnv) = [p(x) for x in env] # TODO: TBD # Ideally we should return a Categorical distribution. @@ -52,6 +51,7 @@ function get_prob(p::RandomPolicy{Nothing}, env) end get_prob(p::RandomPolicy, env, a) = 1 / length(p.action_space) +get_prob(p::RandomPolicy{<:VectSpace}, env::MultiThreadEnv, a) = [1/length(x) for x in p.action_space.data] function get_prob(p::RandomPolicy{Nothing}, env, a) legal_actions = get_legal_actions(env) @@ -61,3 +61,34 @@ function get_prob(p::RandomPolicy{Nothing}, env, a) 0.0 end end + +##### +# RandomStartPolicy +##### + +Base.@kwdef mutable struct RandomStartPolicy{P,R<:RandomPolicy} <: AbstractPolicy + policy::P + random_policy::R + num_rand_start::Int +end + +function (p::RandomStartPolicy)(env) + p.num_rand_start -= 1 + if p.num_rand_start < 0 + p.policy(env) + else + p.random_policy(env) + end +end + +update!(p::RandomStartPolicy, experience) = update!(p.policy, experience) + +for f in (:get_prob, :get_priority) + @eval function $f(p::RandomStartPolicy, args...) + if p.num_rand_start < 0 + $f(p.policy, args...) + else + $f(p.random_policy, args...) + end + end +end \ No newline at end of file diff --git a/src/implementations/spaces/spaces.jl b/src/implementations/spaces/spaces.jl index fddf2c7..dc3384f 100644 --- a/src/implementations/spaces/spaces.jl +++ b/src/implementations/spaces/spaces.jl @@ -3,5 +3,5 @@ include("multi_continuous_space.jl") include("discrete_space.jl") include("empty_space.jl") include("multi_discrete_space.jl") -include("tuple_space.jl") +include("vect_space.jl") include("dict_space.jl") diff --git a/src/implementations/spaces/tuple_space.jl b/src/implementations/spaces/tuple_space.jl deleted file mode 100644 index 5b9dbbf..0000000 --- a/src/implementations/spaces/tuple_space.jl +++ /dev/null @@ -1,10 +0,0 @@ -export TupleSpace - -struct TupleSpace{T} <: AbstractSpace - data::T -end - -Base.eltype(s::TupleSpace) = Tuple{(eltype(x) for x in s.data)...} -Base.in(xs, s::TupleSpace) = - length(xs) == length(s.data) && all(x in d for (x, d) in zip(xs, s.data)) -Random.rand(rng::AbstractRNG, s::TupleSpace) = Tuple(rand(rng, d) for d in s.data) diff --git a/src/implementations/spaces/vect_space.jl b/src/implementations/spaces/vect_space.jl new file mode 100644 index 0000000..e0b54a0 --- /dev/null +++ b/src/implementations/spaces/vect_space.jl @@ -0,0 +1,10 @@ +export VectSpace + +struct VectSpace{T} <: AbstractSpace + data::T +end + +Base.eltype(s::VectSpace) = Vector{eltype(s.data[1])} +Base.in(xs, s::VectSpace) = + length(xs) == length(s.data) && all(x in d for (x, d) in zip(xs, s.data)) +Random.rand(rng::AbstractRNG, s::VectSpace) = [rand(rng, d) for d in s.data] diff --git a/test/spaces.jl b/test/spaces.jl index 4e4d155..f354ca9 100644 --- a/test/spaces.jl +++ b/test/spaces.jl @@ -81,14 +81,14 @@ test_samples(s) end - @testset "TupleSpace and DictSpace" begin - s = TupleSpace([ + @testset "VectSpace and DictSpace" begin + s = VectSpace([ DiscreteSpace(3), ContinuousSpace(0.0, 1.0), - TupleSpace([DiscreteSpace(3), ContinuousSpace(0.0, 1.0)]), # recursive + VectSpace([DiscreteSpace(3), ContinuousSpace(0.0, 1.0)]), # recursive DictSpace( :a => MultiDiscreteSpace([2.0, 4.0]), - :b => TupleSpace([ + :b => VectSpace([ MultiContinuousSpace([-1, -2], [2.5, 3.5]), MultiDiscreteSpace([3, 2]), ]),