Skip to content

Commit b190c60

Browse files
authored
Merge branch 'main' into nstep
2 parents 2595265 + c89ed6f commit b190c60

File tree

4 files changed

+96
-5
lines changed

4 files changed

+96
-5
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1616
[compat]
1717
Adapt = "3"
1818
CircularArrayBuffers = "0.1"
19+
DataStructures = "0.18"
1920
ElasticArrays = "1"
2021
MacroTools = "0.5"
2122
OnlineStats = "1"
2223
StackViews = "0.1"
23-
julia = "1.9"
24-
DataStructures = "0.18"
2524
StatsBase = "0.34"
25+
julia = "1.9"
2626

2727
[extras]
2828
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
29+
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2930
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3031

3132
[targets]
32-
test = ["Test", "CUDA"]
33+
test = ["Test", "CUDA", "StableRNGs"]

src/common/sum_tree.jl

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,28 @@ function Base.empty!(t::SumTree)
131131
t
132132
end
133133

134+
"""
135+
correct_sample(t::SumTree, leaf_ind)
136+
Check whether the sampled leaf is valid and if not return another valid leaf close to it. Used to correct samples with zero priority which may occur due to numerical errors with floats.
137+
"""
138+
function correct_sample(t::SumTree, leaf_ind)
139+
p = t.tree[leaf_ind]
140+
# walk backwards until p != 0 or until leftmost leaf reached
141+
tmp_ind = leaf_ind
142+
while iszero(p) && (tmp_ind-1)*2 > length(t.tree)
143+
tmp_ind -= 1
144+
p = t.tree[tmp_ind]
145+
end
146+
# walk forwards until p != 0 or until rightmost leaf reached
147+
iszero(p) && (tmp_ind = leaf_ind)
148+
while iszero(p) && (tmp_ind - t.nparents) <= t.length
149+
tmp_ind += 1
150+
p = t.tree[tmp_ind]
151+
end
152+
return p, tmp_ind
153+
end
154+
155+
134156
function Base.get(t::SumTree, v)
135157
parent_ind = 1
136158
leaf_ind = parent_ind
@@ -152,7 +174,7 @@ function Base.get(t::SumTree, v)
152174
if leaf_ind <= t.nparents
153175
leaf_ind += t.capacity
154176
end
155-
p = t.tree[leaf_ind]
177+
p, leaf_ind = correct_sample(t, leaf_ind)
156178
ind = leaf_ind - t.nparents
157179
real_ind = ind >= t.first ? ind - t.first + 1 : ind + t.capacity - t.first + 1
158180
real_ind, p
@@ -172,4 +194,4 @@ function Random.rand(rng::AbstractRNG, t::SumTree{T}, n::Int) where {T}
172194
inds, priorities
173195
end
174196

175-
Random.rand(t::SumTree, n::Int) = rand(Random.GLOBAL_RNG, t, n)
197+
Random.rand(t::SumTree, n::Int) = rand(Random.GLOBAL_RNG, t, n)

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
using ReinforcementLearningTrajectories
22
using CircularArrayBuffers, DataStructures
3+
using StableRNGs
34
using Test
45
import ReinforcementLearningTrajectories.StatsBase.sample
56
using CUDA
67
using Adapt
8+
using Random
9+
import ReinforcementLearningTrajectories.StatsBase.sample
10+
import StatsBase.countmap
711

812
struct TestAdaptor end
913

@@ -13,6 +17,7 @@ Adapt.adapt_storage(to::TestAdaptor, x) = CUDA.functional() ? CUDA.cu(x) : x
1317

1418
@testset "ReinforcementLearningTrajectories.jl" begin
1519
include("traces.jl")
20+
include("sum_tree.jl")
1621
include("common.jl")
1722
include("samplers.jl")
1823
include("controllers.jl")

test/sum_tree.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
function gen_rand_sumtree(n, seed, type::DataType=Float32)
2+
rng = StableRNG(seed)
3+
a = SumTree(type, n)
4+
append!(a, rand(rng, type, n))
5+
return a
6+
end
7+
8+
function gen_sumtree_with_zeros(n, seed, type::DataType=Float32)
9+
a = gen_rand_sumtree(n, seed, type)
10+
b = rand(StableRNG(seed), Bool, n)
11+
return copy_multiply(a, b)
12+
end
13+
14+
function copy_multiply(stree, m)
15+
new_tree = deepcopy(stree)
16+
new_tree .*= m
17+
return new_tree
18+
end
19+
20+
function sumtree_nozero(t::SumTree, rng::AbstractRNG, iters=1)
21+
for _ in iters
22+
(_, p) = rand(rng, t)
23+
p == 0 && return false
24+
end
25+
return true
26+
end
27+
sumtree_nozero(n::Integer, seed::Integer, iters=1) = sumtree_nozero(gen_sumtree_with_zeros(n, seed), StableRNG(seed), iters)
28+
sumtree_nozero(n, seeds::AbstractVector, iters=1) = all(sumtree_nozero(n, seed, iters) for seed in seeds)
29+
30+
31+
function sumtree_distribution!(indices, priorities, t::SumTree, rng::AbstractRNG, iters=1000*t.length)
32+
for i = 1:iters
33+
indices[i], priorities[i] = rand(rng, t)
34+
end
35+
imap = countmap(indices)
36+
est_pdf = Dict(k=>v/length(indices) for (k, v) in imap)
37+
ex_pdf = Dict(k=>v/t.tree[1] for (k, v) in Dict(1:length(t) .=> t))
38+
abserrs = [est_pdf[k] - ex_pdf[k] for k in keys(est_pdf)]
39+
return abserrs
40+
end
41+
sumtree_distribution!(indices, priorities, n, seed, iters=1000*n) = sumtree_distribution!(indices, priorities, gen_rand_sumtree(n, seed), StableRNG(seed), iters)
42+
function sumtree_distribution(n, seeds::AbstractVector, iters=1000*n)
43+
p = [zeros(Float32, iters) for _ = 1:Threads.nthreads()]
44+
i = [zeros(Float32, iters) for _ = 1:Threads.nthreads()]
45+
results = Vector{Vector{Float64}}(undef, length(seeds))
46+
Threads.@threads for ix = 1:length(seeds)
47+
results[ix] = sumtree_distribution!(i[Threads.threadid()], p[Threads.threadid()], gen_rand_sumtree(n, seeds[ix]), StableRNG(seeds[ix]), iters)
48+
end
49+
return results
50+
end
51+
52+
@testset "SumTree" begin
53+
n = 1024
54+
seeds = 1:100
55+
nozero_iters=1024
56+
distr_iters=1024*10_000
57+
abstol = 0.05
58+
maxerr=0.01
59+
60+
@test sumtree_nozero(n, seeds, nozero_iters)
61+
@test all(x->all(x .< maxerr) && sum(abs2, x) < abstol,
62+
sumtree_distribution(n, seeds, distr_iters))
63+
end

0 commit comments

Comments
 (0)