diff --git a/acme/agents/jax/iql/README.md b/acme/agents/jax/iql/README.md new file mode 100644 index 0000000000..2964f4b1cb --- /dev/null +++ b/acme/agents/jax/iql/README.md @@ -0,0 +1,85 @@ +# Implicit Q-Learning (IQL) + +This directory contains an implementation of Implicit Q-Learning (IQL), an +offline reinforcement learning algorithm. + +## Overview + +IQL is designed for learning from fixed datasets without online interaction. +Unlike other offline RL methods, IQL avoids querying values of out-of-sample +actions, which helps prevent overestimation and distributional shift issues. + +## Key Features + +- **Expectile Regression**: Uses expectile regression to learn a value function + that approximates the upper expectile of Q-values, implicitly estimating the + value of the best actions. + +- **No Out-of-Sample Queries**: Never evaluates actions outside the dataset, + avoiding distributional shift problems. + +- **Advantage-Weighted Regression**: Extracts the policy using advantage- + weighted behavioral cloning, which maximizes Q-values while staying close to + the data distribution. + +## Algorithm Components + +1. **Value Function (V)**: Trained with expectile regression to estimate state + values as an upper expectile of Q-values. + +2. **Q-Function**: Trained with standard TD learning using the value function + for next state values. + +3. **Policy**: Trained with advantage-weighted regression to maximize Q-values + weighted by advantages. + +## Usage + +```python +from acme.agents.jax import iql +from acme import specs + +# Create networks +environment_spec = specs.make_environment_spec(environment) +networks = iql.make_networks(environment_spec) + +# Configure IQL +config = iql.IQLConfig( + expectile=0.7, # Higher values are more conservative + temperature=3.0, # Higher values give more weight to high-advantage actions + batch_size=256, +) + +# Create builder +builder = iql.IQLBuilder(config) + +# Create learner +learner = builder.make_learner( + random_key=jax.random.PRNGKey(0), + networks=networks, + dataset=dataset_iterator, + logger_fn=logger_factory, + environment_spec=environment_spec, +) +``` + +## Hyperparameters + +- **expectile** (default: 0.7): Controls the expectile for value function. + Values > 0.5 give upper expectiles. Higher values (e.g., 0.9) are more + conservative. + +- **temperature** (default: 3.0): Inverse temperature for advantage weighting. + Higher values give more weight to high-advantage actions. + +- **tau** (default: 0.005): Target network update coefficient (Polyak + averaging). + +- **discount** (default: 0.99): Discount factor for TD updates. + +## References + +Kostrikov, I., Nair, A., & Levine, S. (2021). Offline Reinforcement Learning +with Implicit Q-Learning. arXiv preprint arXiv:2110.06169. + +https://arxiv.org/abs/2110.06169 diff --git a/acme/agents/jax/iql/__init__.py b/acme/agents/jax/iql/__init__.py new file mode 100644 index 0000000000..c391e02e91 --- /dev/null +++ b/acme/agents/jax/iql/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implicit Q-Learning (IQL) agent implementation.""" + +from acme.agents.jax.iql.builder import IQLBuilder +from acme.agents.jax.iql.config import IQLConfig +from acme.agents.jax.iql.learning import IQLLearner +from acme.agents.jax.iql.networks import IQLNetworks +from acme.agents.jax.iql.networks import make_networks + +__all__ = [ + 'IQLBuilder', + 'IQLConfig', + 'IQLLearner', + 'IQLNetworks', + 'make_networks', +] diff --git a/acme/agents/jax/iql/agent_test.py b/acme/agents/jax/iql/agent_test.py new file mode 100644 index 0000000000..a0adf1dc22 --- /dev/null +++ b/acme/agents/jax/iql/agent_test.py @@ -0,0 +1,110 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for IQL agent.""" + +from absl.testing import absltest +from acme import specs +from acme.agents.jax import iql +from acme.testing import fakes +from acme.utils import counting +from acme.utils import loggers +import jax +import numpy as np + + +class IQLTest(absltest.TestCase): + """Basic tests for IQL agent components.""" + + def test_iql_networks_creation(self): + """Test that IQL networks can be created.""" + # Create a simple environment spec + env = fakes.ContinuousEnvironment( + episode_length=10, + action_dim=2, + observation_dim=4, + bounded=True) + env_spec = specs.make_environment_spec(env) + + # Create networks + networks = iql.make_networks(env_spec) + + # Check that all networks are created + self.assertIsNotNone(networks.policy_network) + self.assertIsNotNone(networks.q_network) + self.assertIsNotNone(networks.value_network) + self.assertIsNotNone(networks.log_prob) + self.assertIsNotNone(networks.sample) + self.assertIsNotNone(networks.sample_eval) + + def test_iql_config(self): + """Test IQL config creation with default values.""" + config = iql.IQLConfig() + + self.assertEqual(config.batch_size, 256) + self.assertEqual(config.expectile, 0.7) + self.assertEqual(config.temperature, 3.0) + self.assertEqual(config.discount, 0.99) + + def test_iql_builder(self): + """Test that IQL builder can be created.""" + config = iql.IQLConfig(batch_size=64) + builder = iql.IQLBuilder(config) + + self.assertIsNotNone(builder) + + def test_iql_learner_creation(self): + """Test that IQL learner can be created and run.""" + # Create environment + env = fakes.ContinuousEnvironment( + episode_length=10, + action_dim=2, + observation_dim=4, + bounded=True) + env_spec = specs.make_environment_spec(env) + + # Create networks + networks = iql.make_networks(env_spec) + + # Create fake dataset + dataset = fakes.transition_iterator(env)(batch_size=32) + + # Create learner + config = iql.IQLConfig(batch_size=32) + learner = iql.IQLLearner( + batch_size=config.batch_size, + networks=networks, + random_key=jax.random.PRNGKey(0), + demonstrations=dataset, + policy_optimizer=iql.optax.adam(config.policy_learning_rate), + value_optimizer=iql.optax.adam(config.value_learning_rate), + critic_optimizer=iql.optax.adam(config.critic_learning_rate), + tau=config.tau, + expectile=config.expectile, + temperature=config.temperature, + discount=config.discount, + counter=counting.Counter(), + logger=loggers.NoOpLogger()) + + # Run a few training steps + for _ in range(5): + learner.step() + + # Check that parameters can be retrieved + policy_params = learner.get_variables(['policy'])[0] + self.assertIsNotNone(policy_params) + + +if __name__ == '__main__': + absltest.main() diff --git a/acme/agents/jax/iql/builder.py b/acme/agents/jax/iql/builder.py new file mode 100644 index 0000000000..e16509af57 --- /dev/null +++ b/acme/agents/jax/iql/builder.py @@ -0,0 +1,143 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IQL Builder.""" +from typing import Iterator, Optional + +from acme import core +from acme import specs +from acme import types +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import builders +from acme.agents.jax.iql import config as iql_config +from acme.agents.jax.iql import learning +from acme.agents.jax.iql import networks as iql_networks +from acme.jax import networks as networks_lib +from acme.jax import variable_utils +from acme.utils import counting +from acme.utils import loggers +import optax + + +class IQLBuilder(builders.OfflineBuilder[iql_networks.IQLNetworks, + actor_core_lib.FeedForwardPolicy, + types.Transition]): + """IQL Builder. + + Constructs the components needed for Implicit Q-Learning agent, + including the learner, policy, and actor. + """ + + def __init__(self, config: iql_config.IQLConfig): + """Creates an IQL builder. + + Args: + config: Configuration with IQL hyperparameters. + """ + self._config = config + + def make_learner( + self, + random_key: networks_lib.PRNGKey, + networks: iql_networks.IQLNetworks, + dataset: Iterator[types.Transition], + logger_fn: loggers.LoggerFactory, + environment_spec: specs.EnvironmentSpec, + *, + counter: Optional[counting.Counter] = None, + ) -> core.Learner: + """Creates an IQL learner. + + Args: + random_key: Random number generator key. + networks: IQL networks (policy, Q-function, value function). + dataset: Iterator over offline dataset. + logger_fn: Factory for creating loggers. + environment_spec: Environment specification. + counter: Counter for tracking training progress. + + Returns: + IQL learner instance. + """ + del environment_spec + + return learning.IQLLearner( + batch_size=self._config.batch_size, + networks=networks, + random_key=random_key, + demonstrations=dataset, + policy_optimizer=optax.adam(self._config.policy_learning_rate), + value_optimizer=optax.adam(self._config.value_learning_rate), + critic_optimizer=optax.adam(self._config.critic_learning_rate), + tau=self._config.tau, + expectile=self._config.expectile, + temperature=self._config.temperature, + discount=self._config.discount, + num_sgd_steps_per_step=self._config.num_sgd_steps_per_step, + logger=logger_fn('learner'), + counter=counter) + + def make_actor( + self, + random_key: networks_lib.PRNGKey, + policy: actor_core_lib.FeedForwardPolicy, + environment_spec: specs.EnvironmentSpec, + variable_source: Optional[core.VariableSource] = None, + ) -> core.Actor: + """Creates an actor for policy evaluation. + + Args: + random_key: Random number generator key. + policy: Policy function to execute. + environment_spec: Environment specification. + variable_source: Source for policy parameters. + + Returns: + Actor instance. + """ + del environment_spec + assert variable_source is not None + actor_core = actor_core_lib.batched_feed_forward_to_actor_core(policy) + variable_client = variable_utils.VariableClient( + variable_source, 'policy', device='cpu') + return actors.GenericActor( + actor_core, random_key, variable_client, backend='cpu') + + def make_policy( + self, + networks: iql_networks.IQLNetworks, + environment_spec: specs.EnvironmentSpec, + evaluation: bool) -> actor_core_lib.FeedForwardPolicy: + """Constructs the policy function. + + Args: + networks: IQL networks. + environment_spec: Environment specification. + evaluation: Whether this is for evaluation (deterministic) or training. + + Returns: + Policy function that maps (params, key, observation) -> action. + """ + del environment_spec, evaluation + + def policy( + params: networks_lib.Params, + key: networks_lib.PRNGKey, + observation: networks_lib.Observation) -> networks_lib.Action: + """Evaluation policy (deterministic).""" + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) + + return policy diff --git a/acme/agents/jax/iql/config.py b/acme/agents/jax/iql/config.py new file mode 100644 index 0000000000..3b6edd7772 --- /dev/null +++ b/acme/agents/jax/iql/config.py @@ -0,0 +1,47 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration options for Implicit Q-Learning (IQL).""" +import dataclasses +from typing import Optional + + +@dataclasses.dataclass +class IQLConfig: + """Configuration options for IQL. + + Attributes: + batch_size: Batch size for training. + value_learning_rate: Learning rate for the value function optimizer. + critic_learning_rate: Learning rate for the Q-function optimizer. + policy_learning_rate: Learning rate for the policy optimizer. + tau: Target network update coefficient (Polyak averaging). + expectile: Expectile parameter for value function (τ in paper). + Higher values (e.g., 0.9) are more conservative. + temperature: Inverse temperature (β) for advantage-weighted regression. + Higher values give more weight to high-advantage actions. + discount: Discount factor for TD updates. + num_sgd_steps_per_step: Number of gradient updates per environment step. + num_bc_iters: Number of behavioral cloning iterations for policy warmup. + """ + batch_size: int = 256 + value_learning_rate: float = 3e-4 + critic_learning_rate: float = 3e-4 + policy_learning_rate: float = 3e-4 + tau: float = 0.005 + expectile: float = 0.7 + temperature: float = 3.0 + discount: float = 0.99 + num_sgd_steps_per_step: int = 1 + num_bc_iters: int = 0 diff --git a/acme/agents/jax/iql/learning.py b/acme/agents/jax/iql/learning.py new file mode 100644 index 0000000000..b5c0163f2d --- /dev/null +++ b/acme/agents/jax/iql/learning.py @@ -0,0 +1,390 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IQL learner implementation.""" + +import time +from typing import Dict, Iterator, NamedTuple, Optional + +import acme +from acme import types +from acme.agents.jax.iql import networks as iql_networks +from acme.jax import networks as networks_lib +from acme.jax import utils +from acme.utils import counting +from acme.utils import loggers +import jax +import jax.numpy as jnp +import optax + + +class TrainingState(NamedTuple): + """Contains training state for the IQL learner. + + Attributes: + policy_optimizer_state: Optimizer state for policy network. + value_optimizer_state: Optimizer state for value function network. + critic_optimizer_state: Optimizer state for Q-function network. + policy_params: Parameters of the policy network. + value_params: Parameters of the value function network. + critic_params: Parameters of the Q-function network. + target_critic_params: Target network parameters for Q-function. + key: Random number generator key. + steps: Number of training steps completed. + """ + policy_optimizer_state: optax.OptState + value_optimizer_state: optax.OptState + critic_optimizer_state: optax.OptState + policy_params: networks_lib.Params + value_params: networks_lib.Params + critic_params: networks_lib.Params + target_critic_params: networks_lib.Params + key: networks_lib.PRNGKey + steps: int = 0 + + +class IQLLearner(acme.Learner): + """IQL learner. + + Learning component of the Implicit Q-Learning algorithm from + Kostrikov et al., 2021: https://arxiv.org/abs/2110.06169 + + IQL is an offline RL algorithm that avoids querying values of out-of-sample + actions by using expectile regression for the value function and advantage- + weighted behavioral cloning for policy extraction. + """ + + _state: TrainingState + + def __init__( + self, + batch_size: int, + networks: iql_networks.IQLNetworks, + random_key: networks_lib.PRNGKey, + demonstrations: Iterator[types.Transition], + policy_optimizer: optax.GradientTransformation, + value_optimizer: optax.GradientTransformation, + critic_optimizer: optax.GradientTransformation, + tau: float = 0.005, + expectile: float = 0.7, + temperature: float = 3.0, + discount: float = 0.99, + num_sgd_steps_per_step: int = 1, + counter: Optional[counting.Counter] = None, + logger: Optional[loggers.Logger] = None): + """Initializes the IQL learner. + + Args: + batch_size: Batch size for training. + networks: IQL networks (policy, Q-function, value function). + random_key: Random number generator key. + demonstrations: Iterator over offline training data. + policy_optimizer: Optimizer for policy network. + value_optimizer: Optimizer for value function network. + critic_optimizer: Optimizer for Q-function network. + tau: Target network update coefficient (Polyak averaging). + expectile: Expectile parameter for value function (0.5 = mean, >0.5 = upper expectile). + temperature: Inverse temperature for advantage-weighted regression. + discount: Discount factor for TD updates. + num_sgd_steps_per_step: Number of gradient updates per step. + counter: Counter for tracking training progress. + logger: Logger for metrics. + """ + self._batch_size = batch_size + self._networks = networks + self._demonstrations = demonstrations + self._tau = tau + self._expectile = expectile + self._temperature = temperature + self._discount = discount + self._num_sgd_steps_per_step = num_sgd_steps_per_step + self._counter = counter or counting.Counter() + self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) + + # Initialize network parameters + key_policy, key_value, key_critic, key = jax.random.split(random_key, 4) + + dummy_obs = utils.zeros_like(networks.environment_specs.observations) + dummy_action = utils.zeros_like(networks.environment_specs.actions) + + policy_params = networks.policy_network.init(key_policy, dummy_obs) + value_params = networks.value_network.init(key_value, dummy_obs) + critic_params = networks.q_network.init(key_critic, dummy_obs, dummy_action) + target_critic_params = critic_params + + # Initialize optimizers + policy_optimizer_state = policy_optimizer.init(policy_params) + value_optimizer_state = value_optimizer.init(value_params) + critic_optimizer_state = critic_optimizer.init(critic_params) + + # Store state + self._state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + value_optimizer_state=value_optimizer_state, + critic_optimizer_state=critic_optimizer_state, + policy_params=policy_params, + value_params=value_params, + critic_params=critic_params, + target_critic_params=target_critic_params, + key=key, + steps=0) + + # Store optimizers + self._policy_optimizer = policy_optimizer + self._value_optimizer = value_optimizer + self._critic_optimizer = critic_optimizer + + # Define update functions + def expectile_loss(diff: jnp.ndarray, expectile: float) -> jnp.ndarray: + """Asymmetric squared loss for expectile regression. + + Args: + diff: Difference between target and prediction. + expectile: Expectile parameter (0.5 = MSE, >0.5 = upper expectile). + + Returns: + Expectile loss value. + """ + weight = jnp.where(diff > 0, expectile, (1 - expectile)) + return weight * (diff ** 2) + + def value_loss_fn( + value_params: networks_lib.Params, + critic_params: networks_lib.Params, + transitions: types.Transition) -> jnp.ndarray: + """Computes value function loss using expectile regression. + + The value function is trained to approximate an upper expectile of the + Q-values, which implicitly estimates the value of the best actions. + + Args: + value_params: Value function parameters. + critic_params: Q-function parameters (frozen during value update). + transitions: Batch of transitions. + + Returns: + Scalar loss value. + """ + # Compute current Q-values + q_values = networks.q_network.apply( + critic_params, transitions.observation, transitions.action) + q_values = jnp.min(q_values, axis=-1) # Take minimum over ensemble + + # Compute value predictions + v_pred = networks.value_network.apply(value_params, transitions.observation) + v_pred = jnp.squeeze(v_pred, axis=-1) + + # Expectile regression loss + diff = q_values - v_pred + loss = expectile_loss(diff, self._expectile).mean() + + return loss + + def critic_loss_fn( + critic_params: networks_lib.Params, + value_params: networks_lib.Params, + target_critic_params: networks_lib.Params, + transitions: types.Transition) -> jnp.ndarray: + """Computes Q-function loss using TD learning. + + The Q-function is trained with standard temporal difference learning, + but uses the value function (instead of max Q) for the next state value. + + Args: + critic_params: Q-function parameters. + value_params: Value function parameters (frozen during Q update). + target_critic_params: Target Q-function parameters. + transitions: Batch of transitions. + + Returns: + Scalar loss value. + """ + # Compute next state values + next_v = networks.value_network.apply(value_params, transitions.next_observation) + next_v = jnp.squeeze(next_v, axis=-1) + + # Compute TD targets + target_q = transitions.reward + self._discount * transitions.discount * next_v + + # Compute current Q predictions + q_pred = networks.q_network.apply( + critic_params, transitions.observation, transitions.action) + + # MSE loss + loss = ((q_pred - jnp.expand_dims(target_q, -1)) ** 2).mean() + + return loss + + def policy_loss_fn( + policy_params: networks_lib.Params, + value_params: networks_lib.Params, + critic_params: networks_lib.Params, + transitions: types.Transition, + key: networks_lib.PRNGKey) -> jnp.ndarray: + """Computes policy loss using advantage-weighted regression. + + The policy is trained to maximize Q-values weighted by advantage, + which is equivalent to behavioral cloning with advantage weights. + + Args: + policy_params: Policy parameters. + value_params: Value function parameters (frozen). + critic_params: Q-function parameters (frozen). + transitions: Batch of transitions. + key: Random key (unused but kept for compatibility). + + Returns: + Scalar loss value. + """ + # Compute advantages + v_values = networks.value_network.apply(value_params, transitions.observation) + v_values = jnp.squeeze(v_values, axis=-1) + + q_values = networks.q_network.apply( + critic_params, transitions.observation, transitions.action) + q_values = jnp.min(q_values, axis=-1) + + advantages = q_values - v_values + + # Compute log probabilities + dist_params = networks.policy_network.apply(policy_params, transitions.observation) + log_probs = networks.log_prob(dist_params, transitions.action) + + # Advantage-weighted regression + weights = jnp.exp(advantages * self._temperature) + weights = jnp.minimum(weights, 100.0) # Clip for numerical stability + + loss = -(weights * log_probs).mean() + + return loss + + # JIT compile update functions + self._value_loss_fn = jax.jit(value_loss_fn) + self._critic_loss_fn = jax.jit(critic_loss_fn) + self._policy_loss_fn = jax.jit(policy_loss_fn) + + def value_update_step( + state: TrainingState, + transitions: types.Transition) -> TrainingState: + """Performs one gradient step for value function.""" + value_loss, value_grads = jax.value_and_grad(value_loss_fn)( + state.value_params, state.critic_params, transitions) + + value_updates, value_optimizer_state = self._value_optimizer.update( + value_grads, state.value_optimizer_state) + value_params = optax.apply_updates(state.value_params, value_updates) + + return state._replace( + value_params=value_params, + value_optimizer_state=value_optimizer_state) + + def critic_update_step( + state: TrainingState, + transitions: types.Transition) -> TrainingState: + """Performs one gradient step for Q-function.""" + critic_loss, critic_grads = jax.value_and_grad(critic_loss_fn)( + state.critic_params, state.value_params, + state.target_critic_params, transitions) + + critic_updates, critic_optimizer_state = self._critic_optimizer.update( + critic_grads, state.critic_optimizer_state) + critic_params = optax.apply_updates(state.critic_params, critic_updates) + + # Update target network + target_critic_params = jax.tree_util.tree_map( + lambda x, y: x * (1 - self._tau) + y * self._tau, + state.target_critic_params, critic_params) + + return state._replace( + critic_params=critic_params, + critic_optimizer_state=critic_optimizer_state, + target_critic_params=target_critic_params) + + def policy_update_step( + state: TrainingState, + transitions: types.Transition) -> TrainingState: + """Performs one gradient step for policy.""" + policy_loss, policy_grads = jax.value_and_grad(policy_loss_fn)( + state.policy_params, state.value_params, + state.critic_params, transitions, state.key) + + policy_updates, policy_optimizer_state = self._policy_optimizer.update( + policy_grads, state.policy_optimizer_state) + policy_params = optax.apply_updates(state.policy_params, policy_updates) + + return state._replace( + policy_params=policy_params, + policy_optimizer_state=policy_optimizer_state) + + self._value_update_step = jax.jit(value_update_step) + self._critic_update_step = jax.jit(critic_update_step) + self._policy_update_step = jax.jit(policy_update_step) + + def step(self): + """Performs a single learner step (multiple gradient updates).""" + for _ in range(self._num_sgd_steps_per_step): + # Sample batch + transitions = next(self._demonstrations) + + # Update value function + self._state = self._value_update_step(self._state, transitions) + + # Update Q-function + self._state = self._critic_update_step(self._state, transitions) + + # Update policy + self._state = self._policy_update_step(self._state, transitions) + + # Increment step counter + self._state = self._state._replace(steps=self._state.steps + 1) + + # Update counters and log + counts = self._counter.increment(steps=1) + + # Periodically log metrics + if self._state.steps % 100 == 0: + # Compute losses for logging + transitions = next(self._demonstrations) + value_loss = self._value_loss_fn( + self._state.value_params, self._state.critic_params, transitions) + critic_loss = self._critic_loss_fn( + self._state.critic_params, self._state.value_params, + self._state.target_critic_params, transitions) + policy_loss = self._policy_loss_fn( + self._state.policy_params, self._state.value_params, + self._state.critic_params, transitions, self._state.key) + + self._logger.write({ + 'value_loss': float(value_loss), + 'critic_loss': float(critic_loss), + 'policy_loss': float(policy_loss), + **counts + }) + + def get_variables(self, names: list[str]) -> list[networks_lib.Params]: + """Returns network parameters.""" + variables = { + 'policy': self._state.policy_params, + 'critic': self._state.critic_params, + 'value': self._state.value_params, + } + return [variables[name] for name in names] + + def save(self) -> TrainingState: + """Returns current training state for checkpointing.""" + return self._state + + def restore(self, state: TrainingState): + """Restores training state from checkpoint.""" + self._state = state diff --git a/acme/agents/jax/iql/networks.py b/acme/agents/jax/iql/networks.py new file mode 100644 index 0000000000..7864226e31 --- /dev/null +++ b/acme/agents/jax/iql/networks.py @@ -0,0 +1,81 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Network definitions for the IQL agent.""" +import dataclasses +from typing import Optional + +from acme import specs +from acme.agents.jax import sac +from acme.jax import networks as networks_lib + + +@dataclasses.dataclass +class IQLNetworks: + """Networks and pure functions for the IQL agent. + + Attributes: + policy_network: Policy network that outputs action distribution parameters. + q_network: Q-function network that estimates state-action values. + value_network: Value function network that estimates state values. + log_prob: Function to compute log probability of actions. + sample: Function to sample actions from policy. + sample_eval: Function to sample actions for evaluation (typically deterministic). + environment_specs: Environment specifications. + """ + policy_network: networks_lib.FeedForwardNetwork + q_network: networks_lib.FeedForwardNetwork + value_network: networks_lib.FeedForwardNetwork + log_prob: networks_lib.LogProbFn + sample: Optional[networks_lib.SampleFn] + sample_eval: Optional[networks_lib.SampleFn] + environment_specs: specs.EnvironmentSpec + + +def make_networks( + spec: specs.EnvironmentSpec, + hidden_layer_sizes: tuple[int, ...] = (256, 256), + **kwargs) -> IQLNetworks: + """Creates networks for IQL agent. + + Args: + spec: Environment specification. + hidden_layer_sizes: Sizes of hidden layers for all networks. + **kwargs: Additional arguments passed to SAC network creation. + + Returns: + IQLNetworks containing policy, Q-function, and value function networks. + """ + # Use SAC networks for policy and Q-function + sac_networks = sac.make_networks( + spec, + hidden_layer_sizes=hidden_layer_sizes, + **kwargs) + + # Create value network (state -> scalar) + action_spec = spec.actions + observation_spec = spec.observations + + value_network = networks_lib.LayerNormMLP( + layer_sizes=list(hidden_layer_sizes) + [1], + activate_final=False) + + return IQLNetworks( + policy_network=sac_networks.policy_network, + q_network=sac_networks.q_network, + value_network=value_network, + log_prob=sac_networks.log_prob, + sample=sac_networks.sample, + sample_eval=sac_networks.sample_eval, + environment_specs=spec) diff --git a/examples/offline/run_iql_jax.py b/examples/offline/run_iql_jax.py new file mode 100644 index 0000000000..43a34c62a4 --- /dev/null +++ b/examples/offline/run_iql_jax.py @@ -0,0 +1,129 @@ +# Copyright 2018 DeepMind Technologies Limited. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example IQL agent running on D4RL locomotion datasets.""" + +from absl import app +from absl import flags +import acme +from acme import specs +from acme.agents.jax import actor_core as actor_core_lib +from acme.agents.jax import actors +from acme.agents.jax import iql +from acme.datasets import tfds +from acme.examples.offline import helpers as gym_helpers +from acme.jax import variable_utils +from acme.utils import loggers +import haiku as hk +import jax +import optax + +# Agent flags +flags.DEFINE_integer('batch_size', 256, 'Batch size.') +flags.DEFINE_integer('evaluate_every', 20, 'Evaluation period.') +flags.DEFINE_integer('evaluation_episodes', 10, 'Evaluation episodes.') +flags.DEFINE_integer( + 'num_demonstrations', None, + 'Number of demonstration episodes to load. If None, loads full dataset.') +flags.DEFINE_integer('seed', 0, 'Random seed for learner and evaluator.') + +# IQL specific flags +flags.DEFINE_float('policy_learning_rate', 3e-4, 'Policy learning rate.') +flags.DEFINE_float('value_learning_rate', 3e-4, 'Value function learning rate.') +flags.DEFINE_float('critic_learning_rate', 3e-4, 'Q-function learning rate.') +flags.DEFINE_float('expectile', 0.7, + 'Expectile for value function. Higher is more conservative.') +flags.DEFINE_float('temperature', 3.0, + 'Temperature for advantage weighting. Higher gives more weight to high advantages.') +flags.DEFINE_float('tau', 0.005, 'Target network update coefficient.') +flags.DEFINE_float('discount', 0.99, 'Discount factor.') + +# Environment flags +flags.DEFINE_string('env_name', 'HalfCheetah-v2', 'Gym mujoco environment name.') +flags.DEFINE_string( + 'dataset_name', 'd4rl_mujoco_halfcheetah/v2-medium', + 'D4RL dataset name. Can be any locomotion dataset from ' + 'https://www.tensorflow.org/datasets/catalog/overview#d4rl.') + +FLAGS = flags.FLAGS + + +def main(_): + key = jax.random.PRNGKey(FLAGS.seed) + key_demonstrations, key_learner = jax.random.split(key, 2) + + # Create environment and get specification + environment = gym_helpers.make_environment(task=FLAGS.env_name) + environment_spec = specs.make_environment_spec(environment) + + # Load demonstrations dataset + transitions_iterator = tfds.get_tfds_dataset( + FLAGS.dataset_name, FLAGS.num_demonstrations) + demonstrations = tfds.JaxInMemoryRandomSampleIterator( + transitions_iterator, + key=key_demonstrations, + batch_size=FLAGS.batch_size) + + # Create networks + networks = iql.make_networks(environment_spec) + + # Create learner + learner = iql.IQLLearner( + batch_size=FLAGS.batch_size, + networks=networks, + random_key=key_learner, + demonstrations=demonstrations, + policy_optimizer=optax.adam(FLAGS.policy_learning_rate), + value_optimizer=optax.adam(FLAGS.value_learning_rate), + critic_optimizer=optax.adam(FLAGS.critic_learning_rate), + tau=FLAGS.tau, + expectile=FLAGS.expectile, + temperature=FLAGS.temperature, + discount=FLAGS.discount, + num_sgd_steps_per_step=1) + + def evaluator_network( + params: hk.Params, + key: jax.Array, + observation: jax.Array) -> jax.Array: + """Evaluation policy (deterministic).""" + dist_params = networks.policy_network.apply(params, observation) + return networks.sample_eval(dist_params, key) + + # Create evaluator actor + actor_core = actor_core_lib.batched_feed_forward_to_actor_core( + evaluator_network) + variable_client = variable_utils.VariableClient( + learner, 'policy', device='cpu') + evaluator = actors.GenericActor( + actor_core, key, variable_client, backend='cpu') + + # Create evaluation loop + eval_loop = acme.EnvironmentLoop( + environment=environment, + actor=evaluator, + logger=loggers.TerminalLogger('evaluation', time_delta=0.)) + + # Training loop + print(f'Training IQL on {FLAGS.dataset_name}...') + print(f'Hyperparameters: expectile={FLAGS.expectile}, temperature={FLAGS.temperature}') + + while True: + for _ in range(FLAGS.evaluate_every): + learner.step() + eval_loop.run(FLAGS.evaluation_episodes) + + +if __name__ == '__main__': + app.run(main)