Skip to content
This repository has been archived by the owner on Aug 11, 2023. It is now read-only.

Automatic JuliaFormatter.jl run #91

Merged
merged 1 commit into from
Oct 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/CommonRLInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ CRL.@provide CRL.clone(env::CommonRLEnvs) = CommonRLEnv(copy(env.env))

CRL.@provide function CRL.act!(env::CommonRLEnvs, a)
env.env(a)
get_reward(env.env)
return get_reward(env.env)
end

CRL.valid_actions(x::CommonRLEnvs) = get_legal_actions(x.env)
Expand Down
32 changes: 16 additions & 16 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,33 +31,33 @@ end
#####

function get_env_traits()
[eval(x) for x in RLBase.ENV_API if endswith(String(x), "Style")]
return [eval(x) for x in RLBase.ENV_API if endswith(String(x), "Style")]
end

Base.show(io::IO, t::MIME"text/plain", env::AbstractEnv) =
show(io, MIME"text/markdown"(), env)

function Base.show(io::IO, t::MIME"text/markdown", env::AbstractEnv)
show(io, t, Markdown.parse("""
# $(get_name(env))
return show(io, t, Markdown.parse("""
# $(get_name(env))

## Traits
| Trait Type | Value |
|:---------- | ----- |
$(join(["|$(string(f))|$(f(env))|" for f in get_env_traits()], "\n"))
## Traits
| Trait Type | Value |
|:---------- | ----- |
$(join(["|$(string(f))|$(f(env))|" for f in get_env_traits()], "\n"))

## Actions
$(get_actions(env))
## Actions
$(get_actions(env))

## Players
$(join(["- `$p`" for p in get_players(env)], "\n"))
## Players
$(join(["- `$p`" for p in get_players(env)], "\n"))

## Current Player
`$(get_current_player(env))`
## Current Player
`$(get_current_player(env))`

## Is Environment Terminated?
$(get_terminal(env) ? "Yes" : "No")
"""))
## Is Environment Terminated?
$(get_terminal(env) ? "Yes" : "No")
"""))
end

#####
Expand Down
45 changes: 22 additions & 23 deletions src/implementations/environments.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ StateCachedEnv(env) = StateCachedEnv(get_state(env), env, true)

function (env::StateCachedEnv)(args...; kwargs...)
env.env(args...; kwargs...)
env.is_state_cached = false
return env.is_state_cached = false
end

function get_state(env::StateCachedEnv, args...; kwargs...)
Expand Down Expand Up @@ -148,7 +148,6 @@ end
get_reward(env::RewardOverriddenEnv, args...; kwargs...) =
foldl(|>, env.processors; init = get_reward(env.env, args...; kwargs...))


#####
# MaxTimeoutEnv
#####
Expand All @@ -164,7 +163,7 @@ MaxTimeoutEnv(env::E, max_t::Int; current_t::Int = 1) where {E<:AbstractEnv} =

function (env::MaxTimeoutEnv)(args...; kwargs...)
env.env(args...; kwargs...)
env.current_t = env.current_t + 1
return env.current_t = env.current_t + 1
end

# partial constructor to allow chaining
Expand Down Expand Up @@ -280,7 +279,7 @@ for f in
for i in 2:n
selectdim(cache, ndims(cache), i) .= $f(env[i], args...; kwargs...)
end
cache
return cache
end
end

Expand All @@ -289,34 +288,34 @@ get_actions(env::MultiThreadEnv, args...; kwargs...) =
get_current_player(env::MultiThreadEnv) = [get_current_player(x) for x in env.envs]

function Base.show(io::IO, t::MIME"text/markdown", env::MultiThreadEnv)
show(io, t, Markdown.parse("""
# MultiThreadEnv
return show(io, t, Markdown.parse("""
# MultiThreadEnv

## Num of threads
## Num of threads

$(Threads.nthreads())
$(Threads.nthreads())

## Num of inner environments
## Num of inner environments

$(length(env.envs)) replicates of `$(get_name(env.envs[1]))`
$(length(env.envs)) replicates of `$(get_name(env.envs[1]))`

## Traits of inner environment
| Trait Type | Value |
|:---------- | ----- |
$(join(["|$(string(f))|$(f(env))|" for f in get_env_traits()], "\n"))
## Traits of inner environment
| Trait Type | Value |
|:---------- | ----- |
$(join(["|$(string(f))|$(f(env))|" for f in get_env_traits()], "\n"))

## Actions of inner environment
$(get_actions(env[1]))
## Actions of inner environment
$(get_actions(env[1]))

## Players
$(join(["- `$p`" for p in get_players(env.envs[1])], "\n"))
## Players
$(join(["- `$p`" for p in get_players(env.envs[1])], "\n"))

## Current Player
$(join(["`$x`" for x in get_current_player(env)], ","))
## Current Player
$(join(["`$x`" for x in get_current_player(env)], ","))

## Is Environment Terminated?
$(get_terminal(env))
"""))
## Is Environment Terminated?
$(get_terminal(env))
"""))
end

