Skip to content

taodav/pobax

Repository files navigation

POBAX: Partially Observable Benchmarks in JAX

POBAX is a reinforcement learning benchmark that tests all forms of partial observability.

POBAX has been accepted to RLC 2025. Check out our paper!

The benchmark is entirely written in JAX, allowing for fast, GPU-scalable experimentation.

Environments

POBAX includes environments (as well as recommended hyperparameter settings) across diverse forms of partial observability. We list our environments from smallest to largest (in terms of neural network size requirements for PPO RNN):

Environment Category IDs Description
Simple Chain Object uncertainty & tracking simple_chain Diagnostic POMDP for testing algorithms.
T-Maze Object uncertainty & tracking tmaze_10 Bakker's classic memory testing environment (hallway disambiguation).
RockSample Episode nonstationarity and Object uncertainty & tracking rocksample_11_11 and rocksample_15_15 The classic rock collecting POMDP, where an agent needs to uncover and collect rocks.
Battleship Object uncertainty & tracking battleship_10 Single-player battleship (10x10).
Masked Mujoco Moment Features {env_name}-{F/P/V}-v0 Mujoco with state features masked out. env_name can be Walker, Ant, Hopper, or HalfCheetah. F/P/V stands for fully observable, position only, or velocity only versions of environments, respectively.
DMLab Minigrid Spatial uncertainty Navix-DMLab-Maze-{01/02/03}-v0 MiniGrid versions of the DeepMind Lab mazes. 01/02/03 refer to the DeepMind Lab Minigrid mazes in ascending difficulty.
Visual Continuous Control Visual occlusion and moment features {env_name}-pixels Pixel-based versions of Mujoco control. Requires the Madrona_MJX package. env_name can be ant, halfcheetah, hopper, or walker2d.
No-Inventory Crafter Object uncertainty & tracking and spatial uncertainty craftax-pixels Crafter without the inventory. Requires the Craftax package.

Experimental results on memory-based deep reinforcement learning algorithms are shown here and in our work.

POBAX experimental results

Basic Usage

import jax
from pobax.envs import get_env

rand_key = jax.random.PRNGKey(2025)
env_key, rand_key = jax.random.split(rand_key)

# Creates a vectorized environment
env, env_params = get_env("rocksample_11_11", env_key)

# Reset 10 environments
reset_key, rand_key = jax.random.split(rand_key)
reset_keys = jax.random.split(rand_key, 10)

obs, env_state = env.reset(reset_keys, env_params)

# Take steps in all environments
step_key, action_key, rand_key = jax.random.split(rand_key, 3)
step_keys = jax.random.split(step_key, 10)
action_keys = jax.random.split(action_key, 10)

actions = jax.vmap(env.action_space(env_params).sample)(action_keys)

obs, env_state, reward, done, info = env.step(step_keys, env_state, actions, env_params)

Installation

The latest pobax version can be installed via PyPI:

pip install pobax

To develop for the pobax package, create a fork and clone the repo:

git clone [email protected]:{FORKED_USER}/pobax.git
cd pobax
pip install -e .

Installing Madrona_MJX (Optional)

POBAX's pixel-based continuous control environments (ant-pixels, halfcheetah-pixels, hopper-pixels, walker2d-pixels) require the Madrona_MJX renderer for GPU-accelerated rendering.

Installation:

git clone https://github.com/KevinGuo27/madrona_mjx.git
cd madrona_mjx
pip install -e .

Requirements:

  • CUDA 12.6.3 or compatible versions
  • GPU support

Note: Madrona_MJX currently does not support jax.vmap, so experiments must run with a single seed at a time. See scripts/hyperparams/visual_mujoco/ant/best/ant_ppo_madrona_best.py for an example configuration.

Compilation: The first time you run a Madrona_MJX environment, the renderer will compile (takes ~4 minutes on an RTX 3090). You'll see outputs like this:

Using raytracer
Compiling .../madrona/src/mw/device/bvh.cpp
Compiling .../madrona/src/mw/device/memory.cpp
Compiling .../madrona/src/mw/device/host_print.cpp
Compiling .../madrona/src/mw/device/bvh_raycast.cpp
Compiling GPU engine code:
Initialization finished

Here's an example of how to run a pixel-based Madrona_MJX ant environment:

python -m pobax.algos.ppo --env ant_pixels --action_concat --lambda0 0.7 --lambda1 0.95 --hidden_size 512 --total_steps 5000000 --n_seeds 1 --platform gpu --debug --study_name ant_ppo_madrona_best

Agents

POBAX includes algorithms loosely based on the PureJAXRL framework, with algorithms based on proximal policy optimization (PPO). These include:

Memoryless versions of the recurrent PPO algorithm is also included with the --memoryless flag.

Here's an example script of how to run a recurrent PPO agent on T-Maze:

python -m pobax.algos.ppo --env tmaze_5 --debug

Here's a small example of how to sweep hyperparameters using recurrent PPO agent in RockSample(11, 11):

python -m pobax.algos.ppo --env rocksample_11_11 --num_envs 16 --entropy_coeff 0.2 --lambda0 0.7 --lr 0.0025 0.00025 --total_steps 5000000 --seed 2024 --platform cpu --n_seeds 5 --debug

This script will run an experiment over 5 seeds over 5M steps on CPU with entropy coefficient = 0.2, GAE lambda = 0.7 and 16 parallel environments for each run, while sweeping learning rate = 0.0025, 0.00025. For more information on running experiments with POBAX, check out the EXPERIMENTS.md file.

Hyperparameters and their descriptions can be found in pobax/config.py. Any hyperparameter that has a list type can be swept.

Citation

@article{tao2025pobax,
  author = {Tao, Ruo Yu and Guo, Kaicheng and Allen, Cameron and Konidaris, George},
  title = {Benchmarking Partial Observability in Reinforcement Learning with a Suite of Memory-Improvable Domains},
  booktitle = {Proceedings of the Second Reinforcement Learning Conference},
  journal = {The Reinforcement Learning Journal}
  url = {http://github.com/taodav/pobax},
  year = {2025},
}

About

Partially Observable Benchmarks in JAX

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •