Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions acme/agents/jax/iql/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Implicit Q-Learning (IQL)

This directory contains an implementation of Implicit Q-Learning (IQL), an
offline reinforcement learning algorithm.

## Overview

IQL is designed for learning from fixed datasets without online interaction.
Unlike other offline RL methods, IQL avoids querying values of out-of-sample
actions, which helps prevent overestimation and distributional shift issues.

## Key Features

- **Expectile Regression**: Uses expectile regression to learn a value function
that approximates the upper expectile of Q-values, implicitly estimating the
value of the best actions.

- **No Out-of-Sample Queries**: Never evaluates actions outside the dataset,
avoiding distributional shift problems.

- **Advantage-Weighted Regression**: Extracts the policy using advantage-
weighted behavioral cloning, which maximizes Q-values while staying close to
the data distribution.

## Algorithm Components

1. **Value Function (V)**: Trained with expectile regression to estimate state
values as an upper expectile of Q-values.

2. **Q-Function**: Trained with standard TD learning using the value function
for next state values.

3. **Policy**: Trained with advantage-weighted regression to maximize Q-values
weighted by advantages.

## Usage

```python
from acme.agents.jax import iql
from acme import specs

# Create networks
environment_spec = specs.make_environment_spec(environment)
networks = iql.make_networks(environment_spec)

# Configure IQL
config = iql.IQLConfig(
expectile=0.7, # Higher values are more conservative
temperature=3.0, # Higher values give more weight to high-advantage actions
batch_size=256,
)

# Create builder
builder = iql.IQLBuilder(config)

# Create learner
learner = builder.make_learner(
random_key=jax.random.PRNGKey(0),
networks=networks,
dataset=dataset_iterator,
logger_fn=logger_factory,
environment_spec=environment_spec,
)
```

## Hyperparameters

- **expectile** (default: 0.7): Controls the expectile for value function.
Values > 0.5 give upper expectiles. Higher values (e.g., 0.9) are more
conservative.

- **temperature** (default: 3.0): Inverse temperature for advantage weighting.
Higher values give more weight to high-advantage actions.

- **tau** (default: 0.005): Target network update coefficient (Polyak
averaging).

- **discount** (default: 0.99): Discount factor for TD updates.

## References

Kostrikov, I., Nair, A., & Levine, S. (2021). Offline Reinforcement Learning
with Implicit Q-Learning. arXiv preprint arXiv:2110.06169.

https://arxiv.org/abs/2110.06169
29 changes: 29 additions & 0 deletions acme/agents/jax/iql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implicit Q-Learning (IQL) agent implementation."""

from acme.agents.jax.iql.builder import IQLBuilder
from acme.agents.jax.iql.config import IQLConfig
from acme.agents.jax.iql.learning import IQLLearner
from acme.agents.jax.iql.networks import IQLNetworks
from acme.agents.jax.iql.networks import make_networks

__all__ = [
'IQLBuilder',
'IQLConfig',
'IQLLearner',
'IQLNetworks',
'make_networks',
]
110 changes: 110 additions & 0 deletions acme/agents/jax/iql/agent_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for IQL agent."""

from absl.testing import absltest
from acme import specs
from acme.agents.jax import iql
from acme.testing import fakes
from acme.utils import counting
from acme.utils import loggers
import jax
import numpy as np


class IQLTest(absltest.TestCase):
"""Basic tests for IQL agent components."""

def test_iql_networks_creation(self):
"""Test that IQL networks can be created."""
# Create a simple environment spec
env = fakes.ContinuousEnvironment(
episode_length=10,
action_dim=2,
observation_dim=4,
bounded=True)
env_spec = specs.make_environment_spec(env)

# Create networks
networks = iql.make_networks(env_spec)

# Check that all networks are created
self.assertIsNotNone(networks.policy_network)
self.assertIsNotNone(networks.q_network)
self.assertIsNotNone(networks.value_network)
self.assertIsNotNone(networks.log_prob)
self.assertIsNotNone(networks.sample)
self.assertIsNotNone(networks.sample_eval)

def test_iql_config(self):
"""Test IQL config creation with default values."""
config = iql.IQLConfig()

self.assertEqual(config.batch_size, 256)
self.assertEqual(config.expectile, 0.7)
self.assertEqual(config.temperature, 3.0)
self.assertEqual(config.discount, 0.99)

def test_iql_builder(self):
"""Test that IQL builder can be created."""
config = iql.IQLConfig(batch_size=64)
builder = iql.IQLBuilder(config)

self.assertIsNotNone(builder)

def test_iql_learner_creation(self):
"""Test that IQL learner can be created and run."""
# Create environment
env = fakes.ContinuousEnvironment(
episode_length=10,
action_dim=2,
observation_dim=4,
bounded=True)
env_spec = specs.make_environment_spec(env)

