Skip to content

Run code format #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions src/common/CircularArraySARTSATraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,28 @@ const CircularArraySARTSATraces = Traces{
<:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
},
}

function CircularArraySARTSATraces(;
capacity::Int,
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity+1)) +
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) = minimum(map(capacity,t.traces))
CircularArrayBuffers.capacity(t::CircularArraySARTSATraces) =
minimum(map(capacity, t.traces))
19 changes: 10 additions & 9 deletions src/common/CircularArraySARTSTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,28 @@ const CircularArraySARTSTraces = Traces{
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
},
}

function CircularArraySARTSTraces(;
capacity::Int,
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity+1)) +
MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
Traces(
action = CircularArrayBuffer{action_eltype}(action_size..., capacity),
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) = minimum(map(capacity,t.traces))
CircularArrayBuffers.capacity(t::CircularArraySARTSTraces) =
minimum(map(capacity, t.traces))
26 changes: 16 additions & 10 deletions src/common/CircularArraySLARTTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ const CircularArraySLARTTraces = Traces{
<:MultiplexTraces{AA′,<:Trace{<:CircularArrayBuffer}},
<:Trace{<:CircularArrayBuffer},
<:Trace{<:CircularArrayBuffer},
}
},
}

function CircularArraySLARTTraces(;
capacity::Int,
state=Int => (),
legal_actions_mask=Bool => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
legal_actions_mask = Bool => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
Expand All @@ -26,12 +26,18 @@ function CircularArraySLARTTraces(;
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(CircularArrayBuffer{state_eltype}(state_size..., capacity + 1)) +
MultiplexTraces{LL′}(CircularArrayBuffer{legal_actions_mask_eltype}(legal_actions_mask_size..., capacity + 1)) +
MultiplexTraces{LL′}(
CircularArrayBuffer{legal_actions_mask_eltype}(
legal_actions_mask_size...,
capacity + 1,
),
) +
MultiplexTraces{AA′}(CircularArrayBuffer{action_eltype}(action_size..., capacity + 1)) +
Traces(
reward=CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal=CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
reward = CircularArrayBuffer{reward_eltype}(reward_size..., capacity),
terminal = CircularArrayBuffer{terminal_eltype}(terminal_size..., capacity),
)
end

CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) = minimum(map(capacity,t.traces))
CircularArrayBuffers.capacity(t::CircularArraySLARTTraces) =
minimum(map(capacity, t.traces))
13 changes: 9 additions & 4 deletions src/common/CircularPrioritizedTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@ struct CircularPrioritizedTraces{T,names,Ts} <: AbstractTraces{names,Ts}
default_priority::Float32
end

function CircularPrioritizedTraces(traces::AbstractTraces{names,Ts}; default_priority) where {names,Ts}
function CircularPrioritizedTraces(
traces::AbstractTraces{names,Ts};
default_priority,
) where {names,Ts}
new_names = (:key, :priority, names...)
new_Ts = Tuple{Int,Float32,Ts.parameters...}
c = capacity(traces)
CircularPrioritizedTraces{typeof(traces),new_names,new_Ts}(
CircularVectorBuffer{Int}(c),
SumTree(c),
traces,
default_priority
default_priority,
)
end

Expand Down Expand Up @@ -60,6 +63,8 @@ function Base.getindex(ts::CircularPrioritizedTraces, s::Symbol)
end
end

Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} = NamedTuple{names}(map(k -> t[k][i], names))
Base.getindex(t::CircularPrioritizedTraces{<:Any,names}, i) where {names} =
NamedTuple{names}(map(k -> t[k][i], names))

capacity(t::CircularPrioritizedTraces) = ReinforcementLearningTrajectories.capacity(t.traces)
capacity(t::CircularPrioritizedTraces) =
ReinforcementLearningTrajectories.capacity(t.traces)
15 changes: 7 additions & 8 deletions src/common/ElasticArraySARTSATraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ const ElasticArraySARTSATraces = Traces{
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
},
}

function ElasticArraySARTSATraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
Expand All @@ -24,8 +24,7 @@ function ElasticArraySARTSATraces(;
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
Traces(
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end

20 changes: 10 additions & 10 deletions src/common/ElasticArraySARTSTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ const ElasticArraySARTSTraces = Traces{
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
},
}

function ElasticArraySARTSTraces(;
state=Int => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ())

state = Int => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)