# !!! some might not be meaningful, use with caution.
Expand Down
2 changes: 1 addition & 1 deletion src/implementations/policies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ function get_prob(p::RandomPolicy{Nothing}, env, ::AbstractChanceStyle)
n = sum(mask)
prob = zeros(length(mask))
prob[mask] .= 1 / n
prob
return prob
end

function get_prob(p::RandomPolicy{Nothing}, env, ::ExplicitStochastic)
Expand Down
2 changes: 1 addition & 1 deletion src/implementations/spaces/continuous_space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ struct ContinuousSpace{T<:Number} <: AbstractSpace
high::T
function ContinuousSpace(low::T, high::T) where {T<:Number}
low < high || throw(ArgumentError("$low must be less than $high"))
new{T}(low, high)
return new{T}(low, high)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/implementations/spaces/dict_space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ end
function DictSpace(ps::Pair{<:Union{Symbol,AbstractString},<:AbstractSpace}...)
data = Dict(ps)
K, V = typeof(data).parameters
DictSpace{K,V}(Dict(ps))
return DictSpace{K,V}(Dict(ps))
end

Base.eltype(::DictSpace{K}) where {K} = Dict{K}
Expand Down
13 changes: 6 additions & 7 deletions src/implementations/spaces/discrete_space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,21 @@ The `span` can be of any iterators.
# Example

```julia-repl
julia> s = DiscreteSpace([1,2,3])
julia > s = DiscreteSpace([1, 2, 3])
DiscreteSpace{Array{Int64,1}}([1, 2, 3])

julia> 0 ∉ s
julia > 0 ∉ s
true

julia> 2 ∈ s
julia > 2 ∈ s
true

julia> s = DiscreteSpace(Set([:a, :c, :a, :b]))
julia > s = DiscreteSpace(Set([:a, :c, :a, :b]))
DiscreteSpace{Set{Symbol}}(Set(Symbol[:a, :b, :c]))

julia> s = DiscreteSpace(3)
julia > s = DiscreteSpace(3)
DiscreteSpace{UnitRange{Int64}}(1:3)
```

"""
struct DiscreteSpace{T} <: AbstractSpace
span::T
Expand All @@ -40,7 +39,7 @@ DiscreteSpace(high::T) where {T<:Integer} = DiscreteSpace(one(T), high)

function DiscreteSpace(low::T, high::T) where {T<:Integer}
high >= low || throw(ArgumentError("$high must be >= $low"))
DiscreteSpace(low:high)
return DiscreteSpace(low:high)
end

"""
Expand Down
4 changes: 2 additions & 2 deletions src/implementations/spaces/multi_continuous_space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct MultiContinuousSpace{T<:AbstractArray} <: AbstractSpace
throw(ArgumentError("$(size(low)) != $(size(high)), size must match"))
all(map((l, h) -> l <= h, low, high)) ||
throw(ArgumentError("each element of $low must be ≤ than $high"))
new{T}(low, high)
return new{T}(low, high)
end
end

Expand All @@ -28,5 +28,5 @@ Base.in(xs, s::MultiContinuousSpace) =
Base.length(s::MultiContinuousSpace) = error("MultiContinuousSpace is uncountable")

function Random.rand(rng::AbstractRNG, s::MultiContinuousSpace{T}) where {T}
(s.high .- s.low) .* rand(rng, eltype(T), size(s.low)...) .+ s.low
return (s.high .- s.low) .* rand(rng, eltype(T), size(s.low)...) .+ s.low
end
2 changes: 1 addition & 1 deletion src/implementations/spaces/multi_discrete_space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct MultiDiscreteSpace{T<:AbstractArray} <: AbstractSpace
function MultiDiscreteSpace(low::T, high::T) where {T<:AbstractArray}
all(map((l, h) -> l <= h, low, high)) ||
throw(ArgumentError("each element of $high must be ≥r $low"))
new{T}(low, high, reduce(*, map((l, h) -> h - l + 1, low, high)))
return new{T}(low, high, reduce(*, map((l, h) -> h - l + 1, low, high)))
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/implementations/spaces/vect_space.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ Here `I` represents the index of action in each inner space inside `s`.
"""
function Base.getindex(s::VectSpace, I::Vector{Int})
@assert length(s.data) == length(I)
[getindex(d, i) for (d, i) in zip(s.data, I)]
return [getindex(d, i) for (d, i) in zip(s.data, I)]
end
6 changes: 3 additions & 3 deletions src/inline_export.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ const ENV_API = []
const MULTI_AGENT_ENV_API = []

macro api(ex)
interfacem(__module__, __source__, ex, API)
return interfacem(__module__, __source__, ex, API)
end

macro env_api(ex)
interfacem(__module__, __source__, ex, ENV_API)
return interfacem(__module__, __source__, ex, ENV_API)
end

macro multi_agent_env_api(ex)
interfacem(__module__, __source__, ex, MULTI_AGENT_ENV_API)
return interfacem(__module__, __source__, ex, MULTI_AGENT_ENV_API)
end

function interfacem(__module__::Module, __source__::LineNumberNode, ex::Expr, store)
Expand Down
Loading