# Create networks
networks = iql.make_networks(env_spec)

# Create fake dataset
dataset = fakes.transition_iterator(env)(batch_size=32)

# Create learner
config = iql.IQLConfig(batch_size=32)
learner = iql.IQLLearner(
batch_size=config.batch_size,
networks=networks,
random_key=jax.random.PRNGKey(0),
demonstrations=dataset,
policy_optimizer=iql.optax.adam(config.policy_learning_rate),
value_optimizer=iql.optax.adam(config.value_learning_rate),
critic_optimizer=iql.optax.adam(config.critic_learning_rate),
tau=config.tau,
expectile=config.expectile,
temperature=config.temperature,
discount=config.discount,
counter=counting.Counter(),
logger=loggers.NoOpLogger())

# Run a few training steps
for _ in range(5):
learner.step()

# Check that parameters can be retrieved
policy_params = learner.get_variables(['policy'])[0]
self.assertIsNotNone(policy_params)


if __name__ == '__main__':
absltest.main()
143 changes: 143 additions & 0 deletions acme/agents/jax/iql/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2018 DeepMind Technologies Limited. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""IQL Builder."""
from typing import Iterator, Optional

from acme import core
from acme import specs
from acme import types
from acme.agents.jax import actor_core as actor_core_lib
from acme.agents.jax import actors
from acme.agents.jax import builders
from acme.agents.jax.iql import config as iql_config
from acme.agents.jax.iql import learning
from acme.agents.jax.iql import networks as iql_networks
from acme.jax import networks as networks_lib
from acme.jax import variable_utils
from acme.utils import counting
from acme.utils import loggers
import optax


class IQLBuilder(builders.OfflineBuilder[iql_networks.IQLNetworks,
actor_core_lib.FeedForwardPolicy,
types.Transition]):
"""IQL Builder.

Constructs the components needed for Implicit Q-Learning agent,
including the learner, policy, and actor.
"""

def __init__(self, config: iql_config.IQLConfig):
"""Creates an IQL builder.

Args:
config: Configuration with IQL hyperparameters.
"""
self._config = config

def make_learner(
self,
random_key: networks_lib.PRNGKey,
networks: iql_networks.IQLNetworks,
dataset: Iterator[types.Transition],
logger_fn: loggers.LoggerFactory,
environment_spec: specs.EnvironmentSpec,
*,
counter: Optional[counting.Counter] = None,
) -> core.Learner:
"""Creates an IQL learner.

Args:
random_key: Random number generator key.
networks: IQL networks (policy, Q-function, value function).
dataset: Iterator over offline dataset.
logger_fn: Factory for creating loggers.
environment_spec: Environment specification.
counter: Counter for tracking training progress.

Returns:
IQL learner instance.
"""
del environment_spec

return learning.IQLLearner(
batch_size=self._config.batch_size,
networks=networks,
random_key=random_key,
demonstrations=dataset,
policy_optimizer=optax.adam(self._config.policy_learning_rate),
value_optimizer=optax.adam(self._config.value_learning_rate),
critic_optimizer=optax.adam(self._config.critic_learning_rate),
tau=self._config.tau,
expectile=self._config.expectile,
temperature=self._config.temperature,
discount=self._config.discount,
num_sgd_steps_per_step=self._config.num_sgd_steps_per_step,
logger=logger_fn('learner'),
counter=counter)

def make_actor(
self,
random_key: networks_lib.PRNGKey,
policy: actor_core_lib.FeedForwardPolicy,
environment_spec: specs.EnvironmentSpec,
variable_source: Optional[core.VariableSource] = None,
) -> core.Actor:
"""Creates an actor for policy evaluation.

Args:
random_key: Random number generator key.
policy: Policy function to execute.
environment_spec: Environment specification.
variable_source: Source for policy parameters.

Returns:
Actor instance.
"""
del environment_spec
assert variable_source is not None
actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy)
variable_client = variable_utils.VariableClient(
variable_source, 'policy', device='cpu')
return actors.GenericActor(
actor_core, random_key, variable_client, backend='cpu')

def make_policy(
self,
networks: iql_networks.IQLNetworks,
environment_spec: specs.EnvironmentSpec,
evaluation: bool) -> actor_core_lib.FeedForwardPolicy:
"""Constructs the policy function.

Args:
networks: IQL networks.
environment_spec: Environment specification.
evaluation: Whether this is for evaluation (deterministic) or training.

Returns:
Policy function that maps (params, key, observation) -> action.
"""
del environment_spec, evaluation

def policy(
params: networks_lib.Params,
key: networks_lib.PRNGKey,
observation: networks_lib.Observation) -> networks_lib.Action:
"""Evaluation policy (deterministic)."""
dist_params = networks.policy_network.apply(params, observation)
return networks.sample_eval(dist_params, key)

return policy
Loading