state_eltype, state_size = state
action_eltype, action_size = action
reward_eltype, reward_size = reward
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
Traces(
MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) + Traces(
action = ElasticArray{action_eltype}(undef, action_size..., 0),
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
20 changes: 11 additions & 9 deletions src/common/ElasticArraySLARTTraces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ const ElasticArraySLARTTraces = Traces{
<:MultiplexTraces{AA′,<:Trace{<:ElasticArray}},
<:Trace{<:ElasticArray},
<:Trace{<:ElasticArray},
}
},
}

function ElasticArraySLARTTraces(;
capacity::Int,
state=Int => (),
legal_actions_mask=Bool => (),
action=Int => (),
reward=Float32 => (),
terminal=Bool => ()
state = Int => (),
legal_actions_mask = Bool => (),
action = Int => (),
reward = Float32 => (),
terminal = Bool => (),
)
state_eltype, state_size = state
action_eltype, action_size = action
Expand All @@ -26,10 +26,12 @@ function ElasticArraySLARTTraces(;
terminal_eltype, terminal_size = terminal

MultiplexTraces{SS′}(ElasticArray{state_eltype}(undef, state_size..., 0)) +
MultiplexTraces{LL′}(ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0)) +
MultiplexTraces{LL′}(
ElasticArray{legal_actions_mask_eltype}(undef, legal_actions_mask_size..., 0),
) +
MultiplexTraces{AA′}(ElasticArray{action_eltype}(undef, action_size..., 0)) +
Traces(
reward=ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal=ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
reward = ElasticArray{reward_eltype}(undef, reward_size..., 0),
terminal = ElasticArray{terminal_eltype}(undef, terminal_size..., 0),
)
end
6 changes: 3 additions & 3 deletions src/common/sum_tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ function correct_sample(t::SumTree, leaf_ind)
p = t.tree[leaf_ind]
# walk backwards until p != 0 or until leftmost leaf reached
tmp_ind = leaf_ind
while iszero(p) && (tmp_ind-1)*2 > length(t.tree)
while iszero(p) && (tmp_ind - 1) * 2 > length(t.tree)
tmp_ind -= 1
p = t.tree[tmp_ind]
end
Expand All @@ -151,7 +151,7 @@ function correct_sample(t::SumTree, leaf_ind)
end
return p, tmp_ind
end


function Base.get(t::SumTree, v)
parent_ind = 1
Expand Down Expand Up @@ -185,7 +185,7 @@ Random.rand(t::SumTree) = rand(Random.GLOBAL_RNG, t)

function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
inds, priorities = Vector{Int}(undef, n), Vector{T}(undef, n)
for i in 1:n
for i = 1:n
v = (i - 1 + rand(rng, T)) / n
ind, p = get(t, v * t.tree[1])
inds[i] = ind
Expand Down
20 changes: 11 additions & 9 deletions src/controllers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export InsertSampleRatioController, AsyncInsertSampleRatioController, EpisodeSampleRatioController
export InsertSampleRatioController,
AsyncInsertSampleRatioController, EpisodeSampleRatioController

"""
InsertSampleRatioController(;ratio=1., threshold=1)
Expand Down Expand Up @@ -43,18 +44,19 @@ end
function AsyncInsertSampleRatioController(
ratio,
threshold,
; ch_in_sz=1,
ch_out_sz=1,
n_inserted=0,
n_sampled=0
;
ch_in_sz = 1,
ch_out_sz = 1,
n_inserted = 0,
n_sampled = 0,
)
AsyncInsertSampleRatioController(
ratio,
threshold,
n_inserted,
n_sampled,
Channel(ch_in_sz),
Channel(ch_out_sz)
Channel(ch_out_sz),
)
end

Expand All @@ -75,14 +77,14 @@ end

function on_insert!(c::EpisodeSampleRatioController, n::Int, x::NamedTuple)
if n > 0
c.n_episodes += sum(x.terminal)
c.n_episodes += sum(x.terminal)
end
end

function on_sample!(c::EpisodeSampleRatioController)
if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio
if c.n_episodes >= c.threshold && c.n_sampled <= (c.n_episodes - c.threshold) * c.ratio
c.n_sampled += 1
return true
end
return false
end
end
Loading
Loading