Skip to content
This repository was archived by the owner on May 21, 2022. It is now read-only.

Commit 77bea11

Browse files
committed
Use unsafe_gettpl! to speed up access to results of env.step()
1 parent ad5d3f2 commit 77bea11

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

src/OpenAIGym.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ mutable struct GymEnv{T} <: AbstractGymEnv
3030
pyreset::PyObject # the python env.reset function
3131
pystate::PyObject # the state array object referenced by the PyArray state.o
3232
pystepres::PyObject # used to make stepping the env slightly more efficient
33+
pytplres::PyObject # used to make stepping the env slightly more efficient
3334
info::PyObject # store it as a PyObject for speed, since often unused
3435
state::T
3536
reward::Float64
@@ -141,13 +142,10 @@ function Reinforce.step!(env::GymEnv{T}, a) where T <: PyArray
141142
pyact = pyaction(a)
142143
pycall!(env.pystepres, env.pystep, PyObject, pyact)
143144

144-
env.pystate, r, env.done, env.info =
145-
convert(Tuple{PyObject, Float64, Bool, PyObject}, env.pystepres)
146-
145+
unsafe_gettpl!(env.pystate, env.pystepres, PyObject, 0)
147146
setdata!(env.state, env.pystate)
148147

149-
env.total_reward += r
150-
return (r, env.state)
148+
return gymstep!(env)
151149
end
152150

153151
"""
@@ -157,11 +155,16 @@ function Reinforce.step!(env::GymEnv{T}, a) where T
157155
pyact = pyaction(a)
158156
pycall!(env.pystepres, env.pystep, PyObject, pyact)
159157

160-
env.pystate, r, env.done, env.info =
161-
convert(Tuple{PyObject, Float64, Bool, PyObject}, env.pystepres)
162-
158+
unsafe_gettpl!(env.pystate, env.pystepres, PyObject, 0)
163159
env.state = convert(T, env.pystate)
164160

161+
return gymstep!(env)
162+
end
163+
164+
@inline function gymstep!(env)
165+
r = unsafe_gettpl!(env.pytplres, env.pystepres, Float64, 1)
166+
env.done = unsafe_gettpl!(env.pytplres, env.pystepres, Bool, 2)
167+
unsafe_gettpl!(env.info, env.pystepres, PyObject, 3)
165168
env.total_reward += r
166169
return (r, env.state)
167170
end

0 commit comments

Comments
 (0)