@@ -72,10 +72,11 @@ StatsBase.sample(s::BatchSampler{nothing}, t::CircularPrioritizedTraces) = Stats
7272
7373function StatsBase. sample (s:: BatchSampler , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} , names)
7474 t = e. traces
75- st = deepcopy (t. priorities)
76- st .*= e. sampleable_inds[1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
77- inds, priorities = rand (s. rng, st, s. batch_size)
78- NamedTuple {(:key, :priority, names...)} ((t. keys[inds], priorities, map (x -> collect (t. traces[Val (x)][inds]), names)... ))
75+ p = collect (deepcopy (t. priorities))
76+ w = StatsBase. FrequencyWeights (p)
77+ w .*= e. sampleable_inds[1 : end - 1 ]
78+ inds = StatsBase. sample (s. rng, eachindex (w), w, s. batch_size)
79+ NamedTuple {(:key, :priority, names...)} ((t. keys[inds], p[inds], map (x -> collect (t. traces[Val (x)][inds]), names)... ))
7980end
8081
8182function StatsBase. sample (s:: BatchSampler , t:: CircularPrioritizedTraces , names)
@@ -242,12 +243,13 @@ fetch(::NStepBatchSampler, trace::AbstractTrace, ::Val, inds, ns) = trace[inds]
242243
243244function StatsBase. sample (s:: NStepBatchSampler{names} , e:: EpisodesBuffer{<:Any, <:Any, <:CircularPrioritizedTraces} ) where {names}
244245 t = e. traces
245- st = deepcopy (t. priorities)
246+ p = collect (deepcopy (t. priorities))
247+ w = StatsBase. FrequencyWeights (p)
246248 valids, ns = valid_range (s,e)
247- st .*= valids[1 : end - 1 ] # temporary sumtree that puts 0 priority to non sampleable indices.
248- inds, priorities = rand (s. rng, st , s. batch_size)
249+ w .*= valids[1 : end - 1 ]
250+ inds = StatsBase . sample (s. rng, eachindex (w), w , s. batch_size)
249251 merge (
250- (key= t. keys[inds], priority= priorities ),
252+ (key= t. keys[inds], priority= p[inds] ),
251253 fetch (s, e, Val (names), inds, ns)
252254 )
253255end
0 commit comments