diff --git a/src/common/CircularArraySARTSATraces.jl b/src/common/CircularArraySARTSATraces.jl index 53678b7..346c556 100644 --- a/src/common/CircularArraySARTSATraces.jl +++ b/src/common/CircularArraySARTSATraces.jl @@ -9,27 +9,28 @@ const CircularArraySARTSATraces = Traces{ <:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}}, <:Trace{<:CircularArrayBuffer}, <:Trace{<:CircularArrayBuffer}, - } + }, } function CircularArraySARTSATraces(; capacity::Int, - state=Int => (), - action=Int => (), - reward=Float32 => (), - terminal=Bool => () + state = Int => (), + action = Int => (), + reward = Float32 => (), + terminal = Bool => (), ) state_eltype, state_size = state action_eltype, action_size = action reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) + - MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) + + MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + + MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( - reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), + reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity), + terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) end -CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = minimum(map(capacity,t.traces)) +CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = + minimum(map(capacity, t.traces)) diff --git a/src/common/CircularArraySARTSTraces.jl b/src/common/CircularArraySARTSTraces.jl index eb43038..968789b 100644 --- a/src/common/CircularArraySARTSTraces.jl +++ b/src/common/CircularArraySARTSTraces.jl @@ -9,27 +9,28 @@ const CircularArraySARTSTraces = Traces{ <:Trace{<:CircularArrayBuffer}, <:Trace{<:CircularArrayBuffer}, <:Trace{<:CircularArrayBuffer}, - } + }, } function CircularArraySARTSTraces(; capacity::Int, - state=Int => (), - action=Int => (), - reward=Float32 => (), - terminal=Bool => () + state = Int => (), + action = Int => (), + reward = Float32 => (), + terminal = Bool => (), ) state_eltype, state_size = state action_eltype, action_size = action reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) + + MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + Traces( action = CircularArrayBuffer{action_eltype}(action_size..., capacity), - reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), + reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity), + terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) end -CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = minimum(map(capacity,t.traces)) +CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = + minimum(map(capacity, t.traces)) diff --git a/src/common/CircularArraySLARTTraces.jl b/src/common/CircularArraySLARTTraces.jl index 0677906..ab334ac 100644 --- a/src/common/CircularArraySLARTTraces.jl +++ b/src/common/CircularArraySLARTTraces.jl @@ -8,16 +8,16 @@ const CircularArraySLARTTraces = Traces{ <:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}}, <:Trace{<:CircularArrayBuffer}, <:Trace{<:CircularArrayBuffer}, - } + }, } function CircularArraySLARTTraces(; capacity::Int, - state=Int => (), - legal_actions_mask=Bool => (), - action=Int => (), - reward=Float32 => (), - terminal=Bool => () + state = Int => (), + legal_actions_mask = Bool => (), + action = Int => (), + reward = Float32 => (), + terminal = Bool => (), ) state_eltype, state_size = state action_eltype, action_size = action @@ -26,12 +26,18 @@ function CircularArraySLARTTraces(; terminal_eltype, terminal_size = terminal MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) + - MultiplexTraces{LL′}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) + + MultiplexTraces{LL′}( + CircularArrayBuffer{legal_actions_mask_eltype}( + legal_actions_mask_size..., + capacity + 1, + ), + ) + MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) + Traces( - reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity), - terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), + reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity), + terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity), ) end -CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = minimum(map(capacity,t.traces)) \ No newline at end of file +CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = + minimum(map(capacity, t.traces)) diff --git a/src/common/CircularPrioritizedTraces.jl b/src/common/CircularPrioritizedTraces.jl index 76581af..233da5e 100644 --- a/src/common/CircularPrioritizedTraces.jl +++ b/src/common/CircularPrioritizedTraces.jl @@ -9,7 +9,10 @@ struct CircularPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts} default_priority::Float32 end -function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts} +function CircularPrioritizedTraces( + traces::AbstractTraces{names,Ts}; + default_priority, +) where {names,Ts} new_names = (:key, :priority, names...) new_Ts = Tuple{Int,Float32,Ts.parameters...} c = capacity(traces) @@ -17,7 +20,7 @@ function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_pri CircularVectorBuffer{Int}(c), SumTree(c), traces, - default_priority + default_priority, ) end @@ -60,6 +63,8 @@ function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol) end end -Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names)) +Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = + NamedTuple{names}(map(k -> t[k][i], names)) -capacity(t::CircularPrioritizedTraces) = ReinforcementLearningTrajectories.capacity(t.traces) +capacity(t::CircularPrioritizedTraces) = + ReinforcementLearningTrajectories.capacity(t.traces) diff --git a/src/common/ElasticArraySARTSATraces.jl b/src/common/ElasticArraySARTSATraces.jl index 0644396..5f9e4e1 100644 --- a/src/common/ElasticArraySARTSATraces.jl +++ b/src/common/ElasticArraySARTSATraces.jl @@ -7,14 +7,14 @@ const ElasticArraySARTSATraces = Traces{ <:MultiplexTraces{AA′,<:Trace{<:ElasticArray}}, <:Trace{<:ElasticArray}, <:Trace{<:ElasticArray}, - } + }, } function ElasticArraySARTSATraces(; - state=Int => (), - action=Int => (), - reward=Float32 => (), - terminal=Bool => () + state = Int => (), + action = Int => (), + reward = Float32 => (), + terminal = Bool => (), ) state_eltype, state_size = state action_eltype, action_size = action @@ -24,8 +24,7 @@ function ElasticArraySARTSATraces(; MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) + Traces( - reward=ElasticArray{reward_eltype}(undef, reward_size..., 0), - terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0), + reward = ElasticArray{reward_eltype}(undef, reward_size..., 0), + terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0), ) end - diff --git a/src/common/ElasticArraySARTSTraces.jl b/src/common/ElasticArraySARTSTraces.jl index c833882..cf25471 100644 --- a/src/common/ElasticArraySARTSTraces.jl +++ b/src/common/ElasticArraySARTSTraces.jl @@ -7,24 +7,24 @@ const ElasticArraySARTSTraces = Traces{ <:Trace{<:ElasticArray}, <:Trace{<:ElasticArray}, <:Trace{<:ElasticArray}, - } + }, } function ElasticArraySARTSTraces(; - state=Int => (), - action=Int => (), - reward=Float32 => (), - terminal=Bool => ()) - + state = Int => (), + action = Int => (), + reward = Float32 => (), + terminal = Bool => (), +) + state_eltype, state_size = state action_eltype, action_size = action reward_eltype, reward_size = reward terminal_eltype, terminal_size = terminal - MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + - Traces( + MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + Traces( action = ElasticArray{action_eltype}(undef, action_size..., 0), - reward=ElasticArray{reward_eltype}(undef, reward_size..., 0), - terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0), + reward = ElasticArray{reward_eltype}(undef, reward_size..., 0), + terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0), ) end diff --git a/src/common/ElasticArraySLARTTraces.jl b/src/common/ElasticArraySLARTTraces.jl index 517eb7f..e5495f1 100644 --- a/src/common/ElasticArraySLARTTraces.jl +++ b/src/common/ElasticArraySLARTTraces.jl @@ -8,16 +8,16 @@ const ElasticArraySLARTTraces = Traces{ <:MultiplexTraces{AA′,<:Trace{<:ElasticArray}}, <:Trace{<:ElasticArray}, <:Trace{<:ElasticArray}, - } + }, } function ElasticArraySLARTTraces(; capacity::Int, - state=Int => (), - legal_actions_mask=Bool => (), - action=Int => (), - reward=Float32 => (), - terminal=Bool => () + state = Int => (), + legal_actions_mask = Bool => (), + action = Int => (), + reward = Float32 => (), + terminal = Bool => (), ) state_eltype, state_size = state action_eltype, action_size = action @@ -26,10 +26,12 @@ function ElasticArraySLARTTraces(; terminal_eltype, terminal_size = terminal MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + - MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) + + MultiplexTraces{LL′}( + ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0), + ) + MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) + Traces( - reward=ElasticArray{reward_eltype}(undef, reward_size..., 0), - terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0), + reward = ElasticArray{reward_eltype}(undef, reward_size..., 0), + terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0), ) end diff --git a/src/common/sum_tree.jl b/src/common/sum_tree.jl index 486ff28..fe9f276 100644 --- a/src/common/sum_tree.jl +++ b/src/common/sum_tree.jl @@ -139,7 +139,7 @@ function correct_sample(t::SumTree, leaf_ind) p = t.tree[leaf_ind] # walk backwards until p != 0 or until leftmost leaf reached tmp_ind = leaf_ind - while iszero(p) && (tmp_ind-1)*2 > length(t.tree) + while iszero(p) && (tmp_ind - 1) * 2 > length(t.tree) tmp_ind -= 1 p = t.tree[tmp_ind] end @@ -151,7 +151,7 @@ function correct_sample(t::SumTree, leaf_ind) end return p, tmp_ind end - + function Base.get(t::SumTree, v) parent_ind = 1 @@ -185,7 +185,7 @@ Random.rand(t::SumTree) = rand(Random.GLOBAL_RNG, t) function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T} inds, priorities = Vector{Int}(undef, n), Vector{T}(undef, n) - for i in 1:n + for i = 1:n v = (i - 1 + rand(rng, T)) / n ind, p = get(t, v * t.tree[1]) inds[i] = ind diff --git a/src/controllers.jl b/src/controllers.jl index e769947..70b15bd 100644 --- a/src/controllers.jl +++ b/src/controllers.jl @@ -1,4 +1,5 @@ -export InsertSampleRatioController, AsyncInsertSampleRatioController, EpisodeSampleRatioController +export InsertSampleRatioController, + AsyncInsertSampleRatioController, EpisodeSampleRatioController """ InsertSampleRatioController(;ratio=1., threshold=1) @@ -43,10 +44,11 @@ end function AsyncInsertSampleRatioController( ratio, threshold, - ; ch_in_sz=1, - ch_out_sz=1, - n_inserted=0, - n_sampled=0 + ; + ch_in_sz = 1, + ch_out_sz = 1, + n_inserted = 0, + n_sampled = 0, ) AsyncInsertSampleRatioController( ratio, @@ -54,7 +56,7 @@ function AsyncInsertSampleRatioController( n_inserted, n_sampled, Channel(ch_in_sz), - Channel(ch_out_sz) + Channel(ch_out_sz), ) end @@ -75,14 +77,14 @@ end function on_insert!(c::EpisodeSampleRatioController, n::Int, x::NamedTuple) if n > 0 - c.n_episodes += sum(x.terminal) + c.n_episodes += sum(x.terminal) end end function on_sample!(c::EpisodeSampleRatioController) - if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio + if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio c.n_sampled += 1 return true end return false -end \ No newline at end of file +end diff --git a/src/episodes.jl b/src/episodes.jl index d4314d7..83efc44 100644 --- a/src/episodes.jl +++ b/src/episodes.jl @@ -22,7 +22,8 @@ If `traces` is a capacitated buffer, such as a CircularArraySARTSTraces, then th EpisodesBuffer assumes that individual transitions are `push!`ed. Appending is not yet supported. """ -mutable struct EpisodesBuffer{names, E, T<:AbstractTraces{names, E},B,S} <: AbstractTraces{names,E} +mutable struct EpisodesBuffer{names,E,T<:AbstractTraces{names,E},B,S} <: + AbstractTraces{names,E} traces::T sampleable_inds::S step_numbers::B @@ -42,25 +43,26 @@ end # Capacity of an EpisodesBuffer is the capacity of the underlying traces + 1 for certain cases function is_capacity_plus_one(traces::AbstractTraces) - if any(t->t isa MultiplexTraces, traces.traces) + if any(t -> t isa MultiplexTraces, traces.traces) # MultiplexTraces buffer next_state or next_action, so we need to add one to the capacity return true else false end end -is_capacity_plus_one(traces::CircularPrioritizedTraces) = is_capacity_plus_one(traces.traces) +is_capacity_plus_one(traces::CircularPrioritizedTraces) = + is_capacity_plus_one(traces.traces) function EpisodesBuffer(traces::AbstractTraces) cap = is_capacity_plus_one(traces) ? capacity(traces) + 1 : capacity(traces) @assert isempty(traces) "EpisodesBuffer must be initialized with empty traces." if !isinf(cap) - legalinds = CircularBuffer{Bool}(cap) + legalinds = CircularBuffer{Bool}(cap) step_numbers = CircularBuffer{Int}(cap) eplengths = deepcopy(step_numbers) EpisodesBuffer(traces, legalinds, step_numbers, eplengths) else - legalinds = BitVector() + legalinds = BitVector() step_numbers = Vector{Int}() eplengths = deepcopy(step_numbers) EpisodesBuffer(traces, legalinds, step_numbers, eplengths) @@ -68,7 +70,8 @@ function EpisodesBuffer(traces::AbstractTraces) end function Base.getindex(es::EpisodesBuffer, idx::Int...) - @boundscheck all(es.sampleable_inds[idx...]) || throw(BoundsError(es.sampleable_inds, idx)) + @boundscheck all(es.sampleable_inds[idx...]) || + throw(BoundsError(es.sampleable_inds, idx)) getindex(es.traces, idx...) end @@ -81,7 +84,8 @@ capacity(eb::EpisodesBuffer) = capacity(eb.traces) Base.size(eb::EpisodesBuffer) = size(eb.traces) Base.length(eb::EpisodesBuffer) = length(eb.traces) Base.keys(eb::EpisodesBuffer) = keys(eb.traces) -Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = keys(eb.traces.traces) +Base.keys(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = + keys(eb.traces.traces) function Base.show(io::IO, m::MIME"text/plain", eb::EpisodesBuffer{names}) where {names} s = nameof(typeof(eb)) t = eb.traces @@ -91,15 +95,16 @@ end ispartial_insert(traces::Traces, xs) = length(xs) < length(traces.traces) #this is the number of traces it contains not the number of steps. ispartial_insert(eb::EpisodesBuffer, xs) = ispartial_insert(eb.traces, xs) -ispartial_insert(traces::CircularPrioritizedTraces, xs) = ispartial_insert(traces.traces, xs) +ispartial_insert(traces::CircularPrioritizedTraces, xs) = + ispartial_insert(traces.traces, xs) function pad!(trace::Trace) pad!(trace.parent) return nothing end -pad!(vect::ElasticArray{T, Vector{T}}) where {T} = push!(vect, zero(T)) -pad!(vect::ElasticVector{T, Vector{T}}) where {T} = push!(vect, zero(T)) +pad!(vect::ElasticArray{T,Vector{T}}) where {T} = push!(vect, zero(T)) +pad!(vect::ElasticVector{T,Vector{T}}) where {T} = push!(vect, zero(T)) pad!(buf::CircularArrayBuffer{T,N,A}) where {T,N,A} = push!(buf, zero(T)) pad!(vect::Vector{T}) where {T} = push!(vect, zero(T)) @@ -133,7 +138,8 @@ end fill_multiplex(eb::EpisodesBuffer) = fill_multiplex(eb.traces) -fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = fill_multiplex(eb.traces.traces) +fill_multiplex(eb::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}) = + fill_multiplex(eb.traces.traces) max_length(eb::EpisodesBuffer) = max_length(eb.traces) @@ -145,7 +151,7 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.episodes_lengths, 0) push!(eb.sampleable_inds, 0) elseif !partial #typical inserting - if haskey(eb,:next_action) # if trace has next_action + if haskey(eb, :next_action) # if trace has next_action if eb.step_numbers[end] > 1 # and if there are sufficient steps in the current episode eb.sampleable_inds[end-1] = 1 # steps are indexable one step later end @@ -155,7 +161,7 @@ function Base.push!(eb::EpisodesBuffer, xs::NamedTuple) push!(eb.sampleable_inds, 0) #this one is no longer ep_length = last(eb.step_numbers) push!(eb.episodes_lengths, ep_length) - startidx = max(1,length(eb.step_numbers) - last(eb.step_numbers)) + startidx = max(1, length(eb.step_numbers) - last(eb.step_numbers)) eb.episodes_lengths[startidx:end] .= ep_length push!(eb.step_numbers, ep_length + 1) elseif partial diff --git a/src/normalization.jl b/src/normalization.jl index 598d55a..35bc73e 100644 --- a/src/normalization.jl +++ b/src/normalization.jl @@ -1,4 +1,5 @@ -import OnlineStats: OnlineStats, Group, Moments, fit!, OnlineStat, Weight, EqualWeight, mean, std +import OnlineStats: + OnlineStats, Group, Moments, fit!, OnlineStat, Weight, EqualWeight, mean, std export scalar_normalizer, array_normalizer, NormalizedTraces, Normalizer import MacroTools.@forward @@ -11,11 +12,15 @@ struct Normalizer{OS<:OnlineStat} os::OS end -@forward Normalizer.os OnlineStats.mean, OnlineStats.std, Base.iterate, normalize, Base.length +@forward Normalizer.os OnlineStats.mean, +OnlineStats.std, +Base.iterate, +normalize, +Base.length #Treats last dim as batch dim function OnlineStats.fit!(n::Normalizer, data::AbstractArray) - for d in eachslice(data, dims=ndims(data)) + for d in eachslice(data, dims = ndims(data)) fit!(n.os, vec(d)) end n @@ -74,13 +79,17 @@ end function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractArray) xn = similar(x) - for (i, slice) in enumerate(eachslice(x, dims=ndims(x))) - xn[repeat([:], ndims(x) - 1)..., i] .= reshape(normalize(os, vec(slice)), size(x)[1:end-1]...) + for (i, slice) in enumerate(eachslice(x, dims = ndims(x))) + xn[repeat([:], ndims(x) - 1)..., i] .= + reshape(normalize(os, vec(slice)), size(x)[1:end-1]...) end return xn end -function normalize(os::Group{<:AbstractVector{<:Moments}}, x::AbstractVector{<:AbstractArray}) +function normalize( + os::Group{<:AbstractVector{<:Moments}}, + x::AbstractVector{<:AbstractArray}, +) xn = similar(x) for (i, el) in enumerate(x) xn[i] = normalize(os, vec(el)) @@ -96,7 +105,7 @@ have equal weights in the computation of the moments. See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) to use variants such as exponential weights to favor the most recent observations. """ -scalar_normalizer(; weight::Weight=EqualWeight()) = Normalizer(Moments(weight=weight)) +scalar_normalizer(; weight::Weight = EqualWeight()) = Normalizer(Moments(weight = weight)) """ array_normalizer(size::Tuple{Int}; weights = OnlineStats.EqualWeight()) @@ -108,7 +117,8 @@ By default, all samples have equal weights in the computation of the moments. See the [OnlineStats documentation](https://joshday.github.io/OnlineStats.jl/stable/weights/) to use variants such as exponential weights to favor the most recent observations. """ -array_normalizer(size::NTuple{N,Int}; weight::Weight=EqualWeight()) where {N} = Normalizer(Group([Moments(weight=weight) for _ in 1:prod(size)])) +array_normalizer(size::NTuple{N,Int}; weight::Weight = EqualWeight()) where {N} = + Normalizer(Group([Moments(weight = weight) for _ = 1:prod(size)])) """ NormalizedTraces(traces::AbstractTraces, normalizers::NamedTuple) @@ -142,12 +152,16 @@ traj = Trajectory( ) ``` """ -struct NormalizedTraces{names,TT,T<:AbstractTraces{names,TT},normnames,N} <: AbstractTraces{names,TT} +struct NormalizedTraces{names,TT,T<:AbstractTraces{names,TT},normnames,N} <: + AbstractTraces{names,TT} traces::T normalizers::NamedTuple{normnames,N} end -function NormalizedTraces(traces::AbstractTraces{names,TT}; trace_normalizer_pairs...) where {names} where {TT} +function NormalizedTraces( + traces::AbstractTraces{names,TT}; + trace_normalizer_pairs..., +) where {names} where {TT} for key in keys(trace_normalizer_pairs) @assert key in keys(traces) "Traces do not have key $key, valid keys are $(keys(traces))." end @@ -155,7 +169,8 @@ function NormalizedTraces(traces::AbstractTraces{names,TT}; trace_normalizer_pai for trace in traces.traces #check if all traces of MultiplexTraces are in pairs if trace isa MultiplexTraces - if length(intersect(keys(trace), keys(trace_normalizer_pairs))) in [0, length(keys(trace))] #check if none or all keys are in normalizers + if length(intersect(keys(trace), keys(trace_normalizer_pairs))) in + [0, length(keys(trace))] #check if none or all keys are in normalizers continue else #if not then one is missing present_key = only(intersect(keys(trace), keys(trace_normalizer_pairs))) @@ -180,10 +195,18 @@ function Base.show(io::IO, ::MIME"text/plain", t::NormalizedTraces{names,T}) whe end end -@forward NormalizedTraces.traces Base.length, Base.size, Base.lastindex, Base.firstindex, Base.view, Base.pop!, Base.popfirst!, Base.empty!, Base.parent +@forward NormalizedTraces.traces Base.length, +Base.size, +Base.lastindex, +Base.firstindex, +Base.view, +Base.pop!, +Base.popfirst!, +Base.empty!, +Base.parent for f in (:push!, :pushfirst!, :append!, :prepend!) - @eval function Base.$f(nt::NormalizedTraces, x::T) where T + @eval function Base.$f(nt::NormalizedTraces, x::T) where {T} for key in intersect(keys(nt.normalizers), fieldnames(T)) fit!(nt.normalizers[key], getfield(x, key)) end @@ -191,19 +214,28 @@ for f in (:push!, :pushfirst!, :append!, :prepend!) end end -function StatsBase.sample(s::BatchSampler, nt::NormalizedTraces, names, weights = StatsBase.UnitWeights{Int}(length(nt))) +function StatsBase.sample( + s::BatchSampler, + nt::NormalizedTraces, + names, + weights = StatsBase.UnitWeights{Int}(length(nt)), +) inds = StatsBase.sample(s.rng, 1:length(nt), weights, s.batchsize) - maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data + maybe_normalize(data, key) = + key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data NamedTuple{names}(collect(maybe_normalize(nt[x][inds], x)) for x in names) end function Base.getindex(nt::NormalizedTraces, inds) - maybe_normalize(data, key) = key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data - NamedTuple{keys(nt.traces)}(collect(maybe_normalize(nt.traces[x][inds], x)) for x in keys(nt.traces)) + maybe_normalize(data, key) = + key in keys(nt.normalizers) ? normalize(nt.normalizers[key], data) : data + NamedTuple{keys(nt.traces)}( + collect(maybe_normalize(nt.traces[x][inds], x)) for x in keys(nt.traces) + ) end function Base.getindex(nt::NormalizedTraces, s::Symbol) getindex(nt.traces, s) end -ispartial_insert(traces::NormalizedTraces, xs) = ispartial_insert(traces.traces, xs) \ No newline at end of file +ispartial_insert(traces::NormalizedTraces, xs) = ispartial_insert(traces.traces, xs) diff --git a/src/patch.jl b/src/patch.jl index 200bb80..0908626 100644 --- a/src/patch.jl +++ b/src/patch.jl @@ -3,4 +3,4 @@ import StackViews: StackView lazy_stack(x) = StackView(x) -lazy_stack(x::AbstractVector{<:Number}) = x \ No newline at end of file +lazy_stack(x::AbstractVector{<:Number}) = x diff --git a/src/rendering.jl b/src/rendering.jl index f2c11a8..430cca8 100644 --- a/src/rendering.jl +++ b/src/rendering.jl @@ -1,35 +1,58 @@ using Term -const TRACE_COLORS = ("bright_green", "hot_pink", "bright_blue", "light_coral", "bright_cyan", "sandy_brown", "violet") +const TRACE_COLORS = ( + "bright_green", + "hot_pink", + "bright_blue", + "light_coral", + "bright_cyan", + "sandy_brown", + "violet", +) -Base.show(io::IO, ::MIME"text/plain", t::Union{Trace,Traces,Episode,Episodes,Trajectory}) = tprint(io, convert(Term.AbstractRenderable, t; width=displaysize(io)[2]) |> string) +Base.show(io::IO, ::MIME"text/plain", t::Union{Trace,Traces,Episode,Episodes,Trajectory}) = + tprint(io, convert(Term.AbstractRenderable, t; width = displaysize(io)[2]) |> string) -inner_convert(::Type{Term.AbstractRenderable}, s::String; style="gray1", width=88) = Panel(s, width=width, style=style, justify=:center) -inner_convert(t::Type{Term.AbstractRenderable}, x::Union{Symbol,Number}; kw...) = inner_convert(t, string(x); kw...) +inner_convert(::Type{Term.AbstractRenderable}, s::String; style = "gray1", width = 88) = + Panel(s, width = width, style = style, justify = :center) +inner_convert(t::Type{Term.AbstractRenderable}, x::Union{Symbol,Number}; kw...) = + inner_convert(t, string(x); kw...) -function inner_convert(::Type{Term.AbstractRenderable}, x::AbstractArray; style="gray1", width=88) +function inner_convert( + ::Type{Term.AbstractRenderable}, + x::AbstractArray; + style = "gray1", + width = 88, +) t = string(nameof(typeof(x))) s = replace(string(size(x)), " " => "") - Panel(t * "\n" * s, style=style, justify=:center, width=width) + Panel(t * "\n" * s, style = style, justify = :center, width = width) end -function inner_convert(::Type{Term.AbstractRenderable}, x; style="gray1", width=88) +function inner_convert(::Type{Term.AbstractRenderable}, x; style = "gray1", width = 88) s = string(nameof(typeof(x))) - Panel(s, style=style, justify=:center, width=width) + Panel(s, style = style, justify = :center, width = width) end -Base.convert(T::Type{Term.AbstractRenderable}, t::Trace{<:AbstractArray}; kw...) = convert(T, Trace(collect(eachslice(t.x, dims=ndims(t.x)))); kw..., type=typeof(t), subtitle="size: $(size(t.x))") -Base.convert(T::Type{Term.AbstractRenderable}, t::NormalizedTrace; kw...) = convert(T, t.trace; kw..., type = typeof(t)) +Base.convert(T::Type{Term.AbstractRenderable}, t::Trace{<:AbstractArray}; kw...) = convert( + T, + Trace(collect(eachslice(t.x, dims = ndims(t.x)))); + kw..., + type = typeof(t), + subtitle = "size: $(size(t.x))", +) +Base.convert(T::Type{Term.AbstractRenderable}, t::NormalizedTrace; kw...) = + convert(T, t.trace; kw..., type = typeof(t)) function Base.convert( ::Type{Term.AbstractRenderable}, t::Trace{<:AbstractVector}; - width=88, - n_head=2, - n_tail=1, - name="Trace", - style=TRACE_COLORS[mod1(hash(name), length(TRACE_COLORS))], - type=typeof(t), - subtitle="size: $(size(t.x))" + width = 88, + n_head = 2, + n_tail = 1, + name = "Trace", + style = TRACE_COLORS[mod1(hash(name), length(TRACE_COLORS))], + type = typeof(t), + subtitle = "size: $(size(t.x))", ) title = "$name: [italic]$type[/italic] " min_width = min(width, length(title) - 4) @@ -38,16 +61,50 @@ function Base.convert( if n == 0 content = "" elseif 1 <= n <= n_head + n_tail - content = mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x) + content = mapreduce( + x -> inner_convert( + Term.AbstractRenderable, + x, + style = style, + width = min_width - 6, + ), + /, + t.x, + ) else - content = mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x[1:n_head]) / - TextBox("...", justify=:center, width=min_width - 6) / - mapreduce(x -> inner_convert(Term.AbstractRenderable, x, style=style, width=min_width - 6), /, t.x[end-n_tail+1:end]) + content = + mapreduce( + x -> inner_convert( + Term.AbstractRenderable, + x, + style = style, + width = min_width - 6, + ), + /, + t.x[1:n_head], + ) / TextBox("...", justify = :center, width = min_width - 6) / mapreduce( + x -> inner_convert( + Term.AbstractRenderable, + x, + style = style, + width = min_width - 6, + ), + /, + t.x[end-n_tail+1:end], + ) end - Panel(content, width=min_width, title=title, subtitle=subtitle, subtitle_justify=:right, style=style, subtitle_style="yellow") + Panel( + content, + width = min_width, + title = title, + subtitle = subtitle, + subtitle_justify = :right, + style = style, + subtitle_style = "yellow", + ) end -function Base.convert(::Type{Term.AbstractRenderable}, t::Traces; width=88) +function Base.convert(::Type{Term.AbstractRenderable}, t::Traces; width = 88) max_len = mapreduce(length, max, t.traces) min_len = mapreduce(length, min, t.traces) if max_len - min_len == 1 @@ -58,78 +115,104 @@ function Base.convert(::Type{Term.AbstractRenderable}, t::Traces; width=88) N = length(t.traces) max_inner_width = ceil(Int, (width - 6 * 2) / N) Panel( - mapreduce(((i, x),) -> convert(Term.AbstractRenderable, t[x]; width=max_inner_width, name=x, n_tail=n_tails[i], style=TRACE_COLORS[mod1(i, length(TRACE_COLORS))]), *, enumerate(keys(t))), - title="Traces", - style="yellow3", - subtitle="$N traces in total", - subtitle_justify=:right, - width=width, - fit=true + mapreduce( + ((i, x),) -> convert( + Term.AbstractRenderable, + t[x]; + width = max_inner_width, + name = x, + n_tail = n_tails[i], + style = TRACE_COLORS[mod1(i, length(TRACE_COLORS))], + ), + *, + enumerate(keys(t)), + ), + title = "Traces", + style = "yellow3", + subtitle = "$N traces in total", + subtitle_justify = :right, + width = width, + fit = true, ) end -function Base.convert(::Type{Term.AbstractRenderable}, e::Episode; width=88) +function Base.convert(::Type{Term.AbstractRenderable}, e::Episode; width = 88) Panel( - convert(Term.AbstractRenderable, e.traces; width=width - 6), - title="Episode", - style="green_yellow", - subtitle=e[] ? "Episode END" : "Episode growing...", - subtitle_justify=:right, - width=width, - fit=true + convert(Term.AbstractRenderable, e.traces; width = width - 6), + title = "Episode", + style = "green_yellow", + subtitle = e[] ? "Episode END" : "Episode growing...", + subtitle_justify = :right, + width = width, + fit = true, ) end -function Base.convert(::Type{Term.AbstractRenderable}, e::Episodes; width=88) +function Base.convert(::Type{Term.AbstractRenderable}, e::Episodes; width = 88) n = length(e) if n == 0 content = "" elseif n == 1 - content = convert(Term.AbstractRenderable, e[1], width=width - 6) + content = convert(Term.AbstractRenderable, e[1], width = width - 6) elseif n == 2 - content = convert(Term.AbstractRenderable, e[1], width=width - 6) / - convert(Term.AbstractRenderable, e[end], width=width - 6) + content = + convert(Term.AbstractRenderable, e[1], width = width - 6) / + convert(Term.AbstractRenderable, e[end], width = width - 6) else - content = convert(Term.AbstractRenderable, e[1], width=width - 6) / - TextBox("...", justify=:center, width=width - 6) / - convert(Term.AbstractRenderable, e[end], width=width - 6) + content = + convert(Term.AbstractRenderable, e[1], width = width - 6) / + TextBox("...", justify = :center, width = width - 6) / + convert(Term.AbstractRenderable, e[end], width = width - 6) end Panel( content, - title="Episodes", - subtitle="$n episodes in total", - subtitle_justify=:right, - width=width, - fit=true, - style="wheat1" + title = "Episodes", + subtitle = "$n episodes in total", + subtitle_justify = :right, + width = width, + fit = true, + style = "wheat1", ) end -function Base.convert(r::Type{Term.AbstractRenderable}, t::Trajectory; width=88) +function Base.convert(r::Type{Term.AbstractRenderable}, t::Trajectory; width = 88) Panel( - convert(r, t.container; width=width - 8) / - Panel(convert(Term.Tree, t.sampler); title="sampler", style="yellow3", fit=true, width=width - 8) / - Panel(convert(Term.Tree, t.controller); title="controller", style="yellow3", fit=true, width=width - 8); - title="Trajectory", - style="yellow3", - width=width, - fit=true + convert(r, t.container; width = width - 8) / Panel( + convert(Term.Tree, t.sampler); + title = "sampler", + style = "yellow3", + fit = true, + width = width - 8, + ) / Panel( + convert(Term.Tree, t.controller); + title = "controller", + style = "yellow3", + fit = true, + width = width - 8, + ); + title = "Trajectory", + style = "yellow3", + width = width, + fit = true, ) end # general converter -Base.convert(::Type{Term.Tree}, x) = Tree(to_tree_body(x); title=to_tree_title(x)) +Base.convert(::Type{Term.Tree}, x) = Tree(to_tree_body(x); title = to_tree_title(x)) Base.convert(::Type{Term.Tree}, x::Tree) = x function to_tree_body(x) pts = propertynames(x) if length(pts) > 0 - Dict("$p => $(summary(getproperty(x, p)))" => to_tree_body(getproperty(x, p)) for p in pts) + Dict( + "$p => $(summary(getproperty(x, p)))" => to_tree_body(getproperty(x, p)) for + p in pts + ) else x end end -to_tree_title(x) = "$(summary(x))" \ No newline at end of file +to_tree_title(x) = "$(summary(x))" diff --git a/src/samplers.jl b/src/samplers.jl index 21628bd..fe9ed25 100644 --- a/src/samplers.jl +++ b/src/samplers.jl @@ -1,5 +1,12 @@ using Random -export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler, MultiStepSampler +export EpisodesSampler, + Episode, + BatchSampler, + NStepBatchSampler, + MetaSampler, + MultiBatchSampler, + DummySampler, + MultiStepSampler struct SampleGenerator{S,T} sampler::S @@ -40,21 +47,35 @@ end Uniformly sample **ONE** batch of `batchsize` examples for each trace specified in `names`. If `names` is not set, all the traces will be sampled. """ -BatchSampler(batchsize; kw...) = BatchSampler(; batchsize=batchsize, kw...) +BatchSampler(batchsize; kw...) = BatchSampler(; batchsize = batchsize, kw...) BatchSampler(; kw...) = BatchSampler{nothing}(; kw...) -BatchSampler{names}(batchsize; kw...) where {names} = BatchSampler{names}(; batchsize=batchsize, kw...) -BatchSampler{names}(; batchsize, rng=Random.GLOBAL_RNG) where {names} = BatchSampler{names}(batchsize, rng) - -StatsBase.sample(s::BatchSampler{nothing}, t::AbstractTraces) = StatsBase.sample(s, t, keys(t)) -StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = StatsBase.sample(s, t, names) - -function StatsBase.sample(s::BatchSampler, t::AbstractTraces, names, weights = StatsBase.UnitWeights{Int}(length(t))) +BatchSampler{names}(batchsize; kw...) where {names} = + BatchSampler{names}(; batchsize = batchsize, kw...) +BatchSampler{names}(; batchsize, rng = Random.GLOBAL_RNG) where {names} = + BatchSampler{names}(batchsize, rng) + +StatsBase.sample(s::BatchSampler{nothing}, t::AbstractTraces) = + StatsBase.sample(s, t, keys(t)) +StatsBase.sample(s::BatchSampler{names}, t::AbstractTraces) where {names} = + StatsBase.sample(s, t, names) + +function StatsBase.sample( + s::BatchSampler, + t::AbstractTraces, + names, + weights = StatsBase.UnitWeights{Int}(length(t)), +) inds = StatsBase.sample(s.rng, 1:length(t), weights, s.batchsize) NamedTuple{names}(map(x -> collect(t[Val(x)][inds]), names)) end function StatsBase.sample(s::BatchSampler, t::EpisodesBuffer, names) - StatsBase.sample(s, t.traces, names, StatsBase.FrequencyWeights(t.sampleable_inds[1:end-1])) + StatsBase.sample( + s, + t.traces, + names, + StatsBase.FrequencyWeights(t.sampleable_inds[1:end-1]), + ) end # !!! avoid iterating an empty trajectory @@ -68,20 +89,33 @@ end ##### -StatsBase.sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = StatsBase.sample(s, t, keys(t.traces)) +StatsBase.sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = + StatsBase.sample(s, t, keys(t.traces)) -function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}, names) +function StatsBase.sample( + s::BatchSampler, + e::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}, + names, +) t = e.traces p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) w .*= e.sampleable_inds[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) - NamedTuple{(:key, :priority, names...)}((t.keys[inds], p[inds], map(x -> collect(t.traces[Val(x)][inds]), names)...)) + NamedTuple{(:key, :priority, names...)}(( + t.keys[inds], + p[inds], + map(x -> collect(t.traces[Val(x)][inds]), names)..., + )) end function StatsBase.sample(s::BatchSampler, t::CircularPrioritizedTraces, names) inds, priorities = rand(s.rng, t.priorities, s.batchsize) - NamedTuple{(:key, :priority, names...)}((t.keys[inds], priorities, map(x -> collect(t.traces[Val(x)][inds]), names)...)) + NamedTuple{(:key, :priority, names...)}(( + t.keys[inds], + priorities, + map(x -> collect(t.traces[Val(x)][inds]), names)..., + )) end ##### @@ -148,7 +182,7 @@ struct MultiBatchSampler{S} n::Int end -StatsBase.sample(m::MultiBatchSampler, t) = [StatsBase.sample(m.sampler, t) for _ in 1:m.n] +StatsBase.sample(m::MultiBatchSampler, t) = [StatsBase.sample(m.sampler, t) for _ = 1:m.n] function Base.iterate(s::SampleGenerator{<:MultiBatchSampler}) if length(s.traces) > 0 @@ -178,7 +212,7 @@ to an integer > 1. This samples the (stacksize - 1) previous states. This is use of partial observability, for example when the state is approximated by `stacksize` consecutive frames. """ -mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG} +mutable struct NStepBatchSampler{names,S<:Union{Nothing,Int},R<:AbstractRNG} n::Int # !!! n starts from 1 γ::Float32 batchsize::Int @@ -187,10 +221,16 @@ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int}, R <: AbstractRN end NStepBatchSampler(t::AbstractTraces; kw...) = NStepBatchSampler{keys(t)}(; kw...) -function NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.default_rng()) where {names} +function NStepBatchSampler{names}(; + n, + γ, + batchsize = 32, + stacksize = nothing, + rng = Random.default_rng(), +) where {names} @assert n >= 1 "n must be ≥ 1." ss = stacksize == 1 ? nothing : stacksize - NStepBatchSampler{names, typeof(ss), typeof(rng)}(n, γ, batchsize, ss, rng) + NStepBatchSampler{names,typeof(ss),typeof(rng)}(n, γ, batchsize, ss, rng) end #return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index. @@ -210,49 +250,92 @@ function StatsBase.sample(s::NStepBatchSampler{names}, ts) where {names} StatsBase.sample(s, ts, Val(names)) end -function StatsBase.sample(s::NStepBatchSampler, t::EpisodesBuffer, ::Val{names}) where names +function StatsBase.sample( + s::NStepBatchSampler, + t::EpisodesBuffer, + ::Val{names}, +) where {names} weights, ns = valid_range(s, t) - inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize) + inds = StatsBase.sample( + s.rng, + 1:length(t), + StatsBase.FrequencyWeights(weights[1:end-1]), + s.batchsize, + ) fetch(s, t, Val(names), inds, ns) end -function fetch(s::NStepBatchSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names - NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names)) +function fetch( + s::NStepBatchSampler, + ts::EpisodesBuffer, + ::Val{names}, + inds, + ns, +) where {names} + NamedTuple{names}( + map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names), + ) end #state and next_state have specialized fetch methods due to stacksize -fetch(::NStepBatchSampler{names, Nothing}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[inds] -fetch(s::NStepBatchSampler{names, Int}, trace::AbstractTrace, ::Val{:state}, inds, ns) where {names} = trace[[x + i for i in -s.stacksize+1:0, x in inds]] -fetch(::NStepBatchSampler{names, Nothing}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[inds .+ ns .- 1] -fetch(s::NStepBatchSampler{names, Int}, trace::RelativeTrace{1,0}, ::Val{:next_state}, inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in -s.stacksize+1:0, (idx,x) in enumerate(inds)]] +fetch( + ::NStepBatchSampler{names,Nothing}, + trace::AbstractTrace, + ::Val{:state}, + inds, + ns, +) where {names} = trace[inds] +fetch( + s::NStepBatchSampler{names,Int}, + trace::AbstractTrace, + ::Val{:state}, + inds, + ns, +) where {names} = trace[[x + i for i = -s.stacksize+1:0, x in inds]] +fetch( + ::NStepBatchSampler{names,Nothing}, + trace::RelativeTrace{1,0}, + ::Val{:next_state}, + inds, + ns, +) where {names} = trace[inds.+ns.-1] +fetch( + s::NStepBatchSampler{names,Int}, + trace::RelativeTrace{1,0}, + ::Val{:next_state}, + inds, + ns, +) where {names} = + trace[[x + ns[idx] - 1 + i for i = -s.stacksize+1:0, (idx, x) in enumerate(inds)]] #reward due to discounting function fetch(s::NStepBatchSampler, trace::AbstractTrace, ::Val{:reward}, inds, ns) rewards = Vector{eltype(trace)}(undef, length(inds)) - for (i,idx) in enumerate(inds) + for (i, idx) in enumerate(inds) rewards_to_go = trace[idx:idx+ns[i]-1] - rewards[i] = foldr((x,y)->x + s.γ*y, rewards_to_go) + rewards[i] = foldr((x, y) -> x + s.γ * y, rewards_to_go) end return rewards end #terminal is that of the nth step -fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val{:terminal}, inds, ns) = trace[inds .+ ns .- 1] +fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val{:terminal}, inds, ns) = + trace[inds.+ns.-1] #right multiplex traces must be n-step sampled -fetch(::NStepBatchSampler, trace::RelativeTrace{1,0} , ::Val, inds, ns) = trace[inds .+ ns .- 1] +fetch(::NStepBatchSampler, trace::RelativeTrace{1,0}, ::Val, inds, ns) = trace[inds.+ns.-1] #normal trace types are fetched at inds fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val, inds, ns) = trace[inds] #other types of trace are sampled normally -function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names} +function StatsBase.sample( + s::NStepBatchSampler{names}, + e::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}, +) where {names} t = e.traces p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) - valids, ns = valid_range(s,e) + valids, ns = valid_range(s, e) w .*= valids[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) - merge( - (key=t.keys[inds], priority=p[inds]), - fetch(s, e, Val(names), inds, ns) - ) + merge((key = t.keys[inds], priority = p[inds]), fetch(s, e, Val(names), inds, ns)) end """ @@ -262,21 +345,22 @@ A sampler that samples all Episodes present in the Trajectory and divides them i Episode containers. Truncated Episodes (e.g. due to the buffer capacity) are sampled as well. There will be at most one truncated episode and it will always be the first one. """ -struct EpisodesSampler{names} -end +struct EpisodesSampler{names} end EpisodesSampler() = EpisodesSampler{nothing}() #EpisodesSampler{names}() = new{names}() -struct Episode{names, N <: NamedTuple{names}} +struct Episode{names,N<:NamedTuple{names}} nt::N end @forward Episode.nt Base.keys, Base.haskey, Base.getindex -StatsBase.sample(s::EpisodesSampler{nothing}, t::EpisodesBuffer) = StatsBase.sample(s,t,keys(t)) -StatsBase.sample(s::EpisodesSampler{names}, t::EpisodesBuffer) where names = StatsBase.sample(s,t,names) +StatsBase.sample(s::EpisodesSampler{nothing}, t::EpisodesBuffer) = + StatsBase.sample(s, t, keys(t)) +StatsBase.sample(s::EpisodesSampler{names}, t::EpisodesBuffer) where {names} = + StatsBase.sample(s, t, names) function make_episode(t::EpisodesBuffer, range, names) nt = NamedTuple{names}(map(x -> collect(t[Val(x)][range]), names)) @@ -289,7 +373,7 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names) while idx < length(t) if t.sampleable_inds[idx] == 1 last_state_idx = idx + t.episodes_lengths[idx] - t.step_numbers[idx] - push!(ranges,idx:last_state_idx) + push!(ranges, idx:last_state_idx) idx = last_state_idx + 1 else idx += 1 @@ -309,7 +393,7 @@ Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each samp truncated by the end of its episode. This means that the dimensions of each sample are not the same. """ -struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG} +struct MultiStepSampler{names,S<:Union{Nothing,Int},R<:AbstractRNG} n::Int batchsize::Int stacksize::S @@ -317,10 +401,15 @@ struct MultiStepSampler{names, S <: Union{Nothing,Int}, R <: AbstractRNG} end MultiStepSampler(t::AbstractTraces; kw...) = MultiStepSampler{keys(t)}(; kw...) -function MultiStepSampler{names}(; n::Int, batchsize, stacksize=nothing, rng=Random.default_rng()) where {names} +function MultiStepSampler{names}(; + n::Int, + batchsize, + stacksize = nothing, + rng = Random.default_rng(), +) where {names} @assert n >= 1 "n must be ≥ 1." ss = stacksize == 1 ? nothing : stacksize - MultiStepSampler{names, typeof(ss), typeof(rng)}(n, batchsize, ss, rng) + MultiStepSampler{names,typeof(ss),typeof(rng)}(n, batchsize, ss, rng) end function valid_range(s::MultiStepSampler, eb::EpisodesBuffer) @@ -339,33 +428,59 @@ function StatsBase.sample(s::MultiStepSampler{names}, ts) where {names} StatsBase.sample(s, ts, Val(names)) end -function StatsBase.sample(s::MultiStepSampler, t::EpisodesBuffer, ::Val{names}) where names +function StatsBase.sample( + s::MultiStepSampler, + t::EpisodesBuffer, + ::Val{names}, +) where {names} weights, ns = valid_range(s, t) - inds = StatsBase.sample(s.rng, 1:length(t), StatsBase.FrequencyWeights(weights[1:end-1]), s.batchsize) + inds = StatsBase.sample( + s.rng, + 1:length(t), + StatsBase.FrequencyWeights(weights[1:end-1]), + s.batchsize, + ) fetch(s, t, Val(names), inds, ns) end -function fetch(s::MultiStepSampler, ts::EpisodesBuffer, ::Val{names}, inds, ns) where names - NamedTuple{names}(map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names)) +function fetch( + s::MultiStepSampler, + ts::EpisodesBuffer, + ::Val{names}, + inds, + ns, +) where {names} + NamedTuple{names}( + map(name -> collect(fetch(s, ts[name], Val(name), inds, ns[inds])), names), + ) end function fetch(::MultiStepSampler, trace, ::Val, inds, ns) - [trace[idx:(idx + ns[i] - 1)] for (i,idx) in enumerate(inds)] -end - -function fetch(s::MultiStepSampler{names, Int}, trace::AbstractTrace, ::Union{Val{:state}, Val{:next_state}}, inds, ns) where {names} - [trace[[idx + i + n - 1 for i in -s.stacksize+1:0, n in 1:ns[j]]] for (j,idx) in enumerate(inds)] -end - -function StatsBase.sample(s::MultiStepSampler{names}, e::EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces}) where {names} + [trace[idx:(idx+ns[i]-1)] for (i, idx) in enumerate(inds)] +end + +function fetch( + s::MultiStepSampler{names,Int}, + trace::AbstractTrace, + ::Union{Val{:state},Val{:next_state}}, + inds, + ns, +) where {names} + [ + trace[[idx + i + n - 1 for i = -s.stacksize+1:0, n = 1:ns[j]]] for + (j, idx) in enumerate(inds) + ] +end + +function StatsBase.sample( + s::MultiStepSampler{names}, + e::EpisodesBuffer{<:Any,<:Any,<:CircularPrioritizedTraces}, +) where {names} t = e.traces p = collect(deepcopy(t.priorities)) w = StatsBase.FrequencyWeights(p) - valids, ns = valid_range(s,e) + valids, ns = valid_range(s, e) w .*= valids[1:length(t)] inds = StatsBase.sample(s.rng, eachindex(w), w, s.batchsize) - merge( - (key=t.keys[inds], priority=p[inds]), - fetch(s, e, Val(names), inds, ns) - ) + merge((key = t.keys[inds], priority = p[inds]), fetch(s, e, Val(names), inds, ns)) end diff --git a/src/traces.jl b/src/traces.jl index 35e1f77..9ac9804 100644 --- a/src/traces.jl +++ b/src/traces.jl @@ -12,7 +12,8 @@ abstract type AbstractTrace{E} <: AbstractVector{E} end Base.convert(::Type{AbstractTrace}, x::AbstractTrace) = x -Base.summary(io::IO, t::AbstractTrace) = print(io, "$(length(t))-element $(nameof(typeof(t)))") +Base.summary(io::IO, t::AbstractTrace) = + print(io, "$(length(t))-element $(nameof(typeof(t)))") ##### @@ -31,7 +32,8 @@ struct Trace{T,E} <: AbstractTrace{E} parent::T end -Base.summary(io::IO, t::Trace{T}) where {T} = print(io, "$(length(t))-element$(length(t) > 0 ? 's' : "") $(nameof(typeof(t))){$T}") +Base.summary(io::IO, t::Trace{T}) where {T} = + print(io, "$(length(t))-element$(length(t) > 0 ? 's' : "") $(nameof(typeof(t))){$T}") function Trace(x::T) where {T<:AbstractArray} E = eltype(x) @@ -46,10 +48,25 @@ Adapt.adapt_structure(to, t::Trace) = Trace(Adapt.adapt_structure(to, t.parent)) Base.convert(::Type{AbstractTrace}, x::AbstractArray) = Trace(x) Base.size(x::Trace) = (size(x.parent, ndims(x.parent)),) -Base.getindex(s::Trace, I) = Base.maybeview(s.parent, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) -Base.setindex!(s::Trace, v, I) = setindex!(s.parent, v, ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))...) - -@forward Trace.parent Base.parent, Base.pushfirst!, Base.push!, Base.append!, Base.prepend!, Base.pop!, Base.popfirst!, Base.empty!, Base.eltype +Base.getindex(s::Trace, I) = Base.maybeview( + s.parent, + ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))..., +) +Base.setindex!(s::Trace, v, I) = setindex!( + s.parent, + v, + ntuple(i -> i == ndims(s.parent) ? I : (:), Val(ndims(s.parent)))..., +) + +@forward Trace.parent Base.parent, +Base.pushfirst!, +Base.push!, +Base.append!, +Base.prepend!, +Base.pop!, +Base.popfirst!, +Base.empty!, +Base.eltype #By default, AbstractTrace have infinity capacity (like a Vector). This method is specialized for #CircularArraySARTSTraces in common.jl. The functions below are made that way to avoid type piracy. @@ -88,7 +105,8 @@ Dedicated for `MultiplexTraces` to avoid scalar indexing when `view(view(t::Mult struct RelativeTrace{left,right,T,E} <: AbstractTrace{E} trace::Trace{T,E} end -RelativeTrace{left,right}(t::Trace{T,E}) where {left,right,T,E} = RelativeTrace{left,right,T,E}(t) +RelativeTrace{left,right}(t::Trace{T,E}) where {left,right,T,E} = + RelativeTrace{left,right,T,E}(t) Base.size(x::RelativeTrace{0,-1}) = (max(0, length(x.trace) - 1),) Base.size(x::RelativeTrace{1,0}) = (max(0, length(x.trace) - 1),) @@ -123,13 +141,18 @@ end function MultiplexTraces{names}(t) where {names} if length(names) != 2 - throw(ArgumentError("MultiplexTraces has exactly two sub traces, got $(length(names)) trace names")) + throw( + ArgumentError( + "MultiplexTraces has exactly two sub traces, got $(length(names)) trace names", + ), + ) end trace = convert(AbstractTrace, t) MultiplexTraces{names,typeof(trace),eltype(trace)}(trace) end -Adapt.adapt_structure(to, t::MultiplexTraces{names}) where {names} = MultiplexTraces{names}(Adapt.adapt_structure(to, t.trace)) +Adapt.adapt_structure(to, t::MultiplexTraces{names}) where {names} = + MultiplexTraces{names}(Adapt.adapt_structure(to, t.trace)) Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names} = _getindex(t, Val(k)) @@ -145,8 +168,10 @@ Base.getindex(t::MultiplexTraces{names}, k::Symbol) where {names} = _getindex(t, return :($ex) end -Base.getindex(t::MultiplexTraces{names}, I::Int) where {names} = NamedTuple{names}((t.trace[I], t.trace[I+1])) -Base.getindex(t::MultiplexTraces{names}, I::AbstractArray{Int}) where {names} = NamedTuple{names}((t.trace[I], t.trace[I.+1])) +Base.getindex(t::MultiplexTraces{names}, I::Int) where {names} = + NamedTuple{names}((t.trace[I], t.trace[I+1])) +Base.getindex(t::MultiplexTraces{names}, I::AbstractArray{Int}) where {names} = + NamedTuple{names}((t.trace[I], t.trace[I.+1])) Base.size(t::MultiplexTraces) = (max(0, length(t.trace) - 1),) capacity(t::MultiplexTraces) = capacity(t.trace) @@ -154,13 +179,19 @@ capacity(t::MultiplexTraces) = capacity(t.trace) @forward MultiplexTraces.trace Base.parent, Base.pop!, Base.popfirst!, Base.empty! for f in (:push!, :pushfirst!, :append!, :prepend!) - @eval function Base.$f(t::MultiplexTraces{names}, x::NamedTuple{ks,Tuple{Ts}}) where {names,ks,Ts} + @eval function Base.$f( + t::MultiplexTraces{names}, + x::NamedTuple{ks,Tuple{Ts}}, + ) where {names,ks,Ts} k, v = first(ks), first(x) if k in names $f(t.trace, v) end end - @eval function Base.$f(t::MultiplexTraces{names}, x::RelativeTrace{left, right}) where {names, left, right} + @eval function Base.$f( + t::MultiplexTraces{names}, + x::RelativeTrace{left,right}, + ) where {names,left,right} if left == 0 #do not accept appending the second name as it would be appended twice $f(t[first(names)].trace, x.trace) end @@ -201,7 +232,7 @@ end index_ = build_trace_index(names, Trs) # Generate code, i.e. find the correct index for a given key ex = :() - + for name in names if QuoteNode(name) == QuoteNode(k) index_element = index_[k] @@ -240,7 +271,10 @@ function Base.:(+)(t1::Traces{k1,T,N,T1}, t2::AbstractTraces{k2,T2}) where {k1,T Traces{ks,typeof(ts),length(ks),Tuple{T1.types...,T2.types...}}(ts) end -function Base.:(+)(t1::Traces{k1,T1,N1,E1}, t2::Traces{k2,T2,N2,E2}) where {k1,T1,N1,E1,k2,T2,N2,E2} +function Base.:(+)( + t1::Traces{k1,T1,N1,E1}, + t2::Traces{k2,T2,N2,E2}, +) where {k1,T1,N1,E1,k2,T2,N2,E2} ks = (k1..., k2...) ts = (t1.traces..., t2.traces...) Traces{ks,typeof(ts),length(ks),Tuple{E1.types...,E2.types...}}(ts) @@ -250,7 +284,7 @@ Base.size(t::Traces) = (mapreduce(length, min, t.traces),) max_length(t::Traces) = mapreduce(length, max, t.traces) function capacity(t::Traces{names,Trs,N,E}) where {names,Trs,N,E} - minimum(map(idx->capacity(t[idx]), names)) + minimum(map(idx -> capacity(t[idx]), names)) end @generated function Base.push!(ts::Traces, xs::NamedTuple{N,T}) where {N,T} @@ -269,11 +303,15 @@ end return :($ex) end -@generated function Base.pushfirst!(ts::Traces{names,Trs,N,E}, ::Val{k}, v) where {names,Trs,N,E,k} +@generated function Base.pushfirst!( + ts::Traces{names,Trs,N,E}, + ::Val{k}, + v, +) where {names,Trs,N,E,k} index_ = build_trace_index(names, Trs) # Generate code, i.e. find the correct index for a given key ex = :() - + for name in names if QuoteNode(name) == QuoteNode(k) index_element = index_[k] @@ -285,11 +323,15 @@ end return :($ex) end -@generated function Base.push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v) where {names,Trs,N,E,k} +@generated function Base.push!( + ts::Traces{names,Trs,N,E}, + ::Val{k}, + v, +) where {names,Trs,N,E,k} index_ = build_trace_index(names, Trs) # Generate code, i.e. find the correct index for a given key ex = :() - + for name in names if QuoteNode(name) == QuoteNode(k) index_element = index_[k] diff --git a/src/trajectory.jl b/src/trajectory.jl index 2c7a041..9c8fb36 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -26,7 +26,12 @@ Base.@kwdef struct Trajectory{C,S,T,F} controller::T = InsertSampleRatioController() transformer::F = identity - function Trajectory(c::C, s::S, t::T=InsertSampleRatioController(), f=identity) where {C,S,T} + function Trajectory( + c::C, + s::S, + t::T = InsertSampleRatioController(), + f = identity, + ) where {C,S,T} if c isa EpisodesBuffer new{C,S,T,typeof(f)}(c, s, t, f) else @@ -35,7 +40,12 @@ Base.@kwdef struct Trajectory{C,S,T,F} end end - function Trajectory(container::C, sampler::S, controller::T, transformer) where {C,S,T<:AsyncInsertSampleRatioController} + function Trajectory( + container::C, + sampler::S, + controller::T, + transformer, + ) where {C,S,T<:AsyncInsertSampleRatioController} t = Threads.@spawn while true for msg in controller.ch_in if msg.f === Base.push! @@ -51,7 +61,8 @@ Base.@kwdef struct Trajectory{C,S,T,F} end if controller.n_inserted >= controller.threshold - if controller.n_sampled <= (controller.n_inserted - controller.threshold) * controller.ratio + if controller.n_sampled <= + (controller.n_inserted - controller.threshold) * controller.ratio batch = StatsBase.sample(sampler, container) put!(controller.ch_out, batch) controller.n_sampled += 1 @@ -62,13 +73,14 @@ Base.@kwdef struct Trajectory{C,S,T,F} bind(controller.ch_in, t) bind(controller.ch_out, t) - + new{C,S,T,typeof(transformer)}(container, sampler, controller, transformer) end end TrajectoryStyle(::Trajectory) = SyncTrajectoryStyle() -TrajectoryStyle(::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = AsyncTrajectoryStyle() +TrajectoryStyle(::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = + AsyncTrajectoryStyle() Base.bind(::Trajectory, ::Task) = nothing @@ -89,9 +101,12 @@ struct CallMsg kw::Any end -Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.push!, (x,), NamedTuple())) -Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = put!(t.controller.ch_in, CallMsg(Base.append!, (x,), NamedTuple())) -Base.setindex!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, v, I...) = put!(t.controller.ch_in, CallMsg(Base.setindex!, (v, I...), NamedTuple())) +Base.push!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = + put!(t.controller.ch_in, CallMsg(Base.push!, (x,), NamedTuple())) +Base.append!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, x) = + put!(t.controller.ch_in, CallMsg(Base.append!, (x,), NamedTuple())) +Base.setindex!(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, v, I...) = + put!(t.controller.ch_in, CallMsg(Base.setindex!, (v, I...), NamedTuple())) function Base.append!(t::Trajectory, x) append!(t.container, x) @@ -126,7 +141,8 @@ StatsBase.sample(t::Trajectory) = StatsBase.sample(t.sampler, t.container) Keep sampling batches from the trajectory until the trajectory is not ready to be sampled yet due to the `controller`. """ -iter(t::Trajectory) = Iterators.takewhile(_ -> on_sample!(t), Iterators.cycle(SampleGenerator(t))) +iter(t::Trajectory) = + Iterators.takewhile(_ -> on_sample!(t), Iterators.cycle(SampleGenerator(t))) #The use of iterate(::SampleGenerator) has been suspended in v0.1.8 due to a significant drop in performance. function Base.iterate(t::Trajectory, args...) @@ -135,11 +151,13 @@ function Base.iterate(t::Trajectory, args...) else nothing end -end +end Base.IteratorSize(t::Trajectory) = Base.IteratorSize(iter(t)) -Base.iterate(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...) = iterate(t.controller.ch_out, args...) -Base.IteratorSize(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = Base.IteratorSize(t.controller.ch_out) +Base.iterate(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}, args...) = + iterate(t.controller.ch_out, args...) +Base.IteratorSize(t::Trajectory{<:Any,<:Any,<:AsyncInsertSampleRatioController}) = + Base.IteratorSize(t.controller.ch_out) Base.keys(t::Trajectory) = keys(t.container) Base.haskey(t::Trajectory, k) = k in keys(t) diff --git a/test/common.jl b/test/common.jl index 117d6ad..ae0c4d0 100644 --- a/test/common.jl +++ b/test/common.jl @@ -1,14 +1,14 @@ @testset "sum_tree" begin t = SumTree(8) - for i in 1:4 + for i = 1:4 push!(t, i) end @test length(t) == 4 @test size(t) == (4,) - for i in 5:16 + for i = 5:16 push!(t, i) end @@ -25,76 +25,77 @@ end @testset "CircularArraySARTSATraces" begin - t = CircularArraySARTSATraces(; - capacity=3, - state=Float32 => (2, 3), - action=Float32 => (2,), - reward=Float32 => (), - terminal=Bool => () - ) |> gpu + t = + CircularArraySARTSATraces(; + capacity = 3, + state = Float32 => (2, 3), + action = Float32 => (2,), + reward = Float32 => (), + terminal = Bool => (), + ) |> gpu @test t isa CircularArraySARTSATraces @test ReinforcementLearningTrajectories.capacity(t) == 3 @test CircularArrayBuffers.capacity(t) == 3 - push!(t, (state=ones(Float32, 2, 3),) |> gpu) - push!(t, (action=ones(Float32, 2), next_state=ones(Float32, 2, 3) * 2) |> gpu) + push!(t, (state = ones(Float32, 2, 3),) |> gpu) + push!(t, (action = ones(Float32, 2), next_state = ones(Float32, 2, 3) * 2) |> gpu) @test length(t) == 0 - push!(t, (reward=1.0f0, terminal=false) |> gpu) + push!(t, (reward = 1.0f0, terminal = false) |> gpu) @test length(t) == 0 # next_action is still missing - push!(t, (action=ones(Float32, 2) * 2,) |> gpu) + push!(t, (action = ones(Float32, 2) * 2,) |> gpu) @test length(t) == 1 - push!(t, (state=ones(Float32, 2, 3) * 3,) |> gpu) + push!(t, (state = ones(Float32, 2, 3) * 3,) |> gpu) @test length(t) == 1 # this will trigger the scalar indexing of CuArray CUDA.@allowscalar @test t[1] == ( - state=ones(Float32, 2, 3), - next_state=ones(Float32, 2, 3) * 2, - action=ones(Float32, 2), - next_action=ones(Float32, 2) * 2, - reward=1.0f0, - terminal=false, + state = ones(Float32, 2, 3), + next_state = ones(Float32, 2, 3) * 2, + action = ones(Float32, 2), + next_action = ones(Float32, 2) * 2, + reward = 1.0f0, + terminal = false, ) - push!(t, (reward=2.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 4, action=ones(Float32, 2) * 3) |> gpu) + push!(t, (reward = 2.0f0, terminal = false)) + push!(t, (state = ones(Float32, 2, 3) * 4, action = ones(Float32, 2) * 3) |> gpu) @test length(t) == 2 - push!(t, (reward=3.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 5, action=ones(Float32, 2) * 4) |> gpu) + push!(t, (reward = 3.0f0, terminal = false)) + push!(t, (state = ones(Float32, 2, 3) * 5, action = ones(Float32, 2) * 4) |> gpu) @test length(t) == 3 - push!(t, (reward=4.0f0, terminal=false)) - push!(t, (state=ones(Float32, 2, 3) * 6, action=ones(Float32, 2) * 5) |> gpu) - push!(t, (reward=5.0f0, terminal=false)) + push!(t, (reward = 4.0f0, terminal = false)) + push!(t, (state = ones(Float32, 2, 3) * 6, action = ones(Float32, 2) * 5) |> gpu) + push!(t, (reward = 5.0f0, terminal = false)) @test length(t) == 3 - push!(t, (action=ones(Float32, 2) * 6,) |> gpu) + push!(t, (action = ones(Float32, 2) * 6,) |> gpu) @test length(t) == 3 # this will trigger the scalar indexing of CuArray CUDA.@allowscalar @test t[1] == ( - state=ones(Float32, 2, 3) * 3, - next_state=ones(Float32, 2, 3) * 4, - action=ones(Float32, 2) * 3, - next_action=ones(Float32, 2) * 4, - reward=3.0f0, - terminal=false, + state = ones(Float32, 2, 3) * 3, + next_state = ones(Float32, 2, 3) * 4, + action = ones(Float32, 2) * 3, + next_action = ones(Float32, 2) * 4, + reward = 3.0f0, + terminal = false, ) CUDA.@allowscalar @test t[end] == ( - state=ones(Float32, 2, 3) * 5, - next_state=ones(Float32, 2, 3) * 6, - action=ones(Float32, 2) * 5, - next_action=ones(Float32, 2) * 6, - reward=5.0f0, - terminal=false, + state = ones(Float32, 2, 3) * 5, + next_state = ones(Float32, 2, 3) * 6, + action = ones(Float32, 2) * 5, + next_action = ones(Float32, 2) * 6, + reward = 5.0f0, + terminal = false, ) batch = t[1:3] @@ -107,16 +108,19 @@ end @testset "ElasticArraySARTSTraces" begin t = ElasticArraySARTSTraces(; - state=Float32 => (2, 3), - action=Int => (), - reward=Float32 => (), - terminal=Bool => () + state = Float32 => (2, 3), + action = Int => (), + reward = Float32 => (), + terminal = Bool => (), ) @test t isa ElasticArraySARTSTraces - push!(t, (state=ones(Float32, 2, 3), action=1)) - push!(t, (reward=1.0f0, terminal=false, state=ones(Float32, 2, 3) * 2, action=2)) + push!(t, (state = ones(Float32, 2, 3), action = 1)) + push!( + t, + (reward = 1.0f0, terminal = false, state = ones(Float32, 2, 3) * 2, action = 2), + ) @test length(t) == 1 @@ -127,12 +131,12 @@ end @testset "CircularArraySLARTTraces" begin t = CircularArraySLARTTraces(; - capacity=3, - state=Float32 => (2, 3), - legal_actions_mask=Bool => (5,), - action=Int => (), - reward=Float32 => (), - terminal=Bool => () + capacity = 3, + state = Float32 => (2, 3), + legal_actions_mask = Bool => (5,), + action = Int => (), + reward = Float32 => (), + terminal = Bool => (), ) @test t isa CircularArraySLARTTraces @@ -142,17 +146,15 @@ end @testset "CircularPrioritizedTraces-SARTS" begin t = CircularPrioritizedTraces( - CircularArraySARTSTraces(; - capacity=3 - ), - default_priority=1.0f0 + CircularArraySARTSTraces(; capacity = 3), + default_priority = 1.0f0, ) @test ReinforcementLearningTrajectories.capacity(t) == 3 - push!(t, (state=0, action=0)) + push!(t, (state = 0, action = 0)) - for i in 1:5 - push!(t, (reward=1.0f0, terminal=false, state=i, action=i)) + for i = 1:5 + push!(t, (reward = 1.0f0, terminal = false, state = i, action = i)) end @test length(t) == 3 @@ -174,46 +176,42 @@ end #EpisodesBuffer t = CircularPrioritizedTraces( - CircularArraySARTSTraces(; - capacity=10 - ), - default_priority=1.0f0 + CircularArraySARTSTraces(; capacity = 10), + default_priority = 1.0f0, ) eb = EpisodesBuffer(t) push!(eb, (state = 1, action = 1)) for i = 1:5 - push!(eb, (state = i+1, action = i+1, reward = i, terminal = false)) + push!(eb, (state = i + 1, action = i + 1, reward = i, terminal = false)) end push!(eb, (state = 7, action = 7)) - for (j,i) = enumerate(8:11) - push!(eb, (state = i, action = i, reward = i-1, terminal = false)) + for (j, i) in enumerate(8:11) + push!(eb, (state = i, action = i, reward = i - 1, terminal = false)) end s = BatchSampler(1000) b = sample(s, eb) cm = counter(b[:state]) @test !haskey(cm, 6) @test !haskey(cm, 11) - @test all(in(keys(cm)), [1:5;7:10]) + @test all(in(keys(cm)), [1:5; 7:10]) eb[:priority, [1, 2]] = [0, 0] - @test eb[:priority] == [zeros(2);ones(8)] + @test eb[:priority] == [zeros(2); ones(8)] end @testset "CircularPrioritizedTraces-SARTSA" begin t = CircularPrioritizedTraces( - CircularArraySARTSATraces(; - capacity=3 - ), - default_priority=1.0f0 + CircularArraySARTSATraces(; capacity = 3), + default_priority = 1.0f0, ) @test ReinforcementLearningTrajectories.capacity(t) == 3 - push!(t, (state=0, action=0)) + push!(t, (state = 0, action = 0)) - for i in 1:5 - push!(t, (reward=1.0f0, terminal=false, state=i, action=i)) + for i = 1:5 + push!(t, (reward = 1.0f0, terminal = false, state = i, action = i)) end @test length(t) == 3 @@ -237,28 +235,26 @@ end #EpisodesBuffer t = CircularPrioritizedTraces( - CircularArraySARTSATraces(; - capacity=10 - ), - default_priority=1.0f0 + CircularArraySARTSATraces(; capacity = 10), + default_priority = 1.0f0, ) eb = EpisodesBuffer(t) push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action = i, reward = i, terminal = false)) + push!(eb, (state = i + 1, action = i, reward = i, terminal = false)) end push!(eb, PartialNamedTuple((action = 6,))) push!(eb, (state = 7,)) for i = 8:11 - push!(eb, (state = i, action = i-1, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i - 1, reward = i - 1, terminal = false)) end - push!(eb, PartialNamedTuple((action=11,))) + push!(eb, PartialNamedTuple((action = 11,))) s = BatchSampler(1000) b = sample(s, eb) cm = counter(b[:state]) @test !haskey(cm, 6) @test !haskey(cm, 11) - @test all(in(keys(cm)), [1:5;7:10]) + @test all(in(keys(cm)), [1:5; 7:10]) end diff --git a/test/controllers.jl b/test/controllers.jl index 207a7bd..0e9bdcd 100644 --- a/test/controllers.jl +++ b/test/controllers.jl @@ -2,11 +2,11 @@ import ReinforcementLearningTrajectories: on_insert!, on_sample! @testset "controllers.jl" begin @testset "EpisodeSampleRatioController" begin #push - c = EpisodeSampleRatioController(ratio = 1/2, threshold = 5) - for st in 1:50 - transition = (state = 1, action = 2, reward = 5., terminal = (st % 5 == 0)) + c = EpisodeSampleRatioController(ratio = 1 / 2, threshold = 5) + for st = 1:50 + transition = (state = 1, action = 2, reward = 5.0, terminal = (st % 5 == 0)) on_insert!(c, 1, transition) - if st in 25:10:45 + if st in 25:10:45 @test on_sample!(c) @test !on_sample!(c) else @@ -14,9 +14,14 @@ import ReinforcementLearningTrajectories: on_insert!, on_sample! end end #append - c = EpisodeSampleRatioController(ratio = 1/2, threshold = 5) - for e in 1:20 - transitions = (state = ones(5), action = ones(5), reward = ones(5), terminal = [false, false, false, false, iseven(e)]) + c = EpisodeSampleRatioController(ratio = 1 / 2, threshold = 5) + for e = 1:20 + transitions = ( + state = ones(5), + action = ones(5), + reward = ones(5), + terminal = [false, false, false, false, iseven(e)], + ) on_insert!(c, length(first(transitions)), transitions) if e in 10:4:20 @test on_sample!(c) @@ -25,9 +30,25 @@ import ReinforcementLearningTrajectories: on_insert!, on_sample! @test !on_sample!(c) end end - c = EpisodeSampleRatioController(ratio = 1/4, threshold = 5) - for e in 1:10 - transitions = (state = ones(10), action = ones(10), reward = ones(10), terminal = [false, false, false, false, true, false, false, false, false, true]) + c = EpisodeSampleRatioController(ratio = 1 / 4, threshold = 5) + for e = 1:10 + transitions = ( + state = ones(10), + action = ones(10), + reward = ones(10), + terminal = [ + false, + false, + false, + false, + true, + false, + false, + false, + false, + true, + ], + ) on_insert!(c, length(first(transitions)), transitions) if e in 3:2:10 @test on_sample!(c) @@ -37,4 +58,4 @@ import ReinforcementLearningTrajectories: on_insert!, on_sample! end end end -end \ No newline at end of file +end diff --git a/test/episodes.jl b/test/episodes.jl index 2633307..6ea6126 100644 --- a/test/episodes.jl +++ b/test/episodes.jl @@ -4,10 +4,7 @@ using Test @testset "EpisodesBuffer" begin @testset "with circular SARTS traces" begin - eb = EpisodesBuffer( - CircularArraySARTSTraces(; - capacity=10) - ) + eb = EpisodesBuffer(CircularArraySARTSTraces(; capacity = 10)) # push first episode (five steps) push!(eb, (state = 1,)) @@ -15,14 +12,15 @@ using Test @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 for i = 1:5 - push!(eb, (state = i+1, action = i, reward = i, terminal = false)) + push!(eb, (state = i + 1, action = i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 @test eb.step_numbers[end] == i + 1 - @test eb.episodes_lengths[end-i:end] == fill(i, i+1) + @test eb.episodes_lengths[end-i:end] == fill(i, i + 1) end - @test eb[end] == (state = 5, next_state = 6, action = 5, reward = 5, terminal = false) - @test eb.sampleable_inds == [1,1,1,1,1,0] + @test eb[end] == + (state = 5, next_state = 6, action = 5, reward = 5, terminal = false) + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0] @test length(eb.traces) == 5 # start second episode @@ -31,38 +29,40 @@ using Test @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 - @test eb.sampleable_inds == [1,1,1,1,1,0,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0, 0] @test eb[:reward][6] == 0 # 6 is not a valid index, filled with dummy value zero @test_throws BoundsError eb[6] # 6 is not a valid index @test_throws BoundsError eb[7] # 7 is not a valid index # push four steps of second episode ep2_len = 0 - for (i,s) = enumerate(8:11) + for (i, s) in enumerate(8:11) ep2_len += 1 - push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 @test eb.step_numbers[end] == i + 1 @test eb.episodes_lengths[end-i:end] == fill(ep2_len, ep2_len + 1) end - @test eb[end] == (state = 10, next_state = 11, action = 10, reward = 10, terminal = false) - @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] + @test eb[end] == + (state = 10, next_state = 11, action = 10, reward = 10, terminal = false) + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0] @test length(eb) == 10 # push two more steps of second episode, which replace the oldest steps in the buffer - for (i, s) = enumerate(12:13) + for (i, s) in enumerate(12:13) ep2_len += 1 - push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 @test eb.step_numbers[end] == i + 1 + 4 @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end - @test eb[end] == (state = 12, next_state = 13, action = 12, reward = 12, terminal = false) - @test eb.sampleable_inds == [1,1,1,0,1,1,1,1,1,1,0] + @test eb[end] == + (state = 12, next_state = 13, action = 12, reward = 12, terminal = false) + @test eb.sampleable_inds == [1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0] # verify episode 2 - for (i,s) in enumerate(3:13) + for (i, s) in enumerate(3:13) if i in (4, 11) @test eb.sampleable_inds[i] == 0 continue @@ -75,52 +75,62 @@ using Test end # push third episode - push!(eb, (state = 14, )) + push!(eb, (state = 14,)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 #push until it reaches it own start - for (i,s) in enumerate(15:26) - push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) + for (i, s) in enumerate(15:26) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) end @test eb.sampleable_inds == [fill(true, 10); [false]] @test eb.episodes_lengths == fill(length(15:26), 11) @test eb.step_numbers == [3:13;] popfirst!(eb) - @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 + @test length(eb) == + length(eb.sampleable_inds) - 1 == + length(eb.step_numbers) - 1 == + length(eb.episodes_lengths) - 1 == + 9 @test first(eb.step_numbers) == 4 pop!(eb) - @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8 + @test length(eb) == + length(eb.sampleable_inds) - 1 == + length(eb.step_numbers) - 1 == + length(eb.episodes_lengths) - 1 == + 8 @test last(eb.step_numbers) == 12 @test size(eb) == size(eb.traces) == (8,) empty!(eb) - @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) + @test size(eb) == + (0,) == + size(eb.traces) == + size(eb.sampleable_inds) == + size(eb.episodes_lengths) == + size(eb.step_numbers) end @testset "with SARTSA traces and PartialNamedTuple" begin - eb = EpisodesBuffer( - CircularArraySARTSATraces(; - capacity=10) - ) + eb = EpisodesBuffer(CircularArraySARTSATraces(; capacity = 10)) # push first episode (five steps) push!(eb, (state = 1,)) @test eb.sampleable_inds[end] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 for i = 1:5 - push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + push!(eb, (state = i + 1, action = i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 if length(eb) >= 1 @test eb.sampleable_inds[end-2] == 1 end @test eb.step_numbers[end] == i + 1 - @test eb.episodes_lengths[end-i:end] == fill(i, i+1) + @test eb.episodes_lengths[end-i:end] == fill(i, i + 1) end - @test eb.sampleable_inds == [1,1,1,1,0,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 0, 0] push!(eb, PartialNamedTuple((action = 6,))) - @test eb.sampleable_inds == [1,1,1,1,1,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0] @test length(eb.traces) == 5 # start second episode @@ -129,15 +139,15 @@ using Test @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 - @test eb.sampleable_inds == [1,1,1,1,1,0,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0, 0] @test eb[5][:next_action] == eb[:next_action][5] == 6 @test eb[:reward][6] == 0 # 6 is not a valid index, the reward there is dummy, filled as zero @test_throws BoundsError eb[6] # 6 is not a valid index ep2_len = 0 # push four steps of second episode - for (i,s) = enumerate(8:11) + for (i, s) in enumerate(8:11) ep2_len += 1 - push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 if eb.step_numbers[end] > 2 @@ -146,7 +156,7 @@ using Test @test eb.step_numbers[end] == i + 1 @test eb.episodes_lengths[end-i:end] == fill(ep2_len, ep2_len + 1) end - @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0] @test length(eb.traces) == 9 # an action is missing at this stage @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @@ -155,9 +165,9 @@ using Test end # push two more steps of second episode, which replace the oldest steps in the buffer - for (i, s) = enumerate(12:13) + for (i, s) in enumerate(12:13) ep2_len += 1 - push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 if eb.step_numbers[end] > 2 @@ -170,7 +180,7 @@ using Test @test length(eb.traces) == 10 # verify episode 2 - for (i,s) in enumerate(3:13) + for (i, s) in enumerate(3:13) if i in (4, 11) @test eb.sampleable_inds[i] == 0 continue @@ -189,32 +199,41 @@ using Test @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 # push until it reaches it own start - for (i,s) in enumerate(15:26) - push!(eb, (state = s, action = s-1, reward = s-1, terminal = false)) + for (i, s) in enumerate(15:26) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) end push!(eb, PartialNamedTuple((action = 26,))) @test eb.sampleable_inds == [fill(true, 10); [false]] @test eb.episodes_lengths == fill(length(15:26), 11) @test eb.step_numbers == [3:13;] step = popfirst!(eb) - @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 9 + @test length(eb) == + length(eb.sampleable_inds) - 1 == + length(eb.step_numbers) - 1 == + length(eb.episodes_lengths) - 1 == + 9 @test first(eb.step_numbers) == 4 step = pop!(eb) - @test length(eb) == length(eb.sampleable_inds) - 1 == length(eb.step_numbers) - 1 == length(eb.episodes_lengths) - 1 == 8 + @test length(eb) == + length(eb.sampleable_inds) - 1 == + length(eb.step_numbers) - 1 == + length(eb.episodes_lengths) - 1 == + 8 @test last(eb.step_numbers) == 12 @test size(eb) == size(eb.traces) == (8,) empty!(eb) - @test size(eb) == (0,) == size(eb.traces) == size(eb.sampleable_inds) == size(eb.episodes_lengths) == size(eb.step_numbers) + @test size(eb) == + (0,) == + size(eb.traces) == + size(eb.sampleable_inds) == + size(eb.episodes_lengths) == + size(eb.step_numbers) end @testset "with vector traces" begin - eb = EpisodesBuffer( - Traces(; - state=Int[], - reward=Int[]) - ) + eb = EpisodesBuffer(Traces(; state = Int[], reward = Int[])) push!(eb, (state = 1,)) # partial inserting for i = 1:15 - push!(eb, (state = i+1, reward =i)) + push!(eb, (state = i + 1, reward = i)) end @test length(eb.traces) == 15 @test eb.sampleable_inds == [fill(true, 15); [false]] @@ -222,31 +241,29 @@ using Test @test eb.step_numbers == [1:16;] push!(eb, (state = 1,)) # partial inserting for i = 1:15 - push!(eb, (state = i+1, reward =i)) + push!(eb, (state = i + 1, reward = i)) end - @test eb.sampleable_inds == [fill(true, 15); [false];fill(true, 15); [false]] + @test eb.sampleable_inds == [fill(true, 15); [false]; fill(true, 15); [false]] @test all(==(15), eb.episodes_lengths) - @test eb.step_numbers == [1:16;1:16] + @test eb.step_numbers == [1:16; 1:16] @test length(eb) == 31 end @testset "with ElasticArraySARTSTraces" begin - eb = EpisodesBuffer( - ElasticArraySARTSTraces() - ) + eb = EpisodesBuffer(ElasticArraySARTSTraces()) # push first episode (five steps) push!(eb, (state = 1,)) @test eb.sampleable_inds[end] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 for i = 1:5 - push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + push!(eb, (state = i + 1, action = i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 @test eb.step_numbers[end] == i + 1 - @test eb.episodes_lengths[end-i:end] == fill(i, i+1) + @test eb.episodes_lengths[end-i:end] == fill(i, i + 1) end - @test eb.sampleable_inds == [1,1,1,1,1,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0] @test length(eb.traces) == 5 # start second episode @@ -255,33 +272,33 @@ using Test @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 - @test eb.sampleable_inds == [1,1,1,1,1,0,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0, 0] @test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero @test_throws BoundsError eb[6] #6 is not a valid index ep2_len = 0 # push four steps of second episode - for (j,i) = enumerate(8:11) + for (j, i) in enumerate(8:11) ep2_len += 1 - push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i - 1, reward = i - 1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 @test eb.step_numbers[end] == j + 1 @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) end - @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,1,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0] @test length(eb.traces) == 10 # push two more steps of second episode, which replace the oldest steps in the buffer - for (i, s) = enumerate(12:13) + for (i, s) in enumerate(12:13) ep2_len += 1 - push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 1 @test eb.step_numbers[end] == i + 1 + 4 @test eb.episodes_lengths[end-ep2_len:end] == fill(ep2_len, ep2_len + 1) end # verify episode 2 - for i in 3:13 + for i = 3:13 if i in (6, 13) @test eb.sampleable_inds[i] == 0 continue @@ -294,15 +311,15 @@ using Test end # push third episode - push!(eb, (state = 14, )) + push!(eb, (state = 14,)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 # push until it reaches it own start - for (i,s) in enumerate(15:26) - push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + for (i, s) in enumerate(15:26) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) end @test eb.sampleable_inds[end-5:end] == [fill(true, 5); [false]] @test eb.episodes_lengths[end-10:end] == fill(length(15:26), 11) @@ -321,26 +338,24 @@ using Test end @testset "ElasticArraySARTSATraces with PartialNamedTuple" begin - eb = EpisodesBuffer( - ElasticArraySARTSATraces() - ) + eb = EpisodesBuffer(ElasticArraySARTSATraces()) # push first episode (five steps) push!(eb, (state = 1,)) @test eb.sampleable_inds[end] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 for i = 1:5 - push!(eb, (state = i+1, action =i, reward = i, terminal = false)) + push!(eb, (state = i + 1, action = i, reward = i, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 if eb.step_numbers[end] > 2 @test eb.sampleable_inds[end-2] == 1 end @test eb.step_numbers[end] == i + 1 - @test eb.episodes_lengths[end-i:end] == fill(i, i+1) + @test eb.episodes_lengths[end-i:end] == fill(i, i + 1) end push!(eb, PartialNamedTuple((action = 6,))) - @test eb.sampleable_inds == [1,1,1,1,1,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0] @test length(eb.traces) == 5 # start second episode @@ -349,15 +364,15 @@ using Test @test eb.sampleable_inds[end-1] == 0 @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 - @test eb.sampleable_inds == [1,1,1,1,1,0,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0, 0] @test eb[:action][6] == 6 @test eb[:next_action][5] == 6 @test eb[:reward][6] == 0 #6 is not a valid index, the reward there is dummy, filled as zero ep2_len = 0 # push four steps of second episode - for (j,i) = enumerate(8:11) + for (j, i) in enumerate(8:11) ep2_len += 1 - push!(eb, (state = i, action =i-1, reward = i-1, terminal = false)) + push!(eb, (state = i, action = i - 1, reward = i - 1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 if eb.step_numbers[end] > 2 @@ -366,12 +381,12 @@ using Test @test eb.step_numbers[end] == j + 1 @test eb.episodes_lengths[end-j:end] == fill(ep2_len, ep2_len + 1) end - @test eb.sampleable_inds == [1,1,1,1,1,0,1,1,1,0,0] + @test eb.sampleable_inds == [1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0] @test length(eb.traces) == 9 # an action is missing at this stage # push two more steps of second episode, which replace the oldest steps in the buffer - for (i, s) = enumerate(12:13) + for (i, s) in enumerate(12:13) ep2_len += 1 - push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) @test eb.sampleable_inds[end] == 0 @test eb.sampleable_inds[end-1] == 0 if eb.step_numbers[end] > 2 @@ -384,7 +399,7 @@ using Test @test length(eb.traces) == 12 # verify episode 2 - for i in 1:13 + for i = 1:13 if i in (6, 13) @test eb.sampleable_inds[i] == 0 continue @@ -403,8 +418,8 @@ using Test @test eb.episodes_lengths[end] == 0 @test eb.step_numbers[end] == 1 #push until it reaches it own start - for (i,s) in enumerate(15:26) - push!(eb, (state = s, action =s-1, reward = s-1, terminal = false)) + for (i, s) in enumerate(15:26) + push!(eb, (state = s, action = s - 1, reward = s - 1, terminal = false)) end push!(eb, PartialNamedTuple((action = 26,))) @test eb.sampleable_inds[end-10:end] == [fill(true, 10); [false]] diff --git a/test/normalization.jl b/test/normalization.jl index ae6f6f9..0305974 100644 --- a/test/normalization.jl +++ b/test/normalization.jl @@ -8,13 +8,13 @@ import OnlineStats: mean, std traj = Trajectory( container = nt, sampler = BatchSampler(1000), - controller = InsertSampleRatioController(ratio = Inf, threshold = 0) + controller = InsertSampleRatioController(ratio = Inf, threshold = 0), ) m = mean(0:4) s = std(0:4) - ss = std([0,1,2,2,3,4]) + ss = std([0, 1, 2, 2, 3, 4]) push!(traj, (state = fill(m, 5), action = 1)) #this also updates state moments - for i in 0:4 + for i = 0:4 r = ((1.0:5.0) .+ i) .% 5 push!(traj, (state = [r;], action = 1, reward = Float32(i), terminal = false)) end @@ -28,13 +28,13 @@ import OnlineStats: mean, std @test unnormalized_batch[:reward] == [0:4;] @test extrema(unnormalized_batch[:state]) == (0, 4) normalized_batch = nt[[1:5;]] - + normalized_batch = sample(traj) - @test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m)./ss) - @test all(extrema(normalized_batch[:next_state]) .≈ ((0, 4) .- m)./ss) - @test all(extrema(normalized_batch[:reward]) .≈ ((0, 4) .- m)./s) + @test all(extrema(normalized_batch[:state]) .≈ ((0, 4) .- m) ./ ss) + @test all(extrema(normalized_batch[:next_state]) .≈ ((0, 4) .- m) ./ ss) + @test all(extrema(normalized_batch[:reward]) .≈ ((0, 4) .- m) ./ s) #check for no mutation unnormalized_batch = t[[1:5;]] @test unnormalized_batch[:reward] == [0:4;] - @test extrema(unnormalized_batch[:state]) == (0, 4) -end \ No newline at end of file + @test extrema(unnormalized_batch[:state]) == (0, 4) +end diff --git a/test/samplers.jl b/test/samplers.jl index dc931a2..a679ba3 100644 --- a/test/samplers.jl +++ b/test/samplers.jl @@ -3,10 +3,7 @@ import ReinforcementLearningTrajectories.fetch @testset "BatchSampler" begin sz = 32 s = BatchSampler(sz) - t = Traces( - state=rand(3, 4, 5), - action=rand(1:4, 5), - ) + t = Traces(state = rand(3, 4, 5), action = rand(1:4, 5)) b = sample(s, t) @@ -15,15 +12,15 @@ import ReinforcementLearningTrajectories.fetch @test size(b.action) == (sz,) #In EpisodesBuffer - eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) + eb = EpisodesBuffer(CircularArraySARTSATraces(capacity = 10)) push!(eb, (state = 1,)) for i = 1:5 - push!(eb, (state = i+1, action = i, reward = i, terminal = false)) + push!(eb, (state = i + 1, action = i, reward = i, terminal = false)) end push!(eb, PartialNamedTuple((next_action = 6,))) push!(eb, (state = 7,)) - for (j,i) = enumerate(8:11) - push!(eb, (state = i, action = i-1, reward = i-1, terminal = false)) + for (j, i) in enumerate(8:11) + push!(eb, (state = i, action = i - 1, reward = i - 1, terminal = false)) end push!(eb, PartialNamedTuple((next_action = 11,))) @@ -32,20 +29,17 @@ import ReinforcementLearningTrajectories.fetch cm = counter(b[:state]) @test !haskey(cm, 6) @test !haskey(cm, 11) - @test all(in(keys(cm)), [1:5;7:10]) + @test all(in(keys(cm)), [1:5; 7:10]) end @testset "MetaSampler" begin t = Trajectory( - container=Traces( - a=Int[], - b=Bool[] - ), - sampler=MetaSampler(policy=BatchSampler(3), critic=BatchSampler(5)), + container = Traces(a = Int[], b = Bool[]), + sampler = MetaSampler(policy = BatchSampler(3), critic = BatchSampler(5)), ) push!(t, (a = 1,)) - for i in 1:10 - push!(t, (a=i, b=true)) + for i = 1:10 + push!(t, (a = i, b = true)) end batches = collect(t) @@ -56,16 +50,16 @@ import ReinforcementLearningTrajectories.fetch @testset "MultiBatchSampler" begin t = Trajectory( - container=Traces( - a=Int[], - b=Bool[] + container = Traces(a = Int[], b = Bool[]), + sampler = MetaSampler( + policy = BatchSampler(3), + critic = MultiBatchSampler(BatchSampler(5), 2), ), - sampler=MetaSampler(policy=BatchSampler(3), critic=MultiBatchSampler(BatchSampler(5), 2)), ) push!(t, (a = 1,)) - for i in 1:10 - push!(t, (a=i, b=true)) + for i = 1:10 + push!(t, (a = i, b = true)) end batches = collect(t) @@ -251,4 +245,4 @@ import ReinforcementLearningTrajectories.fetch terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds]) @test terminals == [[a == 5 ? 1 : 0 for a in acs] for acs in actions] end -end \ No newline at end of file +end diff --git a/test/sum_tree.jl b/test/sum_tree.jl index 8bb9d0f..f2d231c 100644 --- a/test/sum_tree.jl +++ b/test/sum_tree.jl @@ -1,15 +1,15 @@ -function gen_rand_sumtree(n, seed, type::DataType=Float32) +function gen_rand_sumtree(n, seed, type::DataType = Float32) rng = StableRNG(seed) a = SumTree(type, n) append!(a, rand(rng, type, n)) return a end -function gen_sumtree_with_zeros(n, seed, type::DataType=Float32) +function gen_sumtree_with_zeros(n, seed, type::DataType = Float32) a = gen_rand_sumtree(n, seed, type) b = rand(StableRNG(seed), Bool, n) return copy_multiply(a, b) -end +end function copy_multiply(stree, m) new_tree = deepcopy(stree) @@ -17,34 +17,55 @@ function copy_multiply(stree, m) return new_tree end -function sumtree_nozero(t::SumTree, rng::AbstractRNG, iters=1) +function sumtree_nozero(t::SumTree, rng::AbstractRNG, iters = 1) for _ in iters (_, p) = rand(rng, t) p == 0 && return false end return true end -sumtree_nozero(n::Integer, seed::Integer, iters=1) = sumtree_nozero(gen_sumtree_with_zeros(n, seed), StableRNG(seed), iters) -sumtree_nozero(n, seeds::AbstractVector, iters=1) = all(sumtree_nozero(n, seed, iters) for seed in seeds) +sumtree_nozero(n::Integer, seed::Integer, iters = 1) = + sumtree_nozero(gen_sumtree_with_zeros(n, seed), StableRNG(seed), iters) +sumtree_nozero(n, seeds::AbstractVector, iters = 1) = + all(sumtree_nozero(n, seed, iters) for seed in seeds) -function sumtree_distribution!(indices, priorities, t::SumTree, rng::AbstractRNG, iters=1000*t.length) +function sumtree_distribution!( + indices, + priorities, + t::SumTree, + rng::AbstractRNG, + iters = 1000 * t.length, +) for i = 1:iters indices[i], priorities[i] = rand(rng, t) end imap = countmap(indices) - est_pdf = Dict(k=>v/length(indices) for (k, v) in imap) - ex_pdf = Dict(k=>v/t.tree[1] for (k, v) in Dict(1:length(t) .=> t)) + est_pdf = Dict(k => v / length(indices) for (k, v) in imap) + ex_pdf = Dict(k => v / t.tree[1] for (k, v) in Dict(1:length(t) .=> t)) abserrs = [est_pdf[k] - ex_pdf[k] for k in keys(est_pdf)] return abserrs end -sumtree_distribution!(indices, priorities, n, seed, iters=1000*n) = sumtree_distribution!(indices, priorities, gen_rand_sumtree(n, seed), StableRNG(seed), iters) -function sumtree_distribution(n, seeds::AbstractVector, iters=1000*n) +sumtree_distribution!(indices, priorities, n, seed, iters = 1000 * n) = + sumtree_distribution!( + indices, + priorities, + gen_rand_sumtree(n, seed), + StableRNG(seed), + iters, + ) +function sumtree_distribution(n, seeds::AbstractVector, iters = 1000 * n) p = [zeros(Float32, iters) for _ = 1:Threads.nthreads()] i = [zeros(Float32, iters) for _ = 1:Threads.nthreads()] results = Vector{Vector{Float64}}(undef, length(seeds)) Threads.@threads for ix = 1:length(seeds) - results[ix] = sumtree_distribution!(i[Threads.threadid()], p[Threads.threadid()], gen_rand_sumtree(n, seeds[ix]), StableRNG(seeds[ix]), iters) + results[ix] = sumtree_distribution!( + i[Threads.threadid()], + p[Threads.threadid()], + gen_rand_sumtree(n, seeds[ix]), + StableRNG(seeds[ix]), + iters, + ) end return results end @@ -52,12 +73,14 @@ end @testset "SumTree" begin n = 1024 seeds = 1:100 - nozero_iters=1024 - distr_iters=1024*10_000 + nozero_iters = 1024 + distr_iters = 1024 * 10_000 abstol = 0.05 - maxerr=0.01 + maxerr = 0.01 @test sumtree_nozero(n, seeds, nozero_iters) - @test all(x->all(x .< maxerr) && sum(abs2, x) < abstol, - sumtree_distribution(n, seeds, distr_iters)) + @test all( + x -> all(x .< maxerr) && sum(abs2, x) < abstol, + sumtree_distribution(n, seeds, distr_iters), + ) end diff --git a/test/traces.jl b/test/traces.jl index 26ba16c..82069aa 100644 --- a/test/traces.jl +++ b/test/traces.jl @@ -1,21 +1,18 @@ @testset "Traces" begin - t = Traces(; - a=[1, 2], - b=Bool[0, 1] - ) + t = Traces(; a = [1, 2], b = Bool[0, 1]) @test length(t) == 2 - push!(t, (; a=3, b=true)) + push!(t, (; a = 3, b = true)) @test t[:a][end] == 3 @test t[:b][end] == true - append!(t, Traces(a=[4, 5], b=[false, false])) + append!(t, Traces(a = [4, 5], b = [false, false])) @test length(t[:a]) == 5 @test t[:b][end-1:end] == [false, false] - @test t[1] == (a=1, b=false) + @test t[1] == (a = 1, b = false) t_12 = t[1:2] @test t_12.a == [1, 2] @@ -40,43 +37,43 @@ end @test length(t) == 0 - push!(t, (; state=1)) - push!(t, (; next_state=2)) + push!(t, (; state = 1)) + push!(t, (; next_state = 2)) @test t[:state] == [1] @test t[:next_state] == [2] - @test t[1] == (state=1, next_state=2) + @test t[1] == (state = 1, next_state = 2) - append!(t, (; state=[3, 4])) + append!(t, (; state = [3, 4])) @test t[:state] == [1, 2, 3] @test t[:next_state] == [2, 3, 4] - @test t[end] == (state=3, next_state=4) + @test t[end] == (state = 3, next_state = 4) pop!(t) - t[end] == (state=2, next_state=3) + t[end] == (state = 2, next_state = 3) empty!(t) @test length(t) == 0 - t2 = MultiplexTraces{(:state, :next_state)}(Int[1,2,3,4]) + t2 = MultiplexTraces{(:state, :next_state)}(Int[1, 2, 3, 4]) append!(t, t2[:state]) - @test t[:state] == [1,2,3] - @test t[:next_state] == [2,3,4] + @test t[:state] == [1, 2, 3] + @test t[:next_state] == [2, 3, 4] end @testset "MergedTraces" begin - t1 = Traces(a=Int[]) - t2 = Traces(b=Bool[]) + t1 = Traces(a = Int[]) + t2 = Traces(b = Bool[]) t3 = t1 + t2 @test t3[:a] === t1[:a] @test t3[:b] === t2[:b] - push!(t3, (; a=1, b=false)) + push!(t3, (; a = 1, b = false)) @test length(t3) == 1 - @test t3[1] == (a=1, b=false) + @test t3[1] == (a = 1, b = false) - append!(t3, Traces(; a=[2, 3], b=[false, true])) + append!(t3, Traces(; a = [2, 3], b = [false, true])) @test length(t3) == 3 @test t3[:a][1:3] == [1, 2, 3] @@ -94,20 +91,20 @@ end t4 = MultiplexTraces{(:m, :n)}(Float64[]) t5 = t4 + t2 + t1 - push!(t5, (m=1.0, n=1.0, a=1, b=1)) + push!(t5, (m = 1.0, n = 1.0, a = 1, b = 1)) @test length(t5) == 1 - push!(t5, (m=2.0, a=2, b=0)) + push!(t5, (m = 2.0, a = 2, b = 0)) - @test t5[end] == (m=1.0, n=2.0, b=false, a=2) + @test t5[end] == (m = 1.0, n = 2.0, b = false, a = 2) - t6 = Traces(aa=Int[]) - t7 = Traces(bb=Bool[]) + t6 = Traces(aa = Int[]) + t7 = Traces(bb = Bool[]) t8 = (t1 + t2) + (t6 + t7) empty!(t8) - push!(t8, (a=1, b=false, aa=1, bb=false)) - append!(t8, Traces(a=[2, 3], b=[true, true], aa=[2, 3], bb=[true, true])) + push!(t8, (a = 1, b = false, aa = 1, bb = false)) + append!(t8, Traces(a = [2, 3], b = [true, true], aa = [2, 3], bb = [true, true])) @test length(t8) == 3 @@ -120,73 +117,77 @@ using ReinforcementLearningTrajectories: build_trace_index @testset "build_trace_index" begin t1 = CircularArraySARTSATraces(; - capacity=3, - state=Float32 => (2, 3), - action=Float32 => (2,), - reward=Float32 => (), - terminal=Bool => () + capacity = 3, + state = Float32 => (2, 3), + action = Float32 => (2,), + reward = Float32 => (), + terminal = Bool => (), ) - @test build_trace_index(typeof(t1).parameters[1], typeof(t1).parameters[2]) == Dict(:reward => 3, + @test build_trace_index(typeof(t1).parameters[1], typeof(t1).parameters[2]) == Dict( + :reward => 3, :next_state => 1, :state => 1, :action => 2, :next_action => 2, - :terminal => 4) + :terminal => 4, + ) - t2 = Traces(; a=[2, 3], b=[false, true]) + t2 = Traces(; a = [2, 3], b = [false, true]) build_trace_index(typeof(t2).parameters[1], typeof(t2).parameters[2]) end @testset "build_trace_index ElasticArraySARTSATraces" begin t1 = ElasticArraySARTSATraces(; - state=Float32 => (2, 3), - action=Float32 => (2,), - reward=Float32 => (), - terminal=Bool => () + state = Float32 => (2, 3), + action = Float32 => (2,), + reward = Float32 => (), + terminal = Bool => (), ) - @test build_trace_index(typeof(t1).parameters[1], typeof(t1).parameters[2]) == Dict(:reward => 3, + @test build_trace_index(typeof(t1).parameters[1], typeof(t1).parameters[2]) == Dict( + :reward => 3, :next_state => 1, :state => 1, :action => 2, :next_action => 2, - :terminal => 4) + :terminal => 4, + ) - t2 = Traces(; a=[2, 3], b=[false, true]) + t2 = Traces(; a = [2, 3], b = [false, true]) build_trace_index(typeof(t2).parameters[1], typeof(t2).parameters[2]) end @testset "push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v)" begin t1 = CircularArraySARTSATraces(; - capacity=3, - state=Float32 => (2, 3), - action=Float32 => (2,), - reward=Float32 => (), - terminal=Bool => () + capacity = 3, + state = Float32 => (2, 3), + action = Float32 => (2,), + reward = Float32 => (), + terminal = Bool => (), ) push!(t1, Val(:reward), 5) @test t1[:reward][1] == 5 @test size(Base.getindex(t1, :reward)) == (1,) - push!(t1, Val(:state), ones(2,3)) + push!(t1, Val(:state), ones(2, 3)) - @test t1[:state][1] == ones(2,3) + @test t1[:state][1] == ones(2, 3) - t2 = Traces(; a=[2, 3], b=[false, true]) + t2 = Traces(; a = [2, 3], b = [false, true]) push!(t2, Val(:a), 5) @test t2[:a][3] == 5 @test size(Base.getindex(t2, :a)) == (3,) - @test Base.getindex(t2, 1) == (; a = 2, b= false) + @test Base.getindex(t2, 1) == (; a = 2, b = false) end @testset "push!(ts::Traces{names,Trs,N,E}, ::Val{k}, v)" begin t1 = ElasticArraySARTSATraces( - state=Float32 => (2, 3), - action=Float32 => (2,), - reward=Float32 => (), - terminal=Bool => () + state = Float32 => (2, 3), + action = Float32 => (2,), + reward = Float32 => (), + terminal = Bool => (), ) push!(t1, Val(:reward), 5) @test t1[:reward][1] == 5 @@ -195,10 +196,10 @@ end @test size(Base.getindex(t1, :state)) == (0,) - t2 = Traces(; a=[2, 3], b=[false, true]) + t2 = Traces(; a = [2, 3], b = [false, true]) push!(t2, Val(:a), 5) @test t2[:a][3] == 5 @test size(Base.getindex(t2, :a)) == (3,) - @test Base.getindex(t2, 1) == (; a = 2, b= false) + @test Base.getindex(t2, 1) == (; a = 2, b = false) end diff --git a/test/trajectories.jl b/test/trajectories.jl index 426da07..3dba9e3 100644 --- a/test/trajectories.jl +++ b/test/trajectories.jl @@ -1,17 +1,11 @@ @testset "Default InsertSampleRatioController" begin - t = Trajectory( - container=Traces( - a=Int[], - b=Bool[] - ), - sampler=BatchSampler(3), - ) + t = Trajectory(container = Traces(a = Int[], b = Bool[]), sampler = BatchSampler(3)) batches = collect(t) @test length(batches) == 0 push!(t, (a = 1,)) - for i in 1:10 - push!(t, (a=i, b=true)) + for i = 1:10 + push!(t, (a = i, b = true)) end batches = collect(t) @@ -20,12 +14,9 @@ end @testset "trajectories" begin t = Trajectory( - container=Traces( - a=Int[], - b=Bool[] - ), - sampler=BatchSampler(3), - controller=InsertSampleRatioController(ratio=0.25, threshold=4) + container = Traces(a = Int[], b = Bool[]), + sampler = BatchSampler(3), + controller = InsertSampleRatioController(ratio = 0.25, threshold = 4), ) batches = [] @@ -37,8 +28,8 @@ end @test length(batches) == 0 # threshold not reached yet push!(t, (a = 1,)) - for i in 1:2 - push!(t, (a=i+1, b=true)) + for i = 1:2 + push!(t, (a = i + 1, b = true)) end for batch in t @@ -47,7 +38,7 @@ end @test length(batches) == 0 # threshold not reached yet - push!(t, (a=4, b=true)) + push!(t, (a = 4, b = true)) for batch in t push!(batches, batch) @@ -55,8 +46,8 @@ end @test length(batches) == 1 # 4 inserted, threshold is 4, ratio is 0.25 - for i in 5:7 - push!(t, (a=i, b=true)) + for i = 5:7 + push!(t, (a = i, b = true)) end for batch in t @@ -65,7 +56,7 @@ end @test length(batches) == 1 # 7 inserted, threshold is 4, ratio is 0.25 - push!(t, (a=8, b=true)) + push!(t, (a = 8, b = true)) for batch in t push!(batches, batch) @@ -74,32 +65,29 @@ end @test length(batches) == 2 # 8 inserted, ratio is 0.25 n = 400 - for i in 1:n - push!(t, (a=i, b=true)) + for i = 1:n + push!(t, (a = i, b = true)) end s = 0 for _ in t s += 1 end - @test s == n*0.25 + @test s == n * 0.25 end @testset "async trajectories" begin threshould = 100 ratio = 1 / 4 t = Trajectory( - container=Traces( - a=Int[], - b=Bool[] - ), - sampler=BatchSampler(3), - controller=AsyncInsertSampleRatioController(ratio, threshould) + container = Traces(a = Int[], b = Bool[]), + sampler = BatchSampler(3), + controller = AsyncInsertSampleRatioController(ratio, threshould), ) n = 100 - insert_task = @async for i in 1:n - append!(t, Traces(a=[i, i, i, i], b=[false, true, false, true])) + insert_task = @async for i = 1:n + append!(t, Traces(a = [i, i, i, i], b = [false, true, false, true])) end s = 0 @@ -108,4 +96,4 @@ end end sleep(1) @test s == (n - threshould * ratio) + 1 -end \ No newline at end of file +end