11using Random
2- export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler
2+ export EpisodesSampler, Episode, BatchSampler, NStepBatchSampler, MetaSampler, MultiBatchSampler, DummySampler, MultiStepSampler
33
44struct SampleGenerator{S,T}
55 sampler:: S
@@ -29,27 +29,27 @@ StatsBase.sample(::DummySampler, t) = t
2929export BatchSampler
3030
3131struct BatchSampler{names}
32- batch_size :: Int
32+ batchsize :: Int
3333 rng:: Random.AbstractRNG
3434end
3535
3636"""
37- BatchSampler{names}(;batch_size , rng=Random.GLOBAL_RNG)
38- BatchSampler{names}(batch_size ;rng=Random.GLOBAL_RNG)
37+ BatchSampler{names}(;batchsize , rng=Random.GLOBAL_RNG)
38+ BatchSampler{names}(batchsize ;rng=Random.GLOBAL_RNG)
3939
40- Uniformly sample **ONE** batch of `batch_size ` examples for each trace specified
40+ Uniformly sample **ONE** batch of `batchsize ` examples for each trace specified
4141in `names`. If `names` is not set, all the traces will be sampled.
4242"""
43- BatchSampler (batch_size ; kw... ) = BatchSampler (; batch_size = batch_size , kw... )
43+ BatchSampler (batchsize ; kw... ) = BatchSampler (; batchsize = batchsize , kw... )
4444BatchSampler (; kw... ) = BatchSampler {nothing} (; kw... )
45- BatchSampler {names} (batch_size ; kw... ) where {names} = BatchSampler {names} (; batch_size = batch_size , kw... )
46- BatchSampler {names} (; batch_size , rng= Random. GLOBAL_RNG) where {names} = BatchSampler {names} (batch_size , rng)
45+ BatchSampler {names} (batchsize ; kw... ) where {names} = BatchSampler {names} (; batchsize = batchsize , kw... )
46+ BatchSampler {names} (; batchsize , rng= Random. GLOBAL_RNG) where {names} = BatchSampler {names} (batchsize , rng)
4747
4848StatsBase. sample (s:: BatchSampler{nothing} , t:: AbstractTraces ) = StatsBase. sample (s, t, keys (t))
4949StatsBase. sample (s:: BatchSampler{names} , t:: AbstractTraces ) where {names} = StatsBase. sample (s, t, names)
5050
5151function StatsBase. sample (s:: BatchSampler , t:: AbstractTraces , names, weights = StatsBase. UnitWeights {Int} (length (t)))
52- inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batch_size )
52+ inds = StatsBase. sample (s. rng, 1 : length (t), weights, s. batchsize )
5353 NamedTuple {names} (map (x -> collect (t[Val (x)][inds]), names))
5454end
5555
@@ -75,12 +75,12 @@ function StatsBase.sample(s::BatchSampler, e::EpisodesBuffer{<:Any, <:Any, <:Cir
7575 p = collect (deepcopy (t. priorities))
7676 w = StatsBase. FrequencyWeights (p)
7777 w .*= e. sampleable_inds[1 : end - 1 ]
78- inds = StatsBase. sample (s. rng, eachindex (w), w, s. batch_size )
78+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize )
7979 NamedTuple {(:key, :priority, names...)} ((t. keys[inds], p[inds], map (x -> collect (t. traces[Val (x)][inds]), names)... ))
8080end
8181
8282function StatsBase. sample (s:: BatchSampler , t:: CircularPrioritizedTraces , names)
83- inds, priorities = rand (s. rng, t. priorities, s. batch_size )
83+ inds, priorities = rand (s. rng, t. priorities, s. batchsize )
8484 NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[Val (x)][inds]), names)... ))
8585end
8686
@@ -165,41 +165,42 @@ end
165165export NStepBatchSampler
166166
167167"""
168- NStepBatchSampler{names}(; n, γ, batch_size=32, stack_size=nothing, rng=Random.GLOBAL_RNG)
168+
169+ NStepBatchSampler{names}(; n, γ, batchsize=32, stacksize=nothing, rng=Random.GLOBAL_RNG)
169170
170171Used to sample a discounted sum of consecutive rewards in the framework of n-step TD learning.
171172The "next" element of Multiplexed traces (such as the next_state or the next_action) will be
172173that in up to `n > 1` steps later in the buffer. The reward will be
173174the discounted sum of the `n` rewards, with `γ` as the discount factor.
174175
175- NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stack_size ` is set
176- to an integer > 1. This samples the (stack_size - 1) previous states. This is useful in the case
177- of partial observability, for example when the state is approximated by `stack_size ` consecutive
176+ NStepBatchSampler may also be used with n ≥ 1 to sample a "stack" of states if `stacksize ` is set
177+ to an integer > 1. This samples the (stacksize - 1) previous states. This is useful in the case
178+ of partial observability, for example when the state is approximated by `stacksize ` consecutive
178179frames.
179180"""
180- mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} }
181+ mutable struct NStepBatchSampler{names, S <: Union{Nothing,Int} , R <: AbstractRNG }
181182 n:: Int # !!! n starts from 1
182183 γ:: Float32
183- batch_size :: Int
184- stack_size :: S
185- rng:: Any
184+ batchsize :: Int
185+ stacksize :: S
186+ rng:: R
186187end
187188
188189NStepBatchSampler (t:: AbstractTraces ; kw... ) = NStepBatchSampler {keys(t)} (; kw... )
189- function NStepBatchSampler {names} (; n, γ, batch_size = 32 , stack_size = nothing , rng= Random. GLOBAL_RNG ) where {names}
190+ function NStepBatchSampler {names} (; n, γ, batchsize = 32 , stacksize = nothing , rng= Random. default_rng () ) where {names}
190191 @assert n >= 1 " n must be ≥ 1."
191- ss = stack_size == 1 ? nothing : stack_size
192- NStepBatchSampler {names, typeof(ss)} (n, γ, batch_size , ss, rng)
192+ ss = stacksize == 1 ? nothing : stacksize
193+ NStepBatchSampler {names, typeof(ss), typeof(rng) } (n, γ, batchsize , ss, rng)
193194end
194195
195- # return a boolean vector of the valid sample indices given the stack_size and the truncated n for each index.
196+ # return a boolean vector of the valid sample indices given the stacksize and the truncated n for each index.
196197function valid_range (s:: NStepBatchSampler , eb:: EpisodesBuffer )
197198 range = copy (eb. sampleable_inds)
198199 ns = Vector {Int} (undef, length (eb. sampleable_inds))
199- stack_size = isnothing (s. stack_size ) ? 1 : s. stack_size
200+ stacksize = isnothing (s. stacksize ) ? 1 : s. stacksize
200201 for idx in eachindex (range)
201202 step_number = eb. step_numbers[idx]
202- range[idx] = step_number >= stack_size && eb. sampleable_inds[idx]
203+ range[idx] = step_number >= stacksize && eb. sampleable_inds[idx]
203204 ns[idx] = min (s. n, eb. episodes_lengths[idx] - step_number + 1 )
204205 end
205206 return range, ns
@@ -211,19 +212,19 @@ end
211212
212213function StatsBase. sample (s:: NStepBatchSampler , t:: EpisodesBuffer , :: Val{names} ) where names
213214 weights, ns = valid_range (s, t)
214- inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batch_size )
215+ inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batchsize )
215216 fetch (s, t, Val (names), inds, ns)
216217end
217218
218219function fetch (s:: NStepBatchSampler , ts:: EpisodesBuffer , :: Val{names} , inds, ns) where names
219220 NamedTuple {names} (map (name -> collect (fetch (s, ts[name], Val (name), inds, ns[inds])), names))
220221end
221222
222- # state and next_state have specialized fetch methods due to stack_size
223+ # state and next_state have specialized fetch methods due to stacksize
223224fetch (:: NStepBatchSampler{names, Nothing} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[inds]
224- fetch (s:: NStepBatchSampler{names, Int} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[[x + i for i in - s. stack_size + 1 : 0 , x in inds]]
225+ fetch (s:: NStepBatchSampler{names, Int} , trace:: AbstractTrace , :: Val{:state} , inds, ns) where {names} = trace[[x + i for i in - s. stacksize + 1 : 0 , x in inds]]
225226fetch (:: NStepBatchSampler{names, Nothing} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[inds .+ ns .- 1 ]
226- fetch (s:: NStepBatchSampler{names, Int} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in - s. stack_size + 1 : 0 , (idx,x) in enumerate (inds)]]
227+ fetch (s:: NStepBatchSampler{names, Int} , trace:: RelativeTrace{1,0} , :: Val{:next_state} , inds, ns) where {names} = trace[[x + ns[idx] - 1 + i for i in - s. stacksize + 1 : 0 , (idx,x) in enumerate (inds)]]
227228
228229# reward due to discounting
229230function fetch (s:: NStepBatchSampler , trace:: AbstractTrace , :: Val{:reward} , inds, ns)
@@ -247,7 +248,7 @@ function StatsBase.sample(s::NStepBatchSampler{names}, e::EpisodesBuffer{<:Any,
247248 w = StatsBase. FrequencyWeights (p)
248249 valids, ns = valid_range (s,e)
249250 w .*= valids[1 : end - 1 ]
250- inds = StatsBase. sample (s. rng, eachindex (w), w, s. batch_size )
251+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize )
251252 merge (
252253 (key= t. keys[inds], priority= p[inds]),
253254 fetch (s, e, Val (names), inds, ns)
@@ -297,3 +298,74 @@ function StatsBase.sample(::EpisodesSampler, t::EpisodesBuffer, names)
297298
298299 return [make_episode (t, r, names) for r in ranges]
299300end
301+
302+ # ####MultiStepSampler
303+
304+ """
305+ MultiStepSampler{names}(batchsize, stacksize, n, rng)
306+
307+ Sampler that fetches steps `[x, x+1, ..., x + n -1]` for each trace of each sampled index
308+ `x`. The samples are returned in an array of batchsize elements. For each element, n is
309+ truncated by the end of its episode. This means that the dimensions of each sample are not
310+ the same.
311+ """
312+ struct MultiStepSampler{names, S <: Union{Nothing,Int} , R <: AbstractRNG }
313+ n:: Int
314+ batchsize:: Int
315+ stacksize:: Int
316+ rng:: R
317+ end
318+
319+ MultiStepSampler (t:: AbstractTraces ; kw... ) = MultiStepSampler {keys(t)} (; kw... )
320+ function MultiStepSampler {names} (; n, batchsize= 32 , stacksize= nothing , rng= Random. default_rng ()) where {names}
321+ @assert n >= 1 " n must be ≥ 1."
322+ ss = stacksize == 1 ? nothing : stacksize
323+ MultiStepSampler {names, typeof(ss), typeof(rng)} (n, batchsize, ss, rng)
324+ end
325+
326+ function valid_range (s:: MultiStepSampler , eb:: EpisodesBuffer )
327+ range = copy (eb. sampleable_inds)
328+ ns = Vector {Int} (undef, length (eb. sampleable_inds))
329+ stacksize = isnothing (s. stacksize) ? 1 : s. stacksize
330+ for idx in eachindex (range)
331+ step_number = eb. step_numbers[idx]
332+ range[idx] = step_number >= stacksize && eb. sampleable_inds[idx]
333+ ns[idx] = min (s. n, eb. episodes_lengths[idx] - step_number + 1 )
334+ end
335+ return range, ns
336+ end
337+
338+ function StatsBase. sample (s:: MultiStepSampler{names} , ts) where {names}
339+ StatsBase. sample (s, ts, Val (names))
340+ end
341+
342+ function StatsBase. sample (s:: MultiStepSampler , t:: EpisodesBuffer , :: Val{names} ) where names
343+ weights, ns = valid_range (s, t)
344+ inds = StatsBase. sample (s. rng, 1 : length (t), StatsBase. FrequencyWeights (weights[1 : end - 1 ]), s. batchsize)
345+ fetch (s, t, Val (names), inds, ns)
346+ end
347+
348+ function fetch (s:: MultiStepSampler , ts:: EpisodesBuffer , :: Val{names} , inds, ns) where names
349+ NamedTuple {names} (map (name -> collect (fetch (s, ts[name], Val (name), inds, ns[inds])), names))
350+ end
351+
352+ function fetch (:: MultiStepSampler , trace, :: Val , inds, ns)
353+ [trace[idx: (idx + ns[i] - 1 )] for (i,idx) in enumerate (inds)]
354+ end
355+
356+ function fetch (s:: MultiStepSampler{names, Int} , trace:: AbstractTrace , :: Union{Val{:state}, Val{:next_state}} , inds, ns) where {names}
357+ [trace[[idx + i + n - 1 for i in - s. stacksize+ 1 : 0 , n in 1 : ns[j]]] for (j,idx) in enumerate (inds)]
358+ end
359+
360+ function StatsBase. sample (s:: MultiStepSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} ) where {names}
361+ t = e. traces
362+ p = collect (deepcopy (t. priorities))
363+ w = StatsBase. FrequencyWeights (p)
364+ valids, ns = valid_range (s,e)
365+ w .*= valids[1 : end - 1 ]
366+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batchsize)
367+ merge (
368+ (key= t. keys[inds], priority= p[inds]),
369+ fetch (s, e, Val (names), inds, ns)
370+ )
371+ end
0 commit comments