|
| 1 | +import ReinforcementLearningTrajectories.fetch |
1 | 2 | @testset "Samplers" begin |
2 | 3 | @testset "BatchSampler" begin |
3 | 4 | sz = 32 |
|
74 | 75 |
|
75 | 76 | #! format: off |
76 | 77 | @testset "NStepSampler" begin |
77 | | - γ = 0.9 |
| 78 | + γ = 0.99 |
78 | 79 | n_stack = 2 |
79 | 80 | n_horizon = 3 |
80 | | - batch_size = 4 |
81 | | - |
82 | | - t1 = MultiplexTraces{(:state, :next_state)}(1:10) + |
83 | | - MultiplexTraces{(:action, :next_action)}(iseven.(1:10)) + |
84 | | - Traces( |
85 | | - reward=1:9, |
86 | | - terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1], |
87 | | - ) |
88 | | - |
89 | | - s1 = NStepBatchSampler(n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size) |
90 | | - |
91 | | - xs = RLTrajectories.StatsBase.sample(s1, t1) |
92 | | - |
93 | | - @test size(xs.state) == (n_stack, batch_size) |
94 | | - @test size(xs.next_state) == (n_stack, batch_size) |
95 | | - @test size(xs.action) == (batch_size,) |
96 | | - @test size(xs.reward) == (batch_size,) |
97 | | - @test size(xs.terminal) == (batch_size,) |
98 | | - |
99 | | - |
100 | | - state_size = (2,3) |
101 | | - n_state = reduce(*, state_size) |
102 | | - total_length = 10 |
103 | | - t2 = MultiplexTraces{(:state, :next_state)}( |
104 | | - reshape(1:n_state * total_length, state_size..., total_length) |
105 | | - ) + |
106 | | - MultiplexTraces{(:action, :next_action)}(iseven.(1:total_length)) + |
107 | | - Traces( |
108 | | - reward=1:total_length-1, |
109 | | - terminal=Bool[0, 0, 0, 1, 0, 0, 0, 0, 1], |
110 | | - ) |
111 | | - |
112 | | - xs2 = RLTrajectories.StatsBase.sample(s1, t2) |
113 | | - |
114 | | - @test size(xs2.state) == (state_size..., n_stack, batch_size) |
115 | | - @test size(xs2.next_state) == (state_size..., n_stack, batch_size) |
116 | | - @test size(xs2.action) == (batch_size,) |
117 | | - @test size(xs2.reward) == (batch_size,) |
118 | | - @test size(xs2.terminal) == (batch_size,) |
119 | | - |
120 | | - inds = [3, 5, 7] |
121 | | - xs3 = RLTrajectories.StatsBase.sample(s1, t2, Val(SS′ART), inds) |
122 | | - |
123 | | - @test xs3.state == cat( |
124 | | - ( |
125 | | - reshape(n_state * (i-n_stack)+1: n_state * i, state_size..., n_stack) |
126 | | - for i in inds |
127 | | - )... |
128 | | - ;dims=length(state_size) + 2 |
129 | | - ) |
130 | | - |
131 | | - @test xs3.next_state == xs3.state .+ (n_state * n_horizon) |
132 | | - @test xs3.action == iseven.(inds) |
133 | | - @test xs3.terminal == [any(t2[:terminal][i: i+n_horizon-1]) for i in inds] |
134 | | - |
135 | | - # manual calculation |
136 | | - @test xs3.reward[1] ≈ 3 + γ * 4 # terminated at step 4 |
137 | | - @test xs3.reward[2] ≈ 5 + γ * (6 + γ * 7) |
138 | | - @test xs3.reward[3] ≈ 7 + γ * (8 + γ * 9) |
139 | | - end |
140 | | - #! format: on |
141 | | - |
142 | | - @testset "Trajectory with CircularPrioritizedTraces and NStepBatchSampler" begin |
143 | | - n=1 |
144 | | - γ=0.99f0 |
145 | | - |
146 | | - t = Trajectory( |
147 | | - container=CircularPrioritizedTraces( |
148 | | - CircularArraySARTSTraces( |
149 | | - capacity=5, |
150 | | - state=Float32 => (4,), |
151 | | - ); |
152 | | - default_priority=100.0f0 |
153 | | - ), |
154 | | - sampler=NStepBatchSampler{SS′ART}( |
155 | | - n=n, |
156 | | - γ=γ, |
157 | | - batch_size=32, |
158 | | - ), |
159 | | - controller=InsertSampleRatioController( |
160 | | - threshold=100, |
161 | | - n_inserted=-1 |
162 | | - ) |
163 | | - ) |
| 81 | + batch_size = 1000 |
| 82 | + eb = EpisodesBuffer(CircularArraySARTSATraces(capacity=10)) |
| 83 | + s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, stack_size=n_stack, batch_size=batch_size) |
164 | 84 |
|
165 | | - push!(t, (state = 1, action = true)) |
166 | | - for i = 1:9 |
167 | | - push!(t, (state = i+1, action = true, reward = i, terminal = false)) |
| 85 | + push!(eb, (state = 1, action = 1)) |
| 86 | + for i = 1:5 |
| 87 | + push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) |
168 | 88 | end |
169 | | - |
170 | | - b = RLTrajectories.StatsBase.sample(t) |
171 | | - @test haskey(b, :priority) |
172 | | - @test sum(b.action .== 0) == 0 |
173 | | - end |
174 | | - |
175 | | - |
176 | | - @testset "Trajectory with CircularArraySARTSTraces and NStepBatchSampler" begin |
177 | | - n=1 |
178 | | - γ=0.99f0 |
179 | | - |
180 | | - t = Trajectory( |
181 | | - container=CircularArraySARTSTraces( |
182 | | - capacity=5, |
183 | | - state=Float32 => (4,), |
184 | | - ), |
185 | | - sampler=NStepBatchSampler{SS′ART}( |
186 | | - n=n, |
187 | | - γ=γ, |
188 | | - batch_size=32, |
189 | | - ), |
190 | | - controller=InsertSampleRatioController( |
191 | | - threshold=100, |
192 | | - n_inserted=-1 |
193 | | - ) |
194 | | - ) |
195 | | - |
196 | | - push!(t, (state = 1, action = true)) |
197 | | - for i = 1:9 |
198 | | - push!(t, (state = i+1, action = true, reward = i, terminal = false)) |
| 89 | + push!(eb, (state = 7, action = 7)) |
| 90 | + for (j,i) = enumerate(8:11) |
| 91 | + push!(eb, (state = i, action =i, reward = i-1, terminal = false)) |
| 92 | + end |
| 93 | + weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) |
| 94 | + @test weights == [0,1,1,1,1,0,0,1,1,1,0] |
| 95 | + @test ns == [3,3,3,2,1,-1,3,3,2,1,0] #the -1 is due to ep_lengths[6] being that of 2nd episode but step_numbers[6] being that of 1st episode |
| 96 | + inds = [i for i in eachindex(weights) if weights[i] == 1] |
| 97 | + batch = sample(s1, eb) |
| 98 | + for key in keys(eb) |
| 99 | + @test haskey(batch, key) |
199 | 100 | end |
| 101 | + #state: samples with stack_size |
| 102 | + states = ReinforcementLearningTrajectories.fetch(s1, eb[:state], Val(:state), inds, ns[inds]) |
| 103 | + @test states == [1 2 3 4 7 8 9; |
| 104 | + 2 3 4 5 8 9 10] |
| 105 | + @test all(in(eachcol(states)), unique(eachcol(batch[:state]))) |
| 106 | + #next_state: samples with stack_size and nsteps forward |
| 107 | + next_states = ReinforcementLearningTrajectories.fetch(s1, eb[:next_state], Val(:next_state), inds, ns[inds]) |
| 108 | + @test next_states == [4 5 5 5 10 10 10; |
| 109 | + 5 6 6 6 11 11 11] |
| 110 | + @test all(in(eachcol(next_states)), unique(eachcol(batch[:next_state]))) |
| 111 | + #action: samples normally |
| 112 | + actions = ReinforcementLearningTrajectories.fetch(s1, eb[:action], Val(:action), inds, ns[inds]) |
| 113 | + @test actions == inds |
| 114 | + @test all(in(actions), unique(batch[:action])) |
| 115 | + #next_action: is a multiplex trace: should automatically sample nsteps forward |
| 116 | + next_actions = ReinforcementLearningTrajectories.fetch(s1, eb[:next_action], Val(:next_action), inds, ns[inds]) |
| 117 | + @test next_actions == [5, 6, 6, 6, 11, 11, 11] |
| 118 | + @test all(in(next_actions), unique(batch[:next_action])) |
| 119 | + #reward: discounted sum |
| 120 | + rewards = ReinforcementLearningTrajectories.fetch(s1, eb[:reward], Val(:reward), inds, ns[inds]) |
| 121 | + @test rewards ≈ [2+0.99*3+0.99^2*4, 3+0.99*4+0.99^2*5, 4+0.99*5, 5, 8+0.99*9+0.99^2*10,9+0.99*10, 10] |
| 122 | + @test all(in(rewards), unique(batch[:reward])) |
| 123 | + #terminal: nsteps forward |
| 124 | + terminals = ReinforcementLearningTrajectories.fetch(s1, eb[:terminal], Val(:terminal), inds, ns[inds]) |
| 125 | + @test terminals == [0,1,1,1,0,0,0] |
| 126 | + |
| 127 | + ### CircularPrioritizedTraces and NStepBatchSampler |
| 128 | + γ = 0.99 |
| 129 | + n_horizon = 3 |
| 130 | + batch_size = 4 |
| 131 | + eb = EpisodesBuffer(CircularPrioritizedTraces(CircularArraySARTSATraces(capacity=10), default_priority = 10f0)) |
| 132 | + s1 = NStepBatchSampler(eb, n=n_horizon, γ=γ, batch_size=batch_size) |
200 | 133 |
|
201 | | - b = RLTrajectories.StatsBase.sample(t) |
202 | | - @test sum(b.action .== 0) == 0 |
| 134 | + push!(eb, (state = 1, action = 1)) |
| 135 | + for i = 1:5 |
| 136 | + push!(eb, (state = i+1, action =i+1, reward = i, terminal = i == 5)) |
| 137 | + end |
| 138 | + push!(eb, (state = 7, action = 7)) |
| 139 | + for (j,i) = enumerate(8:11) |
| 140 | + push!(eb, (state = i, action =i, reward = i-1, terminal = false)) |
| 141 | + end |
| 142 | + weights, ns = ReinforcementLearningTrajectories.valid_range(s1, eb) |
| 143 | + inds = [i for i in eachindex(weights) if weights[i] == 1] |
| 144 | + batch = sample(s1, eb) |
| 145 | + for key in (keys(eb)..., :key, :priority) |
| 146 | + @test haskey(batch, key) |
| 147 | + end |
203 | 148 | end |
204 | 149 |
|
205 | 150 | @testset "EpisodesSampler" begin |
|
0 commit comments