From 993e059ec4516c51a2f2c818f03a453a0e25f0a1 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 17 Apr 2023 17:56:58 +0200 Subject: [PATCH 01/25] First draft with somehwat working airl configured using Hydra. --- setup.py | 1 + src/imitation_cli/__init__.py | 0 src/imitation_cli/airl.py | 114 ++++++++++ src/imitation_cli/config/airl_sweep.yaml | 15 ++ .../config/environment/cartpole.yaml | 5 + .../config/environment/myenv.yaml | 5 + .../config/environment/retro.yaml | 5 + .../config/environment/seals.yaml | 5 + src/imitation_cli/env_test.py | 73 +++++++ src/imitation_cli/utils/__init__.py | 0 .../utils/activation_function.py | 56 +++++ src/imitation_cli/utils/environment.py | 51 +++++ src/imitation_cli/utils/feature_extractor.py | 42 ++++ src/imitation_cli/utils/optimizer.py | 42 ++++ src/imitation_cli/utils/policy.py | 199 ++++++++++++++++++ src/imitation_cli/utils/reward_network.py | 100 +++++++++ src/imitation_cli/utils/rl_algorithm.py | 95 +++++++++ src/imitation_cli/utils/schedule.py | 38 ++++ src/imitation_cli/utils/trajectories.py | 60 ++++++ 19 files changed, 906 insertions(+) create mode 100644 src/imitation_cli/__init__.py create mode 100644 src/imitation_cli/airl.py create mode 100644 src/imitation_cli/config/airl_sweep.yaml create mode 100644 src/imitation_cli/config/environment/cartpole.yaml create mode 100644 src/imitation_cli/config/environment/myenv.yaml create mode 100644 src/imitation_cli/config/environment/retro.yaml create mode 100644 src/imitation_cli/config/environment/seals.yaml create mode 100644 src/imitation_cli/env_test.py create mode 100644 src/imitation_cli/utils/__init__.py create mode 100644 src/imitation_cli/utils/activation_function.py create mode 100644 src/imitation_cli/utils/environment.py create mode 100644 src/imitation_cli/utils/feature_extractor.py create mode 100644 src/imitation_cli/utils/optimizer.py create mode 100644 src/imitation_cli/utils/policy.py create mode 100644 src/imitation_cli/utils/reward_network.py create mode 100644 src/imitation_cli/utils/rl_algorithm.py create mode 100644 src/imitation_cli/utils/schedule.py create mode 100644 src/imitation_cli/utils/trajectories.py diff --git a/setup.py b/setup.py index 73ffd00ac..32a34614d 100644 --- a/setup.py +++ b/setup.py @@ -209,6 +209,7 @@ def get_local_version(version: "ScmVersion", time_format="%Y%m%d") -> str: "tensorboard>=1.14", "huggingface_sb3>=2.2.1", "datasets>=2.8.0", + "hydra-core>=1.3.2", ], tests_require=TESTS_REQUIRE, extras_require={ diff --git a/src/imitation_cli/__init__.py b/src/imitation_cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py new file mode 100644 index 000000000..b0c285bb0 --- /dev/null +++ b/src/imitation_cli/airl.py @@ -0,0 +1,114 @@ +import dataclasses +import logging +from typing import Optional + +import hydra +import numpy as np +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + +from imitation.algorithms.adversarial.airl import AIRL +from imitation.data import rollout +from imitation_cli.utils import environment as gym_env +from imitation_cli.utils import ( + optimizer, + policy, + reward_network, + rl_algorithm, + trajectories, +) + + +@dataclasses.dataclass +class AIRLConfig: + defaults: list = dataclasses.field( + default_factory=lambda: [ + {"environment": "gym_env"}, + {"reward_net": "shaped"}, + # {"gen_algo": "ppo"}, + "_self_", + ] + ) + environment: gym_env.Config = MISSING + expert_trajs: trajectories.Config = MISSING + total_timesteps: int = int(1e6) + checkpoint_interval: int = 0 + gen_algo: rl_algorithm.Config = rl_algorithm.PPO() + reward_net: reward_network.Config = MISSING + seed: int = 0 + demo_batch_size: int = 64 + n_disc_updates_per_round: int = 2 + disc_opt_cls: optimizer.Config = optimizer.Adam + gen_train_timesteps: Optional[int] = None + gen_replay_buffer_capacity: Optional[int] = None + init_tensorboard: bool = False + init_tensorboard_graph: bool = False + debug_use_ground_truth: bool = False + allow_variable_horizon: bool = True # TODO: true just for debugging + + +cs = ConfigStore.instance() +cs.store(name="airl", node=AIRLConfig) +policy.register_configs("expert_trajs/expert_policy") +rl_algorithm.register_configs("gen_algo") +trajectories.register_configs("expert_trajs") +gym_env.register_configs("environment") +reward_network.register_configs("reward_net") + + +@hydra.main( + version_base=None, + config_path="config", + config_name="airl", +) +def run_airl(cfg: AIRLConfig) -> None: + + rng = np.random.default_rng(cfg.seed) + expert_trajs = trajectories.get_trajectories(cfg.expert_trajs, rng) + print(len(expert_trajs)) + + venv = gym_env.make_venv(cfg.environment, rng) + + reward_net = reward_network.make_reward_net(cfg.reward_net) + + gen_algo = rl_algorithm.make_rl_algo(cfg.gen_algo, rng) + + trainer = AIRL( + venv=venv, + demonstrations=expert_trajs, + gen_algo=gen_algo, + reward_net=reward_net, + demo_batch_size=cfg.demo_batch_size, + n_disc_updates_per_round=cfg.n_disc_updates_per_round, + disc_opt_cls=optimizer.make_optimizer(cfg.disc_opt_cls), + gen_train_timesteps=cfg.gen_train_timesteps, + gen_replay_buffer_capacity=cfg.gen_replay_buffer_capacity, + init_tensorboard=cfg.init_tensorboard, + init_tensorboard_graph=cfg.init_tensorboard_graph, + debug_use_ground_truth=cfg.debug_use_ground_truth, + allow_variable_horizon=cfg.allow_variable_horizon, + ) + + def callback(round_num: int, /) -> None: + if cfg.checkpoint_interval > 0 and round_num % cfg.checkpoint_interval == 0: + logging.log( + logging.INFO, + f"Saving checkpoint at round {round_num}. TODO implement this", + ) + + trainer.train(cfg.total_timesteps, callback) + # TODO: implement evaluation + # imit_stats = policy_evaluation.eval_policy(trainer.policy, trainer.venv_train) + + # Save final artifacts. + if cfg.checkpoint_interval >= 0: + logging.log(logging.INFO, f"Saving final checkpoint. TODO implement this") + + return { + # "imit_stats": imit_stats, + "expert_stats": rollout.rollout_stats(expert_trajs), + } + + +if __name__ == "__main__": + run_airl() diff --git a/src/imitation_cli/config/airl_sweep.yaml b/src/imitation_cli/config/airl_sweep.yaml new file mode 100644 index 000000000..38c069a0e --- /dev/null +++ b/src/imitation_cli/config/airl_sweep.yaml @@ -0,0 +1,15 @@ +defaults: + - airl + - expert_trajs: generated + - gen_algo: on_disk + - _self_ + +gen_algo: + environment: + max_episode_steps: 32 + +hydra: + mode: MULTIRUN + sweeper: + params: + environment: glob(*,exclude=gym_env) diff --git a/src/imitation_cli/config/environment/cartpole.yaml b/src/imitation_cli/config/environment/cartpole.yaml new file mode 100644 index 000000000..3a2db8f84 --- /dev/null +++ b/src/imitation_cli/config/environment/cartpole.yaml @@ -0,0 +1,5 @@ +defaults: + - gym_env + +env_name: CartPole-v0 +max_episode_steps: 500 diff --git a/src/imitation_cli/config/environment/myenv.yaml b/src/imitation_cli/config/environment/myenv.yaml new file mode 100644 index 000000000..3d92be40c --- /dev/null +++ b/src/imitation_cli/config/environment/myenv.yaml @@ -0,0 +1,5 @@ +defaults: + - gym_env + +env_name: my_own_env +max_episode_steps: 42 diff --git a/src/imitation_cli/config/environment/retro.yaml b/src/imitation_cli/config/environment/retro.yaml new file mode 100644 index 000000000..000c3ba46 --- /dev/null +++ b/src/imitation_cli/config/environment/retro.yaml @@ -0,0 +1,5 @@ +defaults: + - gym_env + +env_name: Retro-v0 +max_episode_steps: 4500 diff --git a/src/imitation_cli/config/environment/seals.yaml b/src/imitation_cli/config/environment/seals.yaml new file mode 100644 index 000000000..e30789de6 --- /dev/null +++ b/src/imitation_cli/config/environment/seals.yaml @@ -0,0 +1,5 @@ +defaults: + - gym_env + +env_name: Seals-v0 +max_episode_steps: 1000 diff --git a/src/imitation_cli/env_test.py b/src/imitation_cli/env_test.py new file mode 100644 index 000000000..eede51828 --- /dev/null +++ b/src/imitation_cli/env_test.py @@ -0,0 +1,73 @@ +import dataclasses + +import hydra +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + + +@dataclasses.dataclass +class EnvironmentConfig: + gym_id: str = MISSING # The environment to train on + n_envs: int = 8 # number of environments in VecEnv + parallel: bool = True # Use SubprocVecEnv rather than DummyVecEnv + max_episode_steps: int = MISSING # Set to positive int to limit episode horizons + env_make_kwargs: dict = dataclasses.field( + default_factory=dict + ) # The kwargs passed to `spec.make`. + + +@dataclasses.dataclass +class RetroEnvironmentConfig(EnvironmentConfig): + gym_id: str = "Retro-v0" + max_episode_steps: int = 4500 + + +@dataclasses.dataclass +class SealEnvironmentConfig(EnvironmentConfig): + gym_id: str = "Seal-v0" + max_episode_steps: int = 1000 + aaa: int = 555 + + +@dataclasses.dataclass +class PolicyConfig: + env: EnvironmentConfig + type: str = MISSING + + +@dataclasses.dataclass +class PPOPolicyConfig(PolicyConfig): + type: str = "ppo" + + +@dataclasses.dataclass +class RandomPolicyConfig(PolicyConfig): + type: str = "random" + + +@dataclasses.dataclass +class Config: + env: EnvironmentConfig + policy: PolicyConfig + + +cs = ConfigStore.instance() +cs.store(name="config", node=Config) + +cs.store(group="env", name="retro", node=RetroEnvironmentConfig) +cs.store(group="env", name="seal", node=SealEnvironmentConfig) + +cs.store(group="policy", name="ppo", node=PPOPolicyConfig(env="${env}")) +cs.store(group="policy", name="random", node=RandomPolicyConfig(env="${env}")) + +# cs.store(group="policy/env", name="retro", node=RetroEnvironmentConfig) +# cs.store(group="policy/env", name="seal", node=SealEnvironmentConfig) + + +@hydra.main(version_base=None, config_name="config") +def main(cfg: Config): + print(cfg) + + +if __name__ == "__main__": + main() diff --git a/src/imitation_cli/utils/__init__.py b/src/imitation_cli/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/imitation_cli/utils/activation_function.py b/src/imitation_cli/utils/activation_function.py new file mode 100644 index 000000000..8c164a805 --- /dev/null +++ b/src/imitation_cli/utils/activation_function.py @@ -0,0 +1,56 @@ +import dataclasses + +from hydra.core.config_store import ConfigStore +from hydra.utils import call +from omegaconf import MISSING + + +@dataclasses.dataclass +class Config: + # Note: we don't define _target_ here so in the subclasses it can be defined last. + # This is the same pattern we use as in schedule.py. + pass + + +@dataclasses.dataclass +class TanH(Config): + _target_: str = "imitation_cli.utils.activation_function.TanH.make" + + @staticmethod + def make(): + import torch + + return torch.nn.Tanh + + +@dataclasses.dataclass +class ReLU(Config): + _target_: str = "imitation_cli.utils.activation_function.ReLU.make" + + @staticmethod + def make(): + import torch + + return torch.nn.ReLU + + +@dataclasses.dataclass +class LeakyReLU(Config): + _target_: str = "imitation_cli.utils.activation_function.LeakyReLU.make" + + @staticmethod + def make(): + import torch + + return torch.nn.LeakyReLU + + +def make_activation_function(cfg: Config): + return call(cfg) + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="tanh", node=TanH) + cs.store(group=group, name="relu", node=ReLU) + cs.store(group=group, name="leaky_relu", node=LeakyReLU) diff --git a/src/imitation_cli/utils/environment.py b/src/imitation_cli/utils/environment.py new file mode 100644 index 000000000..6de80932a --- /dev/null +++ b/src/imitation_cli/utils/environment.py @@ -0,0 +1,51 @@ +import dataclasses + +import numpy as np +from hydra.core.config_store import ConfigStore +from omegaconf import MISSING + + +@dataclasses.dataclass +class Config: + env_name: str = MISSING # The environment to train on + n_envs: int = 8 # number of environments in VecEnv + parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv TODO: when setting this to true this is really slow for some reason + max_episode_steps: int = MISSING # Set to positive int to limit episode horizons + env_make_kwargs: dict = dataclasses.field( + default_factory=dict + ) # The kwargs passed to `spec.make`. + + +def make_venv( + environment_config: Config, + rnd: np.random.Generator, + log_dir=None, + **kwargs, +): + from imitation.util import util + + return util.make_vec_env( + **environment_config, + rng=rnd, + log_dir=log_dir, + **kwargs, + ) + + +def make_rollout_venv( + environment_config: Config, + rnd: np.random.Generator, +): + from imitation.data import wrappers + + return make_venv( + environment_config, + rnd, + log_dir=None, + post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], + ) + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="gym_env", node=Config) diff --git a/src/imitation_cli/utils/feature_extractor.py b/src/imitation_cli/utils/feature_extractor.py new file mode 100644 index 000000000..f62ca2336 --- /dev/null +++ b/src/imitation_cli/utils/feature_extractor.py @@ -0,0 +1,42 @@ +import dataclasses + +from hydra.core.config_store import ConfigStore +from hydra.utils import call +from omegaconf import MISSING + + +@dataclasses.dataclass +class Config: + _target_: str = MISSING + + +@dataclasses.dataclass +class FlattenExtractorConfig(Config): + _target_: str = "imitation_cli.utils.feature_extractor.FlattenExtractorConfig.make" + + @staticmethod + def make(): + import stable_baselines3 + + return stable_baselines3.common.torch_layers.FlattenExtractor + + +@dataclasses.dataclass +class NatureCNNConfig(Config): + _target_: str = "imitation_cli.utils.feature_extractor.NatureCNNConfig.make" + + @staticmethod + def make(): + import stable_baselines3 + + return stable_baselines3.common.torch_layers.NatureCNN + + +def make_feature_extractor(cfg: Config): + return call(cfg) + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="flatten", node=FlattenExtractorConfig) + cs.store(group=group, name="nature_cnn", node=NatureCNNConfig) diff --git a/src/imitation_cli/utils/optimizer.py b/src/imitation_cli/utils/optimizer.py new file mode 100644 index 000000000..21bef99a6 --- /dev/null +++ b/src/imitation_cli/utils/optimizer.py @@ -0,0 +1,42 @@ +import dataclasses + +from hydra.core.config_store import ConfigStore +from hydra.utils import call +from omegaconf import MISSING + + +@dataclasses.dataclass +class Config: + _target_: str = MISSING + + +@dataclasses.dataclass +class Adam(Config): + _target_: str = "imitation_cli.utils.optimizer.Adam.make" + + @staticmethod + def make(): + import torch + + return torch.optim.Adam + + +@dataclasses.dataclass +class SGD(Config): + _target_: str = "imitation_cli.utils.optimizer.SGD.make" + + @staticmethod + def make(): + import torch + + return torch.optim.SGD + + +def make_optimizer(cfg: Config): + return call(cfg) + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="adam", node=Adam) + cs.store(group=group, name="sgd", node=SGD) diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py new file mode 100644 index 000000000..8dbdd0433 --- /dev/null +++ b/src/imitation_cli/utils/policy.py @@ -0,0 +1,199 @@ +import dataclasses +import pathlib +from typing import Any, Dict, List, Optional + +import numpy as np +import stable_baselines3 as sb3 +from hydra.core.config_store import ConfigStore +from hydra.utils import call +from omegaconf import MISSING +from stable_baselines3.common.torch_layers import FlattenExtractor + +from imitation_cli.utils import activation_function +from imitation_cli.utils import environment as gym_env +from imitation_cli.utils import feature_extractor, optimizer, schedule + + +@dataclasses.dataclass +class Config: + _target_: str = MISSING + environment: gym_env.Config = "${environment}" + + +@dataclasses.dataclass +class Random(Config): + _target_: str = "imitation_cli.utils.policy.Random.make" + + @staticmethod + def make(environment: gym_env.Config, rng: np.random.Generator): + from imitation.policies import base + + env = gym_env.make_venv(environment, rng) + return base.RandomPolicy(env.observation_space, env.action_space) + + +@dataclasses.dataclass +class ZeroPolicy(Config): + _target_: str = "imitation_cli.utils.policy.ZeroPolicy.make" + + @staticmethod + def make(environment: gym_env.Config, rng: np.random.Generator): + from imitation.policies import base + + env = gym_env.make_venv(environment, rng) + return base.ZeroPolicy(env.observation_space, env.action_space) + + +@dataclasses.dataclass +class ActorCriticPolicy(Config): + _target_: str = "imitation_cli.utils.policy.ActorCriticPolicy.make" + _recursive_: bool = False + lr_schedule: schedule.Config = MISSING + net_arch: Optional[Dict[str, List[int]]] = None + activation_fn: activation_function.Config = activation_function.TanH() + ortho_init: bool = True + use_sde: bool = False + log_std_init: float = 0.0 + full_std: bool = True + use_expln: bool = False + squash_output: bool = False + features_extractor_class: feature_extractor.Config = ( + feature_extractor.FlattenExtractorConfig() + ) + features_extractor_kwargs: Optional[Dict[str, Any]] = None + share_features_extractor: bool = True + normalize_images: bool = True + optimizer_class: optimizer.Config = optimizer.Adam() + optimizer_kwargs: Optional[Dict[str, Any]] = None + + @staticmethod + def make_args( + lr_schedule: schedule.Config, + activation_fn: activation_function.Config, + features_extractor_class: feature_extractor.Config, + optimizer_class: optimizer.Config, + **kwargs, + ): + activation_fn = activation_function.make_activation_function(activation_fn) + lr_schedule = schedule.make_schedule(lr_schedule) + features_extractor_class = feature_extractor.make_feature_extractor( + features_extractor_class + ) + optimizer_class = optimizer.make_optimizer(optimizer_class) + + del kwargs["environment"] + del kwargs["_target_"] + del kwargs["_recursive_"] + + return dict( + lr_schedule=lr_schedule, + activation_fn=activation_fn, + features_extractor_class=features_extractor_class, + optimizer_class=optimizer_class, + **kwargs, + ) + + @staticmethod + def make( + environment: gym_env.Config, + rng: np.random.Generator, + lr_schedule: schedule.Config, + activation_fn: activation_function.Config, + features_extractor_class: feature_extractor.Config, + optimizer_class: optimizer.Config, + **kwargs, + ): + env = gym_env.make_venv(environment, rng) + activation_fn = activation_function.make_activation_function(activation_fn) + lr_schedule = schedule.make_schedule(lr_schedule) + features_extractor_class = feature_extractor.make_feature_extractor( + features_extractor_class + ) + optimizer_class = optimizer.make_optimizer(optimizer_class) + + return sb3.common.policies.ActorCriticPolicy( + lr_schedule=lr_schedule, + activation_fn=activation_fn, + features_extractor_class=features_extractor_class, + optimizer_class=optimizer_class, + observation_space=env.observation_space, + action_space=env.action_space, + **kwargs, + ) + + +@dataclasses.dataclass +class Loaded(Config): + type: str = "PPO" # The SB3 policy class. Only SAC and PPO supported as of now + + @staticmethod + def type_to_class(type: str): + import stable_baselines3 as sb3 + + type = type.lower() + if type == "ppo": + return sb3.PPO + if type == "ppo": + return sb3.SAC + raise ValueError(f"Unknown policy type {type}") + + +@dataclasses.dataclass +class PolicyOnDisk(Loaded): + _target_: str = "imitation_cli.utils.policy.PolicyOnDisk.make" + path: pathlib.Path = MISSING + + @staticmethod + def make( + environment: gym_env.Config, + path: pathlib.Path, + type: str, + rng: np.random.Generator, + ): + from imitation.policies import serialize + + env = gym_env.make_venv(environment, rng) + return serialize.load_stable_baselines_model( + Loaded.type_to_class(type), path, env + ).policy + + +@dataclasses.dataclass +class PolicyFromHuggingface(Loaded): + _target_: str = "imitation_cli.utils.policy.PolicyFromHuggingface.make" + organization: str = "HumanCompatibleAI" + + @staticmethod + def make( + type: str, + environment: gym_env.Config, + organization: str, + rng: np.random.Generator, + ): + import huggingface_sb3 as hfsb3 + + from imitation.policies import serialize + + model_name = hfsb3.ModelName( + type.lower(), hfsb3.EnvironmentName(environment.gym_id) + ) + repo_id = hfsb3.ModelRepoId(organization, model_name) + filename = hfsb3.load_from_hub(repo_id, model_name.filename) + env = gym_env.make_venv(environment, rng) + model = serialize.load_stable_baselines_model( + Loaded.type_to_class(type), filename, env + ) + return model.policy + + +def make_policy(cfg: Config, rng: np.random.Generator): + return call(cfg, rng=rng) + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="random", node=Random) + cs.store(group=group, name="zero", node=ZeroPolicy) + cs.store(group=group, name="on_disk", node=PolicyOnDisk) + cs.store(group=group, name="from_huggingface", node=PolicyFromHuggingface) + cs.store(group=group, name="actor_critic", node=ActorCriticPolicy) diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py new file mode 100644 index 000000000..b58fe47b0 --- /dev/null +++ b/src/imitation_cli/utils/reward_network.py @@ -0,0 +1,100 @@ +import dataclasses +from typing import Optional + +import numpy as np +from hydra.core.config_store import ConfigStore +from hydra.utils import call +from omegaconf import MISSING + +import imitation_cli.utils.environment as gym_env +from imitation.rewards import reward_nets +from imitation.util import networks + + +@dataclasses.dataclass +class Config: + _target_: str = MISSING + environment: gym_env.Config = "${environment}" + + +@dataclasses.dataclass +class BasicRewardNet(Config): + _target_: str = "imitation_cli.utils.reward_network.BasicRewardNet.make" + use_state: bool = True + use_action: bool = True + use_next_state: bool = False + use_done: bool = False + normalize_input_layer: bool = True + + @staticmethod + def make(environment: gym_env.Config, normalize_input_layer: bool, **kwargs): + env = gym_env.make_venv(environment, rnd=np.random.default_rng()) + reward_net = reward_nets.BasicRewardNet( + env.observation_space, + env.action_space, + **kwargs, + ) + if normalize_input_layer: + reward_net = reward_nets.NormalizedRewardNet( + reward_net, + networks.RunningNorm, + ) + return reward_net + + +@dataclasses.dataclass +class BasicShapedRewardNet(BasicRewardNet): + _target_: str = "imitation_cli.utils.reward_network.BasicShapedRewardNet.make" + discount_factor: float = 0.99 + + @staticmethod + def make(environment: gym_env.Config, normalize_input_layer: bool, **kwargs): + env = gym_env.make_venv(environment, rnd=np.random.default_rng()) + reward_net = reward_nets.BasicShapedRewardNet( + env.observation_space, + env.action_space, + **kwargs, + ) + if normalize_input_layer: + reward_net = reward_nets.NormalizedRewardNet( + reward_net, + networks.RunningNorm, + ) + return reward_net + + +@dataclasses.dataclass +class RewardEnsemble(Config): + _target_: str = "imitation_cli.utils.reward_network.RewardEnsemble.make" + ensemble_size: int = MISSING + ensemble_member_config: BasicRewardNet = MISSING + add_std_alpha: Optional[float] = None + + @staticmethod + def make( + environment: gym_env.Config, + ensemble_member_config: BasicRewardNet, + add_std_alpha: Optional[float], + ): + env = gym_env.make_venv(environment, rnd=np.random.default_rng()) + members = [call(ensemble_member_config)] + reward_net = reward_nets.RewardEnsemble( + env.observation_space, env.action_space, members + ) + if add_std_alpha is not None: + reward_net = reward_nets.AddSTDRewardWrapper( + reward_net, + default_alpha=add_std_alpha, + ) + return reward_net + + +def make_reward_net(config: Config): + return call(config) + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="basic", node=BasicRewardNet) + cs.store(group=group, name="shaped", node=BasicShapedRewardNet) + cs.store(group=group, name="ensemble", node=RewardEnsemble) diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py new file mode 100644 index 000000000..a8bbeb43b --- /dev/null +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -0,0 +1,95 @@ +import dataclasses +import pathlib +from typing import Optional + +import numpy as np +import stable_baselines3 as sb3 +from hydra.utils import call +from omegaconf import MISSING + +from imitation_cli.utils import environment as gym_env +from imitation_cli.utils import policy as policy_conf +from imitation_cli.utils import schedule + + +@dataclasses.dataclass +class Config: + _target_: str = MISSING + environment: str = "${environment}" + + +@dataclasses.dataclass +class PPO(Config): + _target_: str = "imitation_cli.utils.rl_algorithm.PPO.make" + _recursive_: bool = False + + policy: policy_conf.ActorCriticPolicy = policy_conf.ActorCriticPolicy() + environment: gym_env.Config = "${environment}" + learning_rate: schedule.Config = schedule.FixedSchedule(3e-4) + n_steps: int = 2048 + batch_size: int = 64 + n_epochs: int = 10 + gamma: float = 0.99 + gae_lambda: float = 0.95 + clip_range: schedule.Config = schedule.FixedSchedule(0.2) + clip_range_vf: Optional[schedule.Config] = None + normalize_advantage: bool = True + ent_coef: float = 0.0 + vf_coef: float = 0.5 + max_grad_norm: float = 0.5 + use_sde: bool = False + sde_sample_freq: int = -1 + target_kl: Optional[float] = None + tensorboard_log: Optional[str] = None + verbose: int = 0 + seed: int = "${seed}" + device: str = "auto" + + @staticmethod + def make( + environment: gym_env.Config, + policy: policy_conf.ActorCriticPolicy, + rng: np.random.Generator, + learning_rate: schedule.Config, + clip_range: schedule.Config, + **kwargs, + ): + if "lr_schedule" not in policy: + policy["lr_schedule"] = learning_rate + + policy_kwargs = policy_conf.ActorCriticPolicy.make_args(**policy) + del policy_kwargs["use_sde"] + del policy_kwargs["lr_schedule"] + return sb3.PPO( + policy=sb3.common.policies.ActorCriticPolicy, + policy_kwargs=policy_kwargs, + env=gym_env.make_venv(environment, rng), + learning_rate=schedule.make_schedule(learning_rate), + clip_range=schedule.make_schedule(clip_range), + **kwargs, + ) + + +@dataclasses.dataclass +class PPOOnDisk(PPO): + _target_: str = "imitation_cli.utils.rl_algorithm.PPOOnDisk.make" + path: pathlib.Path = MISSING + + @staticmethod + def make(path: pathlib.Path, environment: gym_env.Config, rng: np.random.Generator): + from imitation.policies import serialize + + env = gym_env.make_venv(environment, rng) + return serialize.load_stable_baselines_model(sb3.PPO, path, env) + + +def make_rl_algo(algo_conf: Config, rng: np.random.Generator): + return call(algo_conf, rng=rng) + + +def register_configs(group: str = "rl_algorithm"): + from hydra.core.config_store import ConfigStore + + cs = ConfigStore.instance() + cs.store(name="ppo", group=group, node=PPO) + cs.store(name="ppo_on_disk", group=group, node=PPOOnDisk) diff --git a/src/imitation_cli/utils/schedule.py b/src/imitation_cli/utils/schedule.py new file mode 100644 index 000000000..0d9a0f51e --- /dev/null +++ b/src/imitation_cli/utils/schedule.py @@ -0,0 +1,38 @@ +import dataclasses + +from hydra.core.config_store import ConfigStore +from hydra.utils import call +from omegaconf import MISSING + + +@dataclasses.dataclass +class Config: + # Note: we don't define _target_ here so in the subclasses it can be defined last. + # This way we can instantiate a fixed schedule with `FixedSchedule(0.1)`. + # If we defined _target_ here, then we would have to instantiate a fixed schedule + # with `FixedSchedule(val=0.1)`. Otherwise we would set _target_ to 0.1. + pass + + +@dataclasses.dataclass +class FixedSchedule(Config): + val: float = MISSING + _target_: str = "stable_baselines3.common.utils.constant_fn" + + +@dataclasses.dataclass +class LinearSchedule(Config): + start: float = MISSING + end: float = MISSING + end_fraction: float = MISSING + _target_: str = "stable_baselines3.common.utils.get_linear_fn" + + +def make_schedule(cfg: Config): + return call(cfg) + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="fixed", node=FixedSchedule) + cs.store(group=group, name="linear", node=LinearSchedule) diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py new file mode 100644 index 000000000..ebbff3caf --- /dev/null +++ b/src/imitation_cli/utils/trajectories.py @@ -0,0 +1,60 @@ +import dataclasses +import pathlib + +import numpy as np +from hydra.core.config_store import ConfigStore +from hydra.utils import call +from omegaconf import MISSING + +from imitation_cli.utils import environment, policy + + +@dataclasses.dataclass +class Config: + _target_: str = MISSING + + +@dataclasses.dataclass +class OnDisk(Config): + _target_: str = "imitation_cli.utils.trajectories.OnDisk.make" + path: pathlib.Path = MISSING + + @staticmethod + def make(path: pathlib.Path, rng: np.random.Generator): + from imitation.data import serialize + + serialize.load(path) + + +@dataclasses.dataclass +class Generated(Config): + _target_: str = "imitation_cli.utils.trajectories.Generated.make" + _recursive_: bool = False # This way the expert_policy is not aut-filled. + total_timesteps: int = int(10) # TODO: this is low for debugging + expert_policy: policy.Config = policy.Config(environment="${environment}") + + @staticmethod + def make( + total_timesteps: int, expert_policy: policy.Config, rng: np.random.Generator + ): + from imitation.data import rollout + + expert = policy.make_policy(expert_policy, rng) + venv = environment.make_venv(expert_policy.environment, rng) + return rollout.generate_trajectories( + expert, + venv, + rollout.make_sample_until(min_timesteps=total_timesteps), + rng, + deterministic_policy=True, + ) + + +def get_trajectories(config: Config, rng: np.random.Generator): + return call(config, rng=rng) + + +def register_configs(group: str): + cs = ConfigStore.instance() + cs.store(group=group, name="on_disk", node=OnDisk) + cs.store(group=group, name="generated", node=Generated) From 1328c2a4cdb63288c06adacc6b7191e72724d39a Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 17 Apr 2023 19:08:45 +0200 Subject: [PATCH 02/25] Use recursive calls and introduce random number generator config. --- src/imitation_cli/airl.py | 20 ++---- src/imitation_cli/utils/environment.py | 34 ++++------ src/imitation_cli/utils/feature_extractor.py | 4 -- src/imitation_cli/utils/optimizer.py | 4 -- src/imitation_cli/utils/policy.py | 66 +++++--------------- src/imitation_cli/utils/randomness.py | 17 +++++ src/imitation_cli/utils/reward_network.py | 13 ++-- src/imitation_cli/utils/rl_algorithm.py | 20 ++---- src/imitation_cli/utils/schedule.py | 4 -- src/imitation_cli/utils/trajectories.py | 23 ++++--- 10 files changed, 70 insertions(+), 135 deletions(-) create mode 100644 src/imitation_cli/utils/randomness.py diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index b0c285bb0..30c56c47c 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -3,20 +3,13 @@ from typing import Optional import hydra -import numpy as np from hydra.core.config_store import ConfigStore +from hydra.utils import call from omegaconf import MISSING from imitation.algorithms.adversarial.airl import AIRL from imitation.data import rollout -from imitation_cli.utils import environment as gym_env -from imitation_cli.utils import ( - optimizer, - policy, - reward_network, - rl_algorithm, - trajectories, -) +from imitation_cli.utils import environment as gym_env, optimizer, policy, reward_network, rl_algorithm, trajectories @dataclasses.dataclass @@ -63,15 +56,14 @@ class AIRLConfig: ) def run_airl(cfg: AIRLConfig) -> None: - rng = np.random.default_rng(cfg.seed) - expert_trajs = trajectories.get_trajectories(cfg.expert_trajs, rng) + expert_trajs = call(cfg.expert_trajs) print(len(expert_trajs)) - venv = gym_env.make_venv(cfg.environment, rng) + venv = call(cfg.environment) reward_net = reward_network.make_reward_net(cfg.reward_net) - gen_algo = rl_algorithm.make_rl_algo(cfg.gen_algo, rng) + gen_algo = call(cfg.gen_algo) trainer = AIRL( venv=venv, @@ -80,7 +72,7 @@ def run_airl(cfg: AIRLConfig) -> None: reward_net=reward_net, demo_batch_size=cfg.demo_batch_size, n_disc_updates_per_round=cfg.n_disc_updates_per_round, - disc_opt_cls=optimizer.make_optimizer(cfg.disc_opt_cls), + disc_opt_cls=call(cfg.disc_opt_cls), gen_train_timesteps=cfg.gen_train_timesteps, gen_replay_buffer_capacity=cfg.gen_replay_buffer_capacity, init_tensorboard=cfg.init_tensorboard, diff --git a/src/imitation_cli/utils/environment.py b/src/imitation_cli/utils/environment.py index 6de80932a..b4ac26401 100644 --- a/src/imitation_cli/utils/environment.py +++ b/src/imitation_cli/utils/environment.py @@ -1,12 +1,15 @@ import dataclasses -import numpy as np from hydra.core.config_store import ConfigStore +from hydra.utils import call from omegaconf import MISSING +from imitation_cli.utils import randomness + @dataclasses.dataclass class Config: + _target_: str = "imitation_cli.utils.environment.Config.make" env_name: str = MISSING # The environment to train on n_envs: int = 8 # number of environments in VecEnv parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv TODO: when setting this to true this is really slow for some reason @@ -14,35 +17,22 @@ class Config: env_make_kwargs: dict = dataclasses.field( default_factory=dict ) # The kwargs passed to `spec.make`. + rng: randomness.Config = randomness.Config + @staticmethod + def make(log_dir=None, **kwargs): + from imitation.util import util -def make_venv( - environment_config: Config, - rnd: np.random.Generator, - log_dir=None, - **kwargs, -): - from imitation.util import util - - return util.make_vec_env( - **environment_config, - rng=rnd, - log_dir=log_dir, - **kwargs, - ) + return util.make_vec_env(log_dir=log_dir, **kwargs) -def make_rollout_venv( - environment_config: Config, - rnd: np.random.Generator, -): +def make_rollout_venv(environment_config: Config): from imitation.data import wrappers - return make_venv( + return call( environment_config, - rnd, log_dir=None, - post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], + post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)] ) diff --git a/src/imitation_cli/utils/feature_extractor.py b/src/imitation_cli/utils/feature_extractor.py index f62ca2336..ddbf95c55 100644 --- a/src/imitation_cli/utils/feature_extractor.py +++ b/src/imitation_cli/utils/feature_extractor.py @@ -32,10 +32,6 @@ def make(): return stable_baselines3.common.torch_layers.NatureCNN -def make_feature_extractor(cfg: Config): - return call(cfg) - - def register_configs(group: str): cs = ConfigStore.instance() cs.store(group=group, name="flatten", node=FlattenExtractorConfig) diff --git a/src/imitation_cli/utils/optimizer.py b/src/imitation_cli/utils/optimizer.py index 21bef99a6..f368b8bda 100644 --- a/src/imitation_cli/utils/optimizer.py +++ b/src/imitation_cli/utils/optimizer.py @@ -32,10 +32,6 @@ def make(): return torch.optim.SGD -def make_optimizer(cfg: Config): - return call(cfg) - - def register_configs(group: str): cs = ConfigStore.instance() cs.store(group=group, name="adam", node=Adam) diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index 8dbdd0433..2ce9ab706 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -25,11 +25,9 @@ class Random(Config): _target_: str = "imitation_cli.utils.policy.Random.make" @staticmethod - def make(environment: gym_env.Config, rng: np.random.Generator): + def make(environment: gym_env.Config): from imitation.policies import base - - env = gym_env.make_venv(environment, rng) - return base.RandomPolicy(env.observation_space, env.action_space) + return base.RandomPolicy(environment.observation_space, environment.action_space) @dataclasses.dataclass @@ -37,18 +35,16 @@ class ZeroPolicy(Config): _target_: str = "imitation_cli.utils.policy.ZeroPolicy.make" @staticmethod - def make(environment: gym_env.Config, rng: np.random.Generator): + def make(environment: gym_env.Config): from imitation.policies import base - env = gym_env.make_venv(environment, rng) - return base.ZeroPolicy(env.observation_space, env.action_space) + return base.ZeroPolicy(environment.observation_space, environment.action_space) @dataclasses.dataclass class ActorCriticPolicy(Config): _target_: str = "imitation_cli.utils.policy.ActorCriticPolicy.make" - _recursive_: bool = False - lr_schedule: schedule.Config = MISSING + lr_schedule: schedule.Config = schedule.FixedSchedule(3e-4) # TODO: make sure this is copied from the rl_algorithm instead net_arch: Optional[Dict[str, List[int]]] = None activation_fn: activation_function.Config = activation_function.TanH() ortho_init: bool = True @@ -68,56 +64,30 @@ class ActorCriticPolicy(Config): @staticmethod def make_args( - lr_schedule: schedule.Config, activation_fn: activation_function.Config, features_extractor_class: feature_extractor.Config, optimizer_class: optimizer.Config, **kwargs, ): - activation_fn = activation_function.make_activation_function(activation_fn) - lr_schedule = schedule.make_schedule(lr_schedule) - features_extractor_class = feature_extractor.make_feature_extractor( - features_extractor_class - ) - optimizer_class = optimizer.make_optimizer(optimizer_class) - - del kwargs["environment"] del kwargs["_target_"] - del kwargs["_recursive_"] + del kwargs["environment"] + + kwargs["activation_fn"] = call(activation_fn) + kwargs["features_extractor_class"] = call(features_extractor_class) + kwargs["optimizer_class"] = call(optimizer_class) return dict( - lr_schedule=lr_schedule, - activation_fn=activation_fn, - features_extractor_class=features_extractor_class, - optimizer_class=optimizer_class, **kwargs, ) @staticmethod def make( environment: gym_env.Config, - rng: np.random.Generator, - lr_schedule: schedule.Config, - activation_fn: activation_function.Config, - features_extractor_class: feature_extractor.Config, - optimizer_class: optimizer.Config, **kwargs, ): - env = gym_env.make_venv(environment, rng) - activation_fn = activation_function.make_activation_function(activation_fn) - lr_schedule = schedule.make_schedule(lr_schedule) - features_extractor_class = feature_extractor.make_feature_extractor( - features_extractor_class - ) - optimizer_class = optimizer.make_optimizer(optimizer_class) - return sb3.common.policies.ActorCriticPolicy( - lr_schedule=lr_schedule, - activation_fn=activation_fn, - features_extractor_class=features_extractor_class, - optimizer_class=optimizer_class, - observation_space=env.observation_space, - action_space=env.action_space, + observation_space=environment.observation_space, + action_space=environment.action_space, **kwargs, ) @@ -148,13 +118,11 @@ def make( environment: gym_env.Config, path: pathlib.Path, type: str, - rng: np.random.Generator, ): from imitation.policies import serialize - env = gym_env.make_venv(environment, rng) return serialize.load_stable_baselines_model( - Loaded.type_to_class(type), path, env + Loaded.type_to_class(type), path, environment ).policy @@ -168,7 +136,6 @@ def make( type: str, environment: gym_env.Config, organization: str, - rng: np.random.Generator, ): import huggingface_sb3 as hfsb3 @@ -179,17 +146,12 @@ def make( ) repo_id = hfsb3.ModelRepoId(organization, model_name) filename = hfsb3.load_from_hub(repo_id, model_name.filename) - env = gym_env.make_venv(environment, rng) model = serialize.load_stable_baselines_model( - Loaded.type_to_class(type), filename, env + Loaded.type_to_class(type), filename, environment ) return model.policy -def make_policy(cfg: Config, rng: np.random.Generator): - return call(cfg, rng=rng) - - def register_configs(group: str): cs = ConfigStore.instance() cs.store(group=group, name="random", node=Random) diff --git a/src/imitation_cli/utils/randomness.py b/src/imitation_cli/utils/randomness.py new file mode 100644 index 000000000..0461621ef --- /dev/null +++ b/src/imitation_cli/utils/randomness.py @@ -0,0 +1,17 @@ +import dataclasses + + +@dataclasses.dataclass +class Config: + _target_: str = "imitation_cli.utils.randomness.Config.make" + seed: int = "${seed}" + + @staticmethod + def make(seed: int): + import numpy as np + import torch + + np.random.seed(seed) + torch.manual_seed(seed) + + return np.random.default_rng(seed) \ No newline at end of file diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py index b58fe47b0..a1f7bcaf2 100644 --- a/src/imitation_cli/utils/reward_network.py +++ b/src/imitation_cli/utils/reward_network.py @@ -28,10 +28,9 @@ class BasicRewardNet(Config): @staticmethod def make(environment: gym_env.Config, normalize_input_layer: bool, **kwargs): - env = gym_env.make_venv(environment, rnd=np.random.default_rng()) reward_net = reward_nets.BasicRewardNet( - env.observation_space, - env.action_space, + environment.observation_space, + environment.action_space, **kwargs, ) if normalize_input_layer: @@ -49,10 +48,9 @@ class BasicShapedRewardNet(BasicRewardNet): @staticmethod def make(environment: gym_env.Config, normalize_input_layer: bool, **kwargs): - env = gym_env.make_venv(environment, rnd=np.random.default_rng()) reward_net = reward_nets.BasicShapedRewardNet( - env.observation_space, - env.action_space, + environment.observation_space, + environment.action_space, **kwargs, ) if normalize_input_layer: @@ -76,10 +74,9 @@ def make( ensemble_member_config: BasicRewardNet, add_std_alpha: Optional[float], ): - env = gym_env.make_venv(environment, rnd=np.random.default_rng()) members = [call(ensemble_member_config)] reward_net = reward_nets.RewardEnsemble( - env.observation_space, env.action_space, members + environment.observation_space, environment.action_space, members ) if add_std_alpha is not None: reward_net = reward_nets.AddSTDRewardWrapper( diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index a8bbeb43b..7bbb199fc 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -22,8 +22,7 @@ class Config: class PPO(Config): _target_: str = "imitation_cli.utils.rl_algorithm.PPO.make" _recursive_: bool = False - - policy: policy_conf.ActorCriticPolicy = policy_conf.ActorCriticPolicy() + policy: policy_conf.ActorCriticPolicy = policy_conf.ActorCriticPolicy environment: gym_env.Config = "${environment}" learning_rate: schedule.Config = schedule.FixedSchedule(3e-4) n_steps: int = 2048 @@ -49,23 +48,19 @@ class PPO(Config): def make( environment: gym_env.Config, policy: policy_conf.ActorCriticPolicy, - rng: np.random.Generator, learning_rate: schedule.Config, clip_range: schedule.Config, **kwargs, ): - if "lr_schedule" not in policy: - policy["lr_schedule"] = learning_rate - policy_kwargs = policy_conf.ActorCriticPolicy.make_args(**policy) del policy_kwargs["use_sde"] del policy_kwargs["lr_schedule"] return sb3.PPO( policy=sb3.common.policies.ActorCriticPolicy, policy_kwargs=policy_kwargs, - env=gym_env.make_venv(environment, rng), - learning_rate=schedule.make_schedule(learning_rate), - clip_range=schedule.make_schedule(clip_range), + env=call(environment), + learning_rate=call(learning_rate), + clip_range=call(clip_range), **kwargs, ) @@ -79,12 +74,7 @@ class PPOOnDisk(PPO): def make(path: pathlib.Path, environment: gym_env.Config, rng: np.random.Generator): from imitation.policies import serialize - env = gym_env.make_venv(environment, rng) - return serialize.load_stable_baselines_model(sb3.PPO, path, env) - - -def make_rl_algo(algo_conf: Config, rng: np.random.Generator): - return call(algo_conf, rng=rng) + return serialize.load_stable_baselines_model(sb3.PPO, path, environment) def register_configs(group: str = "rl_algorithm"): diff --git a/src/imitation_cli/utils/schedule.py b/src/imitation_cli/utils/schedule.py index 0d9a0f51e..b9e3e17e3 100644 --- a/src/imitation_cli/utils/schedule.py +++ b/src/imitation_cli/utils/schedule.py @@ -28,10 +28,6 @@ class LinearSchedule(Config): _target_: str = "stable_baselines3.common.utils.get_linear_fn" -def make_schedule(cfg: Config): - return call(cfg) - - def register_configs(group: str): cs = ConfigStore.instance() cs.store(group=group, name="fixed", node=FixedSchedule) diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py index ebbff3caf..5886db6b2 100644 --- a/src/imitation_cli/utils/trajectories.py +++ b/src/imitation_cli/utils/trajectories.py @@ -1,12 +1,11 @@ import dataclasses import pathlib -import numpy as np from hydra.core.config_store import ConfigStore from hydra.utils import call from omegaconf import MISSING -from imitation_cli.utils import environment, policy +from imitation_cli.utils import policy, randomness @dataclasses.dataclass @@ -20,7 +19,7 @@ class OnDisk(Config): path: pathlib.Path = MISSING @staticmethod - def make(path: pathlib.Path, rng: np.random.Generator): + def make(path: pathlib.Path): from imitation.data import serialize serialize.load(path) @@ -29,31 +28,31 @@ def make(path: pathlib.Path, rng: np.random.Generator): @dataclasses.dataclass class Generated(Config): _target_: str = "imitation_cli.utils.trajectories.Generated.make" - _recursive_: bool = False # This way the expert_policy is not aut-filled. + _recursive_: bool = False total_timesteps: int = int(10) # TODO: this is low for debugging expert_policy: policy.Config = policy.Config(environment="${environment}") + rng: randomness.Config = randomness.Config() @staticmethod def make( - total_timesteps: int, expert_policy: policy.Config, rng: np.random.Generator + total_timesteps: int, + expert_policy: policy.Config, + rng: randomness.Config, ): from imitation.data import rollout - expert = policy.make_policy(expert_policy, rng) - venv = environment.make_venv(expert_policy.environment, rng) + expert = call(expert_policy) + env = call(expert_policy.environment) + rng = call(rng) return rollout.generate_trajectories( expert, - venv, + env, rollout.make_sample_until(min_timesteps=total_timesteps), rng, deterministic_policy=True, ) -def get_trajectories(config: Config, rng: np.random.Generator): - return call(config, rng=rng) - - def register_configs(group: str): cs = ConfigStore.instance() cs.store(group=group, name="on_disk", node=OnDisk) From a72a25d034e851e9415ccb6b1bdf8e538c22db80 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 18 Apr 2023 17:09:40 +0200 Subject: [PATCH 03/25] Split up AILR config in AIRL creation and AIRL run config. Also move most of the interpolation magic into the main script and rename some util modules to make clear what they do. --- src/imitation_cli/airl.py | 86 ++++++++----------- .../config/environment/cartpole.yaml | 5 -- .../config/environment/myenv.yaml | 5 -- .../config/environment/retro.yaml | 5 -- .../config/environment/seals.yaml | 5 -- ...nction.py => activation_function_class.py} | 12 +-- ...xtractor.py => feature_extractor_class.py} | 5 +- .../{optimizer.py => optimizer_class.py} | 5 +- src/imitation_cli/utils/policy.py | 50 +++++------ src/imitation_cli/utils/reward_network.py | 25 +++--- src/imitation_cli/utils/rl_algorithm.py | 20 ++--- src/imitation_cli/utils/schedule.py | 1 - src/imitation_cli/utils/trajectories.py | 4 +- 13 files changed, 91 insertions(+), 137 deletions(-) delete mode 100644 src/imitation_cli/config/environment/cartpole.yaml delete mode 100644 src/imitation_cli/config/environment/myenv.yaml delete mode 100644 src/imitation_cli/config/environment/retro.yaml delete mode 100644 src/imitation_cli/config/environment/seals.yaml rename src/imitation_cli/utils/{activation_function.py => activation_function_class.py} (72%) rename src/imitation_cli/utils/{feature_extractor.py => feature_extractor_class.py} (79%) rename src/imitation_cli/utils/{optimizer.py => optimizer_class.py} (79%) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 30c56c47c..be5d479ec 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -7,31 +7,21 @@ from hydra.utils import call from omegaconf import MISSING -from imitation.algorithms.adversarial.airl import AIRL from imitation.data import rollout -from imitation_cli.utils import environment as gym_env, optimizer, policy, reward_network, rl_algorithm, trajectories +from imitation_cli.utils import environment as gym_env, optimizer_class, policy, reward_network, rl_algorithm, trajectories +from imitation_cli.utils import policy as policy_conf @dataclasses.dataclass class AIRLConfig: - defaults: list = dataclasses.field( - default_factory=lambda: [ - {"environment": "gym_env"}, - {"reward_net": "shaped"}, - # {"gen_algo": "ppo"}, - "_self_", - ] - ) - environment: gym_env.Config = MISSING - expert_trajs: trajectories.Config = MISSING - total_timesteps: int = int(1e6) - checkpoint_interval: int = 0 - gen_algo: rl_algorithm.Config = rl_algorithm.PPO() + _target_: str = "imitation.algorithms.adversarial.airl.AIRL" + venv: gym_env.Config = "${venv}" + demonstrations: trajectories.Config = "${demonstrations}" + gen_algo: rl_algorithm.Config = MISSING reward_net: reward_network.Config = MISSING - seed: int = 0 demo_batch_size: int = 64 n_disc_updates_per_round: int = 2 - disc_opt_cls: optimizer.Config = optimizer.Adam + disc_opt_cls: optimizer_class.Config = optimizer_class.Adam gen_train_timesteps: Optional[int] = None gen_replay_buffer_capacity: Optional[int] = None init_tensorboard: bool = False @@ -40,46 +30,42 @@ class AIRLConfig: allow_variable_horizon: bool = True # TODO: true just for debugging +@dataclasses.dataclass +class AIRLRunConfig: + defaults: list = dataclasses.field( + default_factory=lambda: [ + {"venv": "gym_env"}, + {"airl/reward_net": "shaped"}, + {"airl/gen_algo": "ppo"}, + "_self_", + ] + ) + seed: int = 0 + venv: gym_env.Config = MISSING + demonstrations: trajectories.Config = MISSING + airl: AIRLConfig = AIRLConfig() + total_timesteps: int = int(1e6) + checkpoint_interval: int = 0 + + cs = ConfigStore.instance() cs.store(name="airl", node=AIRLConfig) -policy.register_configs("expert_trajs/expert_policy") -rl_algorithm.register_configs("gen_algo") -trajectories.register_configs("expert_trajs") -gym_env.register_configs("environment") -reward_network.register_configs("reward_net") +cs.store(name="airl_run", node=AIRLRunConfig) +trajectories.register_configs("demonstrations") +gym_env.register_configs("venv") +policy.register_configs("demonstrations/expert_policy", dict(environment="${venv}")) # Make sure the expert generating the demonstrations uses the same env as the main env +rl_algorithm.register_configs("airl/gen_algo", dict(environment="${venv}", policy=policy_conf.ActorCriticPolicy(environment="${venv}"))) # The generation algo and its policy should use the main env by default +reward_network.register_configs("airl/reward_net", dict(environment="${venv}")) # The reward network should be tailored to the default environment by default @hydra.main( version_base=None, config_path="config", - config_name="airl", + config_name="airl_run", ) -def run_airl(cfg: AIRLConfig) -> None: - - expert_trajs = call(cfg.expert_trajs) - print(len(expert_trajs)) - - venv = call(cfg.environment) - - reward_net = reward_network.make_reward_net(cfg.reward_net) - - gen_algo = call(cfg.gen_algo) - - trainer = AIRL( - venv=venv, - demonstrations=expert_trajs, - gen_algo=gen_algo, - reward_net=reward_net, - demo_batch_size=cfg.demo_batch_size, - n_disc_updates_per_round=cfg.n_disc_updates_per_round, - disc_opt_cls=call(cfg.disc_opt_cls), - gen_train_timesteps=cfg.gen_train_timesteps, - gen_replay_buffer_capacity=cfg.gen_replay_buffer_capacity, - init_tensorboard=cfg.init_tensorboard, - init_tensorboard_graph=cfg.init_tensorboard_graph, - debug_use_ground_truth=cfg.debug_use_ground_truth, - allow_variable_horizon=cfg.allow_variable_horizon, - ) +def run_airl(cfg: AIRLRunConfig) -> None: + + trainer = call(cfg.airl) def callback(round_num: int, /) -> None: if cfg.checkpoint_interval > 0 and round_num % cfg.checkpoint_interval == 0: @@ -98,7 +84,7 @@ def callback(round_num: int, /) -> None: return { # "imit_stats": imit_stats, - "expert_stats": rollout.rollout_stats(expert_trajs), + "expert_stats": rollout.rollout_stats(cfg.airl.demonstrations), } diff --git a/src/imitation_cli/config/environment/cartpole.yaml b/src/imitation_cli/config/environment/cartpole.yaml deleted file mode 100644 index 3a2db8f84..000000000 --- a/src/imitation_cli/config/environment/cartpole.yaml +++ /dev/null @@ -1,5 +0,0 @@ -defaults: - - gym_env - -env_name: CartPole-v0 -max_episode_steps: 500 diff --git a/src/imitation_cli/config/environment/myenv.yaml b/src/imitation_cli/config/environment/myenv.yaml deleted file mode 100644 index 3d92be40c..000000000 --- a/src/imitation_cli/config/environment/myenv.yaml +++ /dev/null @@ -1,5 +0,0 @@ -defaults: - - gym_env - -env_name: my_own_env -max_episode_steps: 42 diff --git a/src/imitation_cli/config/environment/retro.yaml b/src/imitation_cli/config/environment/retro.yaml deleted file mode 100644 index 000c3ba46..000000000 --- a/src/imitation_cli/config/environment/retro.yaml +++ /dev/null @@ -1,5 +0,0 @@ -defaults: - - gym_env - -env_name: Retro-v0 -max_episode_steps: 4500 diff --git a/src/imitation_cli/config/environment/seals.yaml b/src/imitation_cli/config/environment/seals.yaml deleted file mode 100644 index e30789de6..000000000 --- a/src/imitation_cli/config/environment/seals.yaml +++ /dev/null @@ -1,5 +0,0 @@ -defaults: - - gym_env - -env_name: Seals-v0 -max_episode_steps: 1000 diff --git a/src/imitation_cli/utils/activation_function.py b/src/imitation_cli/utils/activation_function_class.py similarity index 72% rename from src/imitation_cli/utils/activation_function.py rename to src/imitation_cli/utils/activation_function_class.py index 8c164a805..22f5f8a8e 100644 --- a/src/imitation_cli/utils/activation_function.py +++ b/src/imitation_cli/utils/activation_function_class.py @@ -1,8 +1,6 @@ import dataclasses from hydra.core.config_store import ConfigStore -from hydra.utils import call -from omegaconf import MISSING @dataclasses.dataclass @@ -14,7 +12,7 @@ class Config: @dataclasses.dataclass class TanH(Config): - _target_: str = "imitation_cli.utils.activation_function.TanH.make" + _target_: str = "imitation_cli.utils.activation_function_class.TanH.make" @staticmethod def make(): @@ -25,7 +23,7 @@ def make(): @dataclasses.dataclass class ReLU(Config): - _target_: str = "imitation_cli.utils.activation_function.ReLU.make" + _target_: str = "imitation_cli.utils.activation_function_class.ReLU.make" @staticmethod def make(): @@ -36,7 +34,7 @@ def make(): @dataclasses.dataclass class LeakyReLU(Config): - _target_: str = "imitation_cli.utils.activation_function.LeakyReLU.make" + _target_: str = "imitation_cli.utils.activation_function_class.LeakyReLU.make" @staticmethod def make(): @@ -45,10 +43,6 @@ def make(): return torch.nn.LeakyReLU -def make_activation_function(cfg: Config): - return call(cfg) - - def register_configs(group: str): cs = ConfigStore.instance() cs.store(group=group, name="tanh", node=TanH) diff --git a/src/imitation_cli/utils/feature_extractor.py b/src/imitation_cli/utils/feature_extractor_class.py similarity index 79% rename from src/imitation_cli/utils/feature_extractor.py rename to src/imitation_cli/utils/feature_extractor_class.py index ddbf95c55..97d55b166 100644 --- a/src/imitation_cli/utils/feature_extractor.py +++ b/src/imitation_cli/utils/feature_extractor_class.py @@ -1,7 +1,6 @@ import dataclasses from hydra.core.config_store import ConfigStore -from hydra.utils import call from omegaconf import MISSING @@ -12,7 +11,7 @@ class Config: @dataclasses.dataclass class FlattenExtractorConfig(Config): - _target_: str = "imitation_cli.utils.feature_extractor.FlattenExtractorConfig.make" + _target_: str = "imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make" @staticmethod def make(): @@ -23,7 +22,7 @@ def make(): @dataclasses.dataclass class NatureCNNConfig(Config): - _target_: str = "imitation_cli.utils.feature_extractor.NatureCNNConfig.make" + _target_: str = "imitation_cli.utils.feature_extractor_class.NatureCNNConfig.make" @staticmethod def make(): diff --git a/src/imitation_cli/utils/optimizer.py b/src/imitation_cli/utils/optimizer_class.py similarity index 79% rename from src/imitation_cli/utils/optimizer.py rename to src/imitation_cli/utils/optimizer_class.py index f368b8bda..986afa419 100644 --- a/src/imitation_cli/utils/optimizer.py +++ b/src/imitation_cli/utils/optimizer_class.py @@ -1,7 +1,6 @@ import dataclasses from hydra.core.config_store import ConfigStore -from hydra.utils import call from omegaconf import MISSING @@ -12,7 +11,7 @@ class Config: @dataclasses.dataclass class Adam(Config): - _target_: str = "imitation_cli.utils.optimizer.Adam.make" + _target_: str = "imitation_cli.utils.optimizer_class.Adam.make" @staticmethod def make(): @@ -23,7 +22,7 @@ def make(): @dataclasses.dataclass class SGD(Config): - _target_: str = "imitation_cli.utils.optimizer.SGD.make" + _target_: str = "imitation_cli.utils.optimizer_class.SGD.make" @staticmethod def make(): diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index 2ce9ab706..13b66030e 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -1,23 +1,25 @@ import dataclasses import pathlib -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Mapping -import numpy as np import stable_baselines3 as sb3 from hydra.core.config_store import ConfigStore from hydra.utils import call from omegaconf import MISSING from stable_baselines3.common.torch_layers import FlattenExtractor -from imitation_cli.utils import activation_function -from imitation_cli.utils import environment as gym_env -from imitation_cli.utils import feature_extractor, optimizer, schedule +from imitation_cli.utils import \ + activation_function_class as activation_function_class_cfg, \ + environment as environment_cfg,\ + feature_extractor_class as feature_extractor_class_cfg,\ + optimizer_class as optimizer_class_cfg, \ + schedule @dataclasses.dataclass class Config: _target_: str = MISSING - environment: gym_env.Config = "${environment}" + environment: environment_cfg.Config = MISSING @dataclasses.dataclass @@ -25,7 +27,7 @@ class Random(Config): _target_: str = "imitation_cli.utils.policy.Random.make" @staticmethod - def make(environment: gym_env.Config): + def make(environment: environment_cfg.Config): from imitation.policies import base return base.RandomPolicy(environment.observation_space, environment.action_space) @@ -35,7 +37,7 @@ class ZeroPolicy(Config): _target_: str = "imitation_cli.utils.policy.ZeroPolicy.make" @staticmethod - def make(environment: gym_env.Config): + def make(environment: environment_cfg.Config): from imitation.policies import base return base.ZeroPolicy(environment.observation_space, environment.action_space) @@ -46,27 +48,27 @@ class ActorCriticPolicy(Config): _target_: str = "imitation_cli.utils.policy.ActorCriticPolicy.make" lr_schedule: schedule.Config = schedule.FixedSchedule(3e-4) # TODO: make sure this is copied from the rl_algorithm instead net_arch: Optional[Dict[str, List[int]]] = None - activation_fn: activation_function.Config = activation_function.TanH() + activation_fn: activation_function_class_cfg.Config = activation_function_class_cfg.TanH() ortho_init: bool = True use_sde: bool = False log_std_init: float = 0.0 full_std: bool = True use_expln: bool = False squash_output: bool = False - features_extractor_class: feature_extractor.Config = ( - feature_extractor.FlattenExtractorConfig() + features_extractor_class: feature_extractor_class_cfg.Config = ( + feature_extractor_class_cfg.FlattenExtractorConfig() ) features_extractor_kwargs: Optional[Dict[str, Any]] = None share_features_extractor: bool = True normalize_images: bool = True - optimizer_class: optimizer.Config = optimizer.Adam() + optimizer_class: optimizer_class_cfg.Config = optimizer_class_cfg.Adam() optimizer_kwargs: Optional[Dict[str, Any]] = None @staticmethod def make_args( - activation_fn: activation_function.Config, - features_extractor_class: feature_extractor.Config, - optimizer_class: optimizer.Config, + activation_fn: activation_function_class_cfg.Config, + features_extractor_class: feature_extractor_class_cfg.Config, + optimizer_class: optimizer_class_cfg.Config, **kwargs, ): del kwargs["_target_"] @@ -82,7 +84,7 @@ def make_args( @staticmethod def make( - environment: gym_env.Config, + environment: environment_cfg.Config, **kwargs, ): return sb3.common.policies.ActorCriticPolicy( @@ -115,7 +117,7 @@ class PolicyOnDisk(Loaded): @staticmethod def make( - environment: gym_env.Config, + environment: environment_cfg.Config, path: pathlib.Path, type: str, ): @@ -134,7 +136,7 @@ class PolicyFromHuggingface(Loaded): @staticmethod def make( type: str, - environment: gym_env.Config, + environment: environment_cfg.Config, organization: str, ): import huggingface_sb3 as hfsb3 @@ -152,10 +154,10 @@ def make( return model.policy -def register_configs(group: str): +def register_configs(group: str, defaults: Mapping[str, Any] = {}): cs = ConfigStore.instance() - cs.store(group=group, name="random", node=Random) - cs.store(group=group, name="zero", node=ZeroPolicy) - cs.store(group=group, name="on_disk", node=PolicyOnDisk) - cs.store(group=group, name="from_huggingface", node=PolicyFromHuggingface) - cs.store(group=group, name="actor_critic", node=ActorCriticPolicy) + cs.store(group=group, name="random", node=Random(**defaults)) + cs.store(group=group, name="zero", node=ZeroPolicy(**defaults)) + cs.store(group=group, name="on_disk", node=PolicyOnDisk(**defaults)) + cs.store(group=group, name="from_huggingface", node=PolicyFromHuggingface(**defaults)) + cs.store(group=group, name="actor_critic", node=ActorCriticPolicy(**defaults)) diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py index a1f7bcaf2..ec4df391b 100644 --- a/src/imitation_cli/utils/reward_network.py +++ b/src/imitation_cli/utils/reward_network.py @@ -1,12 +1,11 @@ import dataclasses -from typing import Optional +from typing import Optional, Any, Mapping -import numpy as np from hydra.core.config_store import ConfigStore from hydra.utils import call from omegaconf import MISSING -import imitation_cli.utils.environment as gym_env +import imitation_cli.utils.environment as environment_cg from imitation.rewards import reward_nets from imitation.util import networks @@ -14,7 +13,7 @@ @dataclasses.dataclass class Config: _target_: str = MISSING - environment: gym_env.Config = "${environment}" + environment: environment_cg.Config = MISSING @dataclasses.dataclass @@ -27,7 +26,7 @@ class BasicRewardNet(Config): normalize_input_layer: bool = True @staticmethod - def make(environment: gym_env.Config, normalize_input_layer: bool, **kwargs): + def make(environment: environment_cg.Config, normalize_input_layer: bool, **kwargs): reward_net = reward_nets.BasicRewardNet( environment.observation_space, environment.action_space, @@ -47,7 +46,7 @@ class BasicShapedRewardNet(BasicRewardNet): discount_factor: float = 0.99 @staticmethod - def make(environment: gym_env.Config, normalize_input_layer: bool, **kwargs): + def make(environment: environment_cg.Config, normalize_input_layer: bool, **kwargs): reward_net = reward_nets.BasicShapedRewardNet( environment.observation_space, environment.action_space, @@ -70,7 +69,7 @@ class RewardEnsemble(Config): @staticmethod def make( - environment: gym_env.Config, + environment: environment_cg.Config, ensemble_member_config: BasicRewardNet, add_std_alpha: Optional[float], ): @@ -86,12 +85,8 @@ def make( return reward_net -def make_reward_net(config: Config): - return call(config) - - -def register_configs(group: str): +def register_configs(group: str, defaults: Mapping[str, Any] = {}): cs = ConfigStore.instance() - cs.store(group=group, name="basic", node=BasicRewardNet) - cs.store(group=group, name="shaped", node=BasicShapedRewardNet) - cs.store(group=group, name="ensemble", node=RewardEnsemble) + cs.store(group=group, name="basic", node=BasicRewardNet(**defaults)) + cs.store(group=group, name="shaped", node=BasicShapedRewardNet(**defaults)) + cs.store(group=group, name="ensemble", node=RewardEnsemble(**defaults)) diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index 7bbb199fc..bfc1c39cc 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -1,13 +1,13 @@ import dataclasses import pathlib -from typing import Optional +from typing import Optional, Mapping, Any import numpy as np import stable_baselines3 as sb3 from hydra.utils import call from omegaconf import MISSING -from imitation_cli.utils import environment as gym_env +from imitation_cli.utils import environment as environment_cfg from imitation_cli.utils import policy as policy_conf from imitation_cli.utils import schedule @@ -15,15 +15,15 @@ @dataclasses.dataclass class Config: _target_: str = MISSING - environment: str = "${environment}" + environment: environment_cfg.Config = MISSING @dataclasses.dataclass class PPO(Config): _target_: str = "imitation_cli.utils.rl_algorithm.PPO.make" + # We disable recursive instantiation, so we can just make the arguments of the policy but not the policy itself _recursive_: bool = False - policy: policy_conf.ActorCriticPolicy = policy_conf.ActorCriticPolicy - environment: gym_env.Config = "${environment}" + policy: policy_conf.ActorCriticPolicy = MISSING learning_rate: schedule.Config = schedule.FixedSchedule(3e-4) n_steps: int = 2048 batch_size: int = 64 @@ -46,7 +46,7 @@ class PPO(Config): @staticmethod def make( - environment: gym_env.Config, + environment: environment_cfg.Config, policy: policy_conf.ActorCriticPolicy, learning_rate: schedule.Config, clip_range: schedule.Config, @@ -71,15 +71,15 @@ class PPOOnDisk(PPO): path: pathlib.Path = MISSING @staticmethod - def make(path: pathlib.Path, environment: gym_env.Config, rng: np.random.Generator): + def make(path: pathlib.Path, environment: environment_cfg.Config, rng: np.random.Generator): from imitation.policies import serialize return serialize.load_stable_baselines_model(sb3.PPO, path, environment) -def register_configs(group: str = "rl_algorithm"): +def register_configs(group: str = "rl_algorithm", defaults: Mapping[str, Any] = {}): from hydra.core.config_store import ConfigStore cs = ConfigStore.instance() - cs.store(name="ppo", group=group, node=PPO) - cs.store(name="ppo_on_disk", group=group, node=PPOOnDisk) + cs.store(name="ppo", group=group, node=PPO(**defaults)) + cs.store(name="ppo_on_disk", group=group, node=PPOOnDisk(**defaults)) diff --git a/src/imitation_cli/utils/schedule.py b/src/imitation_cli/utils/schedule.py index b9e3e17e3..0ab279f91 100644 --- a/src/imitation_cli/utils/schedule.py +++ b/src/imitation_cli/utils/schedule.py @@ -1,7 +1,6 @@ import dataclasses from hydra.core.config_store import ConfigStore -from hydra.utils import call from omegaconf import MISSING diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py index 5886db6b2..389bfc96f 100644 --- a/src/imitation_cli/utils/trajectories.py +++ b/src/imitation_cli/utils/trajectories.py @@ -28,9 +28,9 @@ def make(path: pathlib.Path): @dataclasses.dataclass class Generated(Config): _target_: str = "imitation_cli.utils.trajectories.Generated.make" - _recursive_: bool = False + _recursive_: bool = False # We disable the recursive flag, so we can extract the environment from the expert policy total_timesteps: int = int(10) # TODO: this is low for debugging - expert_policy: policy.Config = policy.Config(environment="${environment}") + expert_policy: policy.Config = MISSING rng: randomness.Config = randomness.Config() @staticmethod From 89214d90fac79e831abcafedba6420858b2511b9 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 19 Apr 2023 11:35:35 +0200 Subject: [PATCH 04/25] Make inline imports to speed up shell completion. --- src/imitation_cli/airl.py | 2 +- src/imitation_cli/utils/policy.py | 5 +++-- src/imitation_cli/utils/reward_network.py | 11 +++++++++-- src/imitation_cli/utils/rl_algorithm.py | 9 +++++---- 4 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index be5d479ec..48a509482 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -7,7 +7,6 @@ from hydra.utils import call from omegaconf import MISSING -from imitation.data import rollout from imitation_cli.utils import environment as gym_env, optimizer_class, policy, reward_network, rl_algorithm, trajectories from imitation_cli.utils import policy as policy_conf @@ -64,6 +63,7 @@ class AIRLRunConfig: config_name="airl_run", ) def run_airl(cfg: AIRLRunConfig) -> None: + from imitation.data import rollout trainer = call(cfg.airl) diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index 13b66030e..1fdd2b7e4 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -2,11 +2,10 @@ import pathlib from typing import Any, Dict, List, Optional, Mapping -import stable_baselines3 as sb3 + from hydra.core.config_store import ConfigStore from hydra.utils import call from omegaconf import MISSING -from stable_baselines3.common.torch_layers import FlattenExtractor from imitation_cli.utils import \ activation_function_class as activation_function_class_cfg, \ @@ -87,6 +86,8 @@ def make( environment: environment_cfg.Config, **kwargs, ): + import stable_baselines3 as sb3 + return sb3.common.policies.ActorCriticPolicy( observation_space=environment.observation_space, action_space=environment.action_space, diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py index ec4df391b..ef296b713 100644 --- a/src/imitation_cli/utils/reward_network.py +++ b/src/imitation_cli/utils/reward_network.py @@ -6,8 +6,7 @@ from omegaconf import MISSING import imitation_cli.utils.environment as environment_cg -from imitation.rewards import reward_nets -from imitation.util import networks + @dataclasses.dataclass @@ -27,6 +26,9 @@ class BasicRewardNet(Config): @staticmethod def make(environment: environment_cg.Config, normalize_input_layer: bool, **kwargs): + from imitation.rewards import reward_nets + from imitation.util import networks + reward_net = reward_nets.BasicRewardNet( environment.observation_space, environment.action_space, @@ -47,6 +49,9 @@ class BasicShapedRewardNet(BasicRewardNet): @staticmethod def make(environment: environment_cg.Config, normalize_input_layer: bool, **kwargs): + from imitation.rewards import reward_nets + from imitation.util import networks + reward_net = reward_nets.BasicShapedRewardNet( environment.observation_space, environment.action_space, @@ -73,6 +78,8 @@ def make( ensemble_member_config: BasicRewardNet, add_std_alpha: Optional[float], ): + from imitation.rewards import reward_nets + members = [call(ensemble_member_config)] reward_net = reward_nets.RewardEnsemble( environment.observation_space, environment.action_space, members diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index bfc1c39cc..0b3709d24 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -2,8 +2,6 @@ import pathlib from typing import Optional, Mapping, Any -import numpy as np -import stable_baselines3 as sb3 from hydra.utils import call from omegaconf import MISSING @@ -52,6 +50,8 @@ def make( clip_range: schedule.Config, **kwargs, ): + import stable_baselines3 as sb3 + policy_kwargs = policy_conf.ActorCriticPolicy.make_args(**policy) del policy_kwargs["use_sde"] del policy_kwargs["lr_schedule"] @@ -71,9 +71,10 @@ class PPOOnDisk(PPO): path: pathlib.Path = MISSING @staticmethod - def make(path: pathlib.Path, environment: environment_cfg.Config, rng: np.random.Generator): + def make(path: pathlib.Path, environment: environment_cfg.Config): from imitation.policies import serialize - + import stable_baselines3 as sb3 + return serialize.load_stable_baselines_model(sb3.PPO, path, environment) From edd12c48257db9548d75b7a7e89ed77bb6003d0b Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 19 Apr 2023 15:27:56 +0200 Subject: [PATCH 05/25] Add type annotations, fix typing issues, add comments. --- src/imitation_cli/__init__.py | 1 + src/imitation_cli/airl.py | 60 ++++++++++++++----- .../utils/activation_function_class.py | 6 +- src/imitation_cli/utils/environment.py | 12 +++- .../utils/feature_extractor_class.py | 4 +- src/imitation_cli/utils/optimizer_class.py | 4 +- src/imitation_cli/utils/policy.py | 25 ++++---- src/imitation_cli/utils/randomness.py | 9 ++- src/imitation_cli/utils/reward_network.py | 23 +++++-- src/imitation_cli/utils/rl_algorithm.py | 25 +++++--- src/imitation_cli/utils/trajectories.py | 18 ++++-- 11 files changed, 132 insertions(+), 55 deletions(-) diff --git a/src/imitation_cli/__init__.py b/src/imitation_cli/__init__.py index e69de29bb..8e6132367 100644 --- a/src/imitation_cli/__init__.py +++ b/src/imitation_cli/__init__.py @@ -0,0 +1 @@ +"""Hydra configurations and scripts that form a CLI for imitation.""" diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 48a509482..a659d9f12 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -1,26 +1,32 @@ +"""Config and run configuration for AIRL.""" import dataclasses import logging -from typing import Optional +from typing import Any, Dict, Optional, Sequence, cast import hydra from hydra.core.config_store import ConfigStore from hydra.utils import call from omegaconf import MISSING -from imitation_cli.utils import environment as gym_env, optimizer_class, policy, reward_network, rl_algorithm, trajectories +from imitation_cli.utils import environment as gym_env +from imitation_cli.utils import optimizer_class +from imitation_cli.utils import policy from imitation_cli.utils import policy as policy_conf +from imitation_cli.utils import reward_network, rl_algorithm, trajectories @dataclasses.dataclass class AIRLConfig: + """Config for AIRL.""" + _target_: str = "imitation.algorithms.adversarial.airl.AIRL" - venv: gym_env.Config = "${venv}" - demonstrations: trajectories.Config = "${demonstrations}" + venv: gym_env.Config = MISSING + demonstrations: trajectories.Config = MISSING gen_algo: rl_algorithm.Config = MISSING reward_net: reward_network.Config = MISSING demo_batch_size: int = 64 n_disc_updates_per_round: int = 2 - disc_opt_cls: optimizer_class.Config = optimizer_class.Adam + disc_opt_cls: optimizer_class.Config = optimizer_class.Adam() gen_train_timesteps: Optional[int] = None gen_replay_buffer_capacity: Optional[int] = None init_tensorboard: bool = False @@ -31,13 +37,15 @@ class AIRLConfig: @dataclasses.dataclass class AIRLRunConfig: + """Config for running AIRL.""" + defaults: list = dataclasses.field( default_factory=lambda: [ {"venv": "gym_env"}, {"airl/reward_net": "shaped"}, {"airl/gen_algo": "ppo"}, "_self_", - ] + ], ) seed: int = 0 venv: gym_env.Config = MISSING @@ -48,13 +56,34 @@ class AIRLRunConfig: cs = ConfigStore.instance() -cs.store(name="airl", node=AIRLConfig) -cs.store(name="airl_run", node=AIRLRunConfig) +cs.store( + name="airl_run", + node=AIRLRunConfig( + airl=AIRLConfig( + venv="${venv}", # type: ignore + demonstrations="${demonstrations}", # type: ignore + ), + ), +) trajectories.register_configs("demonstrations") gym_env.register_configs("venv") -policy.register_configs("demonstrations/expert_policy", dict(environment="${venv}")) # Make sure the expert generating the demonstrations uses the same env as the main env -rl_algorithm.register_configs("airl/gen_algo", dict(environment="${venv}", policy=policy_conf.ActorCriticPolicy(environment="${venv}"))) # The generation algo and its policy should use the main env by default -reward_network.register_configs("airl/reward_net", dict(environment="${venv}")) # The reward network should be tailored to the default environment by default + +# Make sure the expert generating the demonstrations uses the same env as the main env +policy.register_configs( + "demonstrations/expert_policy", + dict(environment="${venv}"), +) +rl_algorithm.register_configs( + "airl/gen_algo", + dict( + environment="${venv}", + policy=policy_conf.ActorCriticPolicy(environment="${venv}"), # type: ignore + ), +) # The generation algo and its policy should use the main env by default +reward_network.register_configs( + "airl/reward_net", + dict(environment="${venv}"), +) # The reward network should be tailored to the default environment by default @hydra.main( @@ -62,8 +91,9 @@ class AIRLRunConfig: config_path="config", config_name="airl_run", ) -def run_airl(cfg: AIRLRunConfig) -> None: +def run_airl(cfg: AIRLRunConfig) -> Dict[str, Any]: from imitation.data import rollout + from imitation.data.types import TrajectoryWithRew trainer = call(cfg.airl) @@ -80,11 +110,13 @@ def callback(round_num: int, /) -> None: # Save final artifacts. if cfg.checkpoint_interval >= 0: - logging.log(logging.INFO, f"Saving final checkpoint. TODO implement this") + logging.log(logging.INFO, "Saving final checkpoint. TODO implement this") return { # "imit_stats": imit_stats, - "expert_stats": rollout.rollout_stats(cfg.airl.demonstrations), + "expert_stats": rollout.rollout_stats( + cast(Sequence[TrajectoryWithRew], cfg.airl.demonstrations), + ), } diff --git a/src/imitation_cli/utils/activation_function_class.py b/src/imitation_cli/utils/activation_function_class.py index 22f5f8a8e..1ba225061 100644 --- a/src/imitation_cli/utils/activation_function_class.py +++ b/src/imitation_cli/utils/activation_function_class.py @@ -15,7 +15,7 @@ class TanH(Config): _target_: str = "imitation_cli.utils.activation_function_class.TanH.make" @staticmethod - def make(): + def make() -> type: import torch return torch.nn.Tanh @@ -26,7 +26,7 @@ class ReLU(Config): _target_: str = "imitation_cli.utils.activation_function_class.ReLU.make" @staticmethod - def make(): + def make() -> type: import torch return torch.nn.ReLU @@ -37,7 +37,7 @@ class LeakyReLU(Config): _target_: str = "imitation_cli.utils.activation_function_class.LeakyReLU.make" @staticmethod - def make(): + def make() -> type: import torch return torch.nn.LeakyReLU diff --git a/src/imitation_cli/utils/environment.py b/src/imitation_cli/utils/environment.py index b4ac26401..6ed0c94c1 100644 --- a/src/imitation_cli/utils/environment.py +++ b/src/imitation_cli/utils/environment.py @@ -1,4 +1,10 @@ +from __future__ import annotations import dataclasses +import typing +from typing import Optional + +if typing.TYPE_CHECKING: + from stable_baselines3.common.vec_env import VecEnv from hydra.core.config_store import ConfigStore from hydra.utils import call @@ -17,16 +23,16 @@ class Config: env_make_kwargs: dict = dataclasses.field( default_factory=dict ) # The kwargs passed to `spec.make`. - rng: randomness.Config = randomness.Config + rng: randomness.Config = randomness.Config() @staticmethod - def make(log_dir=None, **kwargs): + def make(log_dir: Optional[str]=None, **kwargs) -> VecEnv: from imitation.util import util return util.make_vec_env(log_dir=log_dir, **kwargs) -def make_rollout_venv(environment_config: Config): +def make_rollout_venv(environment_config: Config) -> VecEnv: from imitation.data import wrappers return call( diff --git a/src/imitation_cli/utils/feature_extractor_class.py b/src/imitation_cli/utils/feature_extractor_class.py index 97d55b166..33a5348d2 100644 --- a/src/imitation_cli/utils/feature_extractor_class.py +++ b/src/imitation_cli/utils/feature_extractor_class.py @@ -14,7 +14,7 @@ class FlattenExtractorConfig(Config): _target_: str = "imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make" @staticmethod - def make(): + def make() -> type: import stable_baselines3 return stable_baselines3.common.torch_layers.FlattenExtractor @@ -25,7 +25,7 @@ class NatureCNNConfig(Config): _target_: str = "imitation_cli.utils.feature_extractor_class.NatureCNNConfig.make" @staticmethod - def make(): + def make() -> type: import stable_baselines3 return stable_baselines3.common.torch_layers.NatureCNN diff --git a/src/imitation_cli/utils/optimizer_class.py b/src/imitation_cli/utils/optimizer_class.py index 986afa419..9e8ae5b9c 100644 --- a/src/imitation_cli/utils/optimizer_class.py +++ b/src/imitation_cli/utils/optimizer_class.py @@ -14,7 +14,7 @@ class Adam(Config): _target_: str = "imitation_cli.utils.optimizer_class.Adam.make" @staticmethod - def make(): + def make() -> type: import torch return torch.optim.Adam @@ -25,7 +25,7 @@ class SGD(Config): _target_: str = "imitation_cli.utils.optimizer_class.SGD.make" @staticmethod - def make(): + def make() -> type: import torch return torch.optim.SGD diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index 1fdd2b7e4..05b1685a3 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -1,7 +1,12 @@ +from __future__ import annotations import dataclasses import pathlib +import typing from typing import Any, Dict, List, Optional, Mapping +if typing.TYPE_CHECKING: + from stable_baselines3.common.vec_env import VecEnv + from stable_baselines3.common.policies import BasePolicy from hydra.core.config_store import ConfigStore from hydra.utils import call @@ -26,7 +31,7 @@ class Random(Config): _target_: str = "imitation_cli.utils.policy.Random.make" @staticmethod - def make(environment: environment_cfg.Config): + def make(environment: VecEnv) -> BasePolicy: from imitation.policies import base return base.RandomPolicy(environment.observation_space, environment.action_space) @@ -36,7 +41,7 @@ class ZeroPolicy(Config): _target_: str = "imitation_cli.utils.policy.ZeroPolicy.make" @staticmethod - def make(environment: environment_cfg.Config): + def make(environment: VecEnv) -> BasePolicy: from imitation.policies import base return base.ZeroPolicy(environment.observation_space, environment.action_space) @@ -83,9 +88,9 @@ def make_args( @staticmethod def make( - environment: environment_cfg.Config, + environment: VecEnv, **kwargs, - ): + ) -> BasePolicy: import stable_baselines3 as sb3 return sb3.common.policies.ActorCriticPolicy( @@ -118,14 +123,14 @@ class PolicyOnDisk(Loaded): @staticmethod def make( - environment: environment_cfg.Config, - path: pathlib.Path, + environment: VecEnv, type: str, - ): + path: pathlib.Path, + ) -> BasePolicy: from imitation.policies import serialize return serialize.load_stable_baselines_model( - Loaded.type_to_class(type), path, environment + Loaded.type_to_class(type), str(path), environment ).policy @@ -136,10 +141,10 @@ class PolicyFromHuggingface(Loaded): @staticmethod def make( + environment: VecEnv, type: str, - environment: environment_cfg.Config, organization: str, - ): + ) -> BasePolicy: import huggingface_sb3 as hfsb3 from imitation.policies import serialize diff --git a/src/imitation_cli/utils/randomness.py b/src/imitation_cli/utils/randomness.py index 0461621ef..1b83104a1 100644 --- a/src/imitation_cli/utils/randomness.py +++ b/src/imitation_cli/utils/randomness.py @@ -1,13 +1,18 @@ +from __future__ import annotations import dataclasses +import typing + +if typing.TYPE_CHECKING: + import numpy as np @dataclasses.dataclass class Config: _target_: str = "imitation_cli.utils.randomness.Config.make" - seed: int = "${seed}" + seed: int = "${seed}" # type: ignore @staticmethod - def make(seed: int): + def make(seed: int) -> np.random.Generator: import numpy as np import torch diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py index ef296b713..f3d00f749 100644 --- a/src/imitation_cli/utils/reward_network.py +++ b/src/imitation_cli/utils/reward_network.py @@ -1,6 +1,12 @@ +from __future__ import annotations import dataclasses +import typing from typing import Optional, Any, Mapping +if typing.TYPE_CHECKING: + from stable_baselines3.common.vec_env import VecEnv + from imitation.rewards.reward_nets import RewardNet + from hydra.core.config_store import ConfigStore from hydra.utils import call from omegaconf import MISSING @@ -8,7 +14,6 @@ import imitation_cli.utils.environment as environment_cg - @dataclasses.dataclass class Config: _target_: str = MISSING @@ -25,7 +30,11 @@ class BasicRewardNet(Config): normalize_input_layer: bool = True @staticmethod - def make(environment: environment_cg.Config, normalize_input_layer: bool, **kwargs): + def make( + environment: VecEnv, + normalize_input_layer: bool, + **kwargs + ) -> RewardNet: from imitation.rewards import reward_nets from imitation.util import networks @@ -48,7 +57,11 @@ class BasicShapedRewardNet(BasicRewardNet): discount_factor: float = 0.99 @staticmethod - def make(environment: environment_cg.Config, normalize_input_layer: bool, **kwargs): + def make( + environment: VecEnv, + normalize_input_layer: bool, + **kwargs + ) -> RewardNet: from imitation.rewards import reward_nets from imitation.util import networks @@ -74,10 +87,10 @@ class RewardEnsemble(Config): @staticmethod def make( - environment: environment_cg.Config, + environment: VecEnv, ensemble_member_config: BasicRewardNet, add_std_alpha: Optional[float], - ): + ) -> RewardNet: from imitation.rewards import reward_nets members = [call(ensemble_member_config)] diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index 0b3709d24..baa3eb552 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -1,12 +1,19 @@ +from __future__ import annotations import dataclasses import pathlib +import typing from typing import Optional, Mapping, Any +if typing.TYPE_CHECKING: + from stable_baselines3.common.vec_env import VecEnv + from stable_baselines3.common.policies import BasePolicy + from stable_baselines3 import PPO + from hydra.utils import call from omegaconf import MISSING from imitation_cli.utils import environment as environment_cfg -from imitation_cli.utils import policy as policy_conf +from imitation_cli.utils import policy as policy_cfg from imitation_cli.utils import schedule @@ -21,7 +28,7 @@ class PPO(Config): _target_: str = "imitation_cli.utils.rl_algorithm.PPO.make" # We disable recursive instantiation, so we can just make the arguments of the policy but not the policy itself _recursive_: bool = False - policy: policy_conf.ActorCriticPolicy = MISSING + policy: policy_cfg.ActorCriticPolicy = MISSING learning_rate: schedule.Config = schedule.FixedSchedule(3e-4) n_steps: int = 2048 batch_size: int = 64 @@ -39,20 +46,20 @@ class PPO(Config): target_kl: Optional[float] = None tensorboard_log: Optional[str] = None verbose: int = 0 - seed: int = "${seed}" + seed: int = "${seed}" # type: ignore device: str = "auto" @staticmethod def make( environment: environment_cfg.Config, - policy: policy_conf.ActorCriticPolicy, + policy: policy_cfg.ActorCriticPolicy, learning_rate: schedule.Config, clip_range: schedule.Config, **kwargs, - ): + ) -> PPO: import stable_baselines3 as sb3 - policy_kwargs = policy_conf.ActorCriticPolicy.make_args(**policy) + policy_kwargs = policy_cfg.ActorCriticPolicy.make_args(**typing.cast(dict, policy)) del policy_kwargs["use_sde"] del policy_kwargs["lr_schedule"] return sb3.PPO( @@ -71,11 +78,11 @@ class PPOOnDisk(PPO): path: pathlib.Path = MISSING @staticmethod - def make(path: pathlib.Path, environment: environment_cfg.Config): + def make(environment: VecEnv, path: pathlib.Path) -> PPO: from imitation.policies import serialize import stable_baselines3 as sb3 - - return serialize.load_stable_baselines_model(sb3.PPO, path, environment) + + return serialize.load_stable_baselines_model(sb3.PPO, str(path), environment) def register_configs(group: str = "rl_algorithm", defaults: Mapping[str, Any] = {}): diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py index 389bfc96f..7f9e7babe 100644 --- a/src/imitation_cli/utils/trajectories.py +++ b/src/imitation_cli/utils/trajectories.py @@ -1,5 +1,13 @@ +from __future__ import annotations import dataclasses import pathlib +import typing + +if typing.TYPE_CHECKING: + from stable_baselines3.common.policies import BasePolicy + from imitation.data.types import Trajectory + from typing import Sequence + import numpy as np from hydra.core.config_store import ConfigStore from hydra.utils import call @@ -19,10 +27,10 @@ class OnDisk(Config): path: pathlib.Path = MISSING @staticmethod - def make(path: pathlib.Path): + def make(path: pathlib.Path) -> Sequence[Trajectory]: from imitation.data import serialize - serialize.load(path) + return serialize.load(path) @dataclasses.dataclass @@ -36,9 +44,9 @@ class Generated(Config): @staticmethod def make( total_timesteps: int, - expert_policy: policy.Config, - rng: randomness.Config, - ): + expert_policy: BasePolicy, + rng: np.random.Generator, + ) -> Sequence[Trajectory]: from imitation.data import rollout expert = call(expert_policy) From c014dfb3834109b03ea9609ec319a8a02c9ddc95 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 19 Apr 2023 15:28:05 +0200 Subject: [PATCH 06/25] Remove env_test.py --- src/imitation_cli/env_test.py | 73 ----------------------------------- 1 file changed, 73 deletions(-) delete mode 100644 src/imitation_cli/env_test.py diff --git a/src/imitation_cli/env_test.py b/src/imitation_cli/env_test.py deleted file mode 100644 index eede51828..000000000 --- a/src/imitation_cli/env_test.py +++ /dev/null @@ -1,73 +0,0 @@ -import dataclasses - -import hydra -from hydra.core.config_store import ConfigStore -from omegaconf import MISSING - - -@dataclasses.dataclass -class EnvironmentConfig: - gym_id: str = MISSING # The environment to train on - n_envs: int = 8 # number of environments in VecEnv - parallel: bool = True # Use SubprocVecEnv rather than DummyVecEnv - max_episode_steps: int = MISSING # Set to positive int to limit episode horizons - env_make_kwargs: dict = dataclasses.field( - default_factory=dict - ) # The kwargs passed to `spec.make`. - - -@dataclasses.dataclass -class RetroEnvironmentConfig(EnvironmentConfig): - gym_id: str = "Retro-v0" - max_episode_steps: int = 4500 - - -@dataclasses.dataclass -class SealEnvironmentConfig(EnvironmentConfig): - gym_id: str = "Seal-v0" - max_episode_steps: int = 1000 - aaa: int = 555 - - -@dataclasses.dataclass -class PolicyConfig: - env: EnvironmentConfig - type: str = MISSING - - -@dataclasses.dataclass -class PPOPolicyConfig(PolicyConfig): - type: str = "ppo" - - -@dataclasses.dataclass -class RandomPolicyConfig(PolicyConfig): - type: str = "random" - - -@dataclasses.dataclass -class Config: - env: EnvironmentConfig - policy: PolicyConfig - - -cs = ConfigStore.instance() -cs.store(name="config", node=Config) - -cs.store(group="env", name="retro", node=RetroEnvironmentConfig) -cs.store(group="env", name="seal", node=SealEnvironmentConfig) - -cs.store(group="policy", name="ppo", node=PPOPolicyConfig(env="${env}")) -cs.store(group="policy", name="random", node=RandomPolicyConfig(env="${env}")) - -# cs.store(group="policy/env", name="retro", node=RetroEnvironmentConfig) -# cs.store(group="policy/env", name="seal", node=SealEnvironmentConfig) - - -@hydra.main(version_base=None, config_name="config") -def main(cfg: Config): - print(cfg) - - -if __name__ == "__main__": - main() From 6c2b3a1a81d930c57aca4b9f5bfff1ef0f9c0bbb Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 20 Apr 2023 19:05:53 +0200 Subject: [PATCH 07/25] Add code for policy evaluation. --- .../algorithms/adversarial/common.py | 5 ++ src/imitation_cli/airl.py | 19 ++-- src/imitation_cli/utils/policy_evaluation.py | 90 +++++++++++++++++++ 3 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 src/imitation_cli/utils/policy_evaluation.py diff --git a/src/imitation/algorithms/adversarial/common.py b/src/imitation/algorithms/adversarial/common.py index 62b459a0d..8ad7d795b 100644 --- a/src/imitation/algorithms/adversarial/common.py +++ b/src/imitation/algorithms/adversarial/common.py @@ -187,6 +187,7 @@ def __init__( if self.demo_batch_size % self.demo_minibatch_size != 0: raise ValueError("Batch size must be a multiple of minibatch size.") self._demo_data_loader = None + self._demonstrations: Optional[base.AnyTransitions] = None self._endless_expert_iterator = None super().__init__( demonstrations=demonstrations, @@ -298,12 +299,16 @@ def reward_test(self) -> reward_nets.RewardNet: """Reward used to train policy at "test" time after adversarial training.""" def set_demonstrations(self, demonstrations: base.AnyTransitions) -> None: + self._demonstrations = demonstrations self._demo_data_loader = base.make_data_loader( demonstrations, self.demo_batch_size, ) self._endless_expert_iterator = util.endless_iter(self._demo_data_loader) + def get_demonstrations(self) -> Optional[base.AnyTransitions]: + return self._demonstrations + def _next_expert_batch(self) -> Mapping: assert self._endless_expert_iterator is not None return next(self._endless_expert_iterator) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index a659d9f12..2f1ffe5c2 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -12,7 +12,12 @@ from imitation_cli.utils import optimizer_class from imitation_cli.utils import policy from imitation_cli.utils import policy as policy_conf -from imitation_cli.utils import reward_network, rl_algorithm, trajectories +from imitation_cli.utils import ( + policy_evaluation, + reward_network, + rl_algorithm, + trajectories, +) @dataclasses.dataclass @@ -44,6 +49,7 @@ class AIRLRunConfig: {"venv": "gym_env"}, {"airl/reward_net": "shaped"}, {"airl/gen_algo": "ppo"}, + {"evaluation": "default_evaluation"}, "_self_", ], ) @@ -53,6 +59,7 @@ class AIRLRunConfig: airl: AIRLConfig = AIRLConfig() total_timesteps: int = int(1e6) checkpoint_interval: int = 0 + evaluation: policy_evaluation.Config = MISSING cs = ConfigStore.instance() @@ -84,6 +91,7 @@ class AIRLRunConfig: "airl/reward_net", dict(environment="${venv}"), ) # The reward network should be tailored to the default environment by default +policy_evaluation.register_configs("evaluation", dict(environment="${venv}")) @hydra.main( @@ -92,10 +100,11 @@ class AIRLRunConfig: config_name="airl_run", ) def run_airl(cfg: AIRLRunConfig) -> Dict[str, Any]: + from imitation.algorithms.adversarial import airl from imitation.data import rollout from imitation.data.types import TrajectoryWithRew - trainer = call(cfg.airl) + trainer: airl.AIRL = call(cfg.airl) def callback(round_num: int, /) -> None: if cfg.checkpoint_interval > 0 and round_num % cfg.checkpoint_interval == 0: @@ -106,16 +115,16 @@ def callback(round_num: int, /) -> None: trainer.train(cfg.total_timesteps, callback) # TODO: implement evaluation - # imit_stats = policy_evaluation.eval_policy(trainer.policy, trainer.venv_train) + imit_stats = policy_evaluation.eval_policy(trainer.policy, cfg.evaluation) # Save final artifacts. if cfg.checkpoint_interval >= 0: logging.log(logging.INFO, "Saving final checkpoint. TODO implement this") return { - # "imit_stats": imit_stats, + "imit_stats": imit_stats, "expert_stats": rollout.rollout_stats( - cast(Sequence[TrajectoryWithRew], cfg.airl.demonstrations), + cast(Sequence[TrajectoryWithRew], trainer.get_demonstrations()) ), } diff --git a/src/imitation_cli/utils/policy_evaluation.py b/src/imitation_cli/utils/policy_evaluation.py new file mode 100644 index 000000000..1412bed2f --- /dev/null +++ b/src/imitation_cli/utils/policy_evaluation.py @@ -0,0 +1,90 @@ +"""Code to evaluate trained policies.""" +from __future__ import annotations + +import dataclasses +import typing +from typing import Any, Mapping, Union + +from hydra.utils import call +from omegaconf import MISSING + +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import randomness + +if typing.TYPE_CHECKING: + from stable_baselines3.common import base_class, policies, vec_env + + +@dataclasses.dataclass +class Config: + """Configuration for evaluating a policy.""" + + environment: environment_cfg.Config = MISSING + n_episodes_eval: int = 50 + rng: randomness.Config = randomness.Config() + + +def register_configs(group: str, defaults: Mapping[str, Any] = {}) -> None: + from hydra.core.config_store import ConfigStore + + cs = ConfigStore.instance() + cs.store( + name="default_evaluation", + group=group, + node=Config(**defaults), + ) + cs.store( + name="fast_evaluation", + group=group, + node=Config(n_episodes_eval=2, **defaults), + ) + + +def eval_policy( + rl_algo: Union[base_class.BaseAlgorithm, policies.BasePolicy], + config: Config, +) -> typing.Mapping[str, float]: + """Evaluation of imitation learned policy. + + Has the side effect of setting `rl_algo`'s environment to `venv` + if it is a `BaseAlgorithm`. + + Args: + rl_algo: Algorithm to evaluate. + config: Configuration for evaluation. + + Returns: + A dictionary with two keys. "imit_stats" gives the return value of + `rollout_stats()` on rollouts test-reward-wrapped environment, using the final + policy (remember that the ground-truth reward can be recovered from the + "monitor_return" key). "expert_stats" gives the return value of + `rollout_stats()` on the expert demonstrations loaded from `rollout_path`. + """ + from stable_baselines3.common import base_class + + from imitation.data import rollout + + sample_until_eval = rollout.make_min_episodes(config.n_episodes_eval) + venv = call(config.environment) + rng = call(config.rng) + + if isinstance(rl_algo, base_class.BaseAlgorithm): + # Set RL algorithm's env to venv, removing any cruft wrappers that the RL + # algorithm's environment may have accumulated. + rl_algo.set_env(venv) + # Generate trajectories with the RL algorithm's env - SB3 may apply wrappers + # under the hood to get it to work with the RL algorithm (e.g. transposing + # images, so they can be fed into CNNs). + train_env = rl_algo.get_env() + assert train_env is not None + else: + train_env = venv + + train_env = typing.cast(vec_env.VecEnv, train_env) + trajs = rollout.generate_trajectories( + rl_algo, + train_env, + sample_until=sample_until_eval, + rng=rng, + ) + return rollout.rollout_stats(trajs) From fe0bc1d4807c56f6f2e686c039ee526bb3012ea8 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sat, 22 Apr 2023 06:20:20 +0200 Subject: [PATCH 08/25] Move airl configuration to it's own file and restructure the main run file --- src/imitation_cli/airl.py | 66 ++++++++----------- .../algorithm_configurations/__init__.py | 1 + .../algorithm_configurations/airl.py | 33 ++++++++++ 3 files changed, 60 insertions(+), 40 deletions(-) create mode 100644 src/imitation_cli/algorithm_configurations/__init__.py create mode 100644 src/imitation_cli/algorithm_configurations/airl.py diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 2f1ffe5c2..d6f43fee8 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -1,15 +1,15 @@ """Config and run configuration for AIRL.""" import dataclasses import logging -from typing import Any, Dict, Optional, Sequence, cast +from typing import Any, Dict, Sequence, cast import hydra from hydra.core.config_store import ConfigStore from hydra.utils import call from omegaconf import MISSING -from imitation_cli.utils import environment as gym_env -from imitation_cli.utils import optimizer_class +from imitation_cli.algorithm_configurations import airl as airl_cfg +from imitation_cli.utils import environment as environment_cfg from imitation_cli.utils import policy from imitation_cli.utils import policy as policy_conf from imitation_cli.utils import ( @@ -21,27 +21,7 @@ @dataclasses.dataclass -class AIRLConfig: - """Config for AIRL.""" - - _target_: str = "imitation.algorithms.adversarial.airl.AIRL" - venv: gym_env.Config = MISSING - demonstrations: trajectories.Config = MISSING - gen_algo: rl_algorithm.Config = MISSING - reward_net: reward_network.Config = MISSING - demo_batch_size: int = 64 - n_disc_updates_per_round: int = 2 - disc_opt_cls: optimizer_class.Config = optimizer_class.Adam() - gen_train_timesteps: Optional[int] = None - gen_replay_buffer_capacity: Optional[int] = None - init_tensorboard: bool = False - init_tensorboard_graph: bool = False - debug_use_ground_truth: bool = False - allow_variable_horizon: bool = True # TODO: true just for debugging - - -@dataclasses.dataclass -class AIRLRunConfig: +class RunConfig: """Config for running AIRL.""" defaults: list = dataclasses.field( @@ -54,32 +34,27 @@ class AIRLRunConfig: ], ) seed: int = 0 - venv: gym_env.Config = MISSING - demonstrations: trajectories.Config = MISSING - airl: AIRLConfig = AIRLConfig() + total_timesteps: int = int(1e6) checkpoint_interval: int = 0 + + venv: environment_cfg.Config = MISSING + demonstrations: trajectories.Config = MISSING + airl: airl_cfg.Config = MISSING evaluation: policy_evaluation.Config = MISSING cs = ConfigStore.instance() -cs.store( - name="airl_run", - node=AIRLRunConfig( - airl=AIRLConfig( - venv="${venv}", # type: ignore - demonstrations="${demonstrations}", # type: ignore - ), - ), -) -trajectories.register_configs("demonstrations") -gym_env.register_configs("venv") +environment_cfg.register_configs("venv") + +trajectories.register_configs("demonstrations") # Make sure the expert generating the demonstrations uses the same env as the main env policy.register_configs( "demonstrations/expert_policy", dict(environment="${venv}"), ) + rl_algorithm.register_configs( "airl/gen_algo", dict( @@ -91,6 +66,17 @@ class AIRLRunConfig: "airl/reward_net", dict(environment="${venv}"), ) # The reward network should be tailored to the default environment by default + +cs.store( + name="airl_run", + node=RunConfig( + airl=airl_cfg.Config( + venv="${venv}", # type: ignore + demonstrations="${demonstrations}", # type: ignore + ), + ), +) + policy_evaluation.register_configs("evaluation", dict(environment="${venv}")) @@ -99,7 +85,7 @@ class AIRLRunConfig: config_path="config", config_name="airl_run", ) -def run_airl(cfg: AIRLRunConfig) -> Dict[str, Any]: +def run_airl(cfg: RunConfig) -> Dict[str, Any]: from imitation.algorithms.adversarial import airl from imitation.data import rollout from imitation.data.types import TrajectoryWithRew @@ -124,7 +110,7 @@ def callback(round_num: int, /) -> None: return { "imit_stats": imit_stats, "expert_stats": rollout.rollout_stats( - cast(Sequence[TrajectoryWithRew], trainer.get_demonstrations()) + cast(Sequence[TrajectoryWithRew], trainer.get_demonstrations()), ), } diff --git a/src/imitation_cli/algorithm_configurations/__init__.py b/src/imitation_cli/algorithm_configurations/__init__.py new file mode 100644 index 000000000..dfe338fdc --- /dev/null +++ b/src/imitation_cli/algorithm_configurations/__init__.py @@ -0,0 +1 @@ +"""Structured Hydra configuration for Imitation algorithms.""" diff --git a/src/imitation_cli/algorithm_configurations/airl.py b/src/imitation_cli/algorithm_configurations/airl.py new file mode 100644 index 000000000..16dfbdfef --- /dev/null +++ b/src/imitation_cli/algorithm_configurations/airl.py @@ -0,0 +1,33 @@ +"""Config for AIRL.""" +import dataclasses +from typing import Optional + +from omegaconf import MISSING + +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import ( + optimizer_class, + reward_network, + rl_algorithm, + trajectories, +) + + +@dataclasses.dataclass +class Config: + """Config for AIRL.""" + + _target_: str = "imitation.algorithms.adversarial.airl.AIRL" + venv: environment_cfg.Config = MISSING + demonstrations: trajectories.Config = MISSING + gen_algo: rl_algorithm.Config = MISSING + reward_net: reward_network.Config = MISSING + demo_batch_size: int = 64 + n_disc_updates_per_round: int = 2 + disc_opt_cls: optimizer_class.Config = optimizer_class.Adam() + gen_train_timesteps: Optional[int] = None + gen_replay_buffer_capacity: Optional[int] = None + init_tensorboard: bool = False + init_tensorboard_graph: bool = False + debug_use_ground_truth: bool = False + allow_variable_horizon: bool = True # TODO: true just for debugging From 2906d55932cff0a08a405b48c1af8961ceaf78ad Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sat, 22 Apr 2023 07:18:32 +0200 Subject: [PATCH 09/25] Add checkpoint saving. --- src/imitation_cli/airl.py | 31 +++++++++++++++----- src/imitation_cli/utils/policy_evaluation.py | 2 +- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index d6f43fee8..74407596d 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -1,13 +1,17 @@ """Config and run configuration for AIRL.""" import dataclasses import logging +import pathlib from typing import Any, Dict, Sequence, cast import hydra +import torch as th from hydra.core.config_store import ConfigStore +from hydra.core.hydra_config import HydraConfig from hydra.utils import call from omegaconf import MISSING +from imitation.policies import serialize from imitation_cli.algorithm_configurations import airl as airl_cfg from imitation_cli.utils import environment as environment_cfg from imitation_cli.utils import policy @@ -67,6 +71,8 @@ class RunConfig: dict(environment="${venv}"), ) # The reward network should be tailored to the default environment by default +policy_evaluation.register_configs("evaluation", dict(environment="${venv}")) + cs.store( name="airl_run", node=RunConfig( @@ -77,8 +83,6 @@ class RunConfig: ), ) -policy_evaluation.register_configs("evaluation", dict(environment="${venv}")) - @hydra.main( version_base=None, @@ -92,20 +96,31 @@ def run_airl(cfg: RunConfig) -> Dict[str, Any]: trainer: airl.AIRL = call(cfg.airl) + checkpoints_path = HydraConfig.get().run.dir / pathlib.Path("checkpoints") + + def save(path: str): + """Save discriminator and generator.""" + # We implement this here and not in Trainer since we do not want to actually + # serialize the whole Trainer (including e.g. expert demonstrations). + save_path = checkpoints_path / path + save_path.mkdir(parents=True, exist_ok=True) + + th.save(trainer.reward_train, save_path / "reward_train.pt") + th.save(trainer.reward_test, save_path / "reward_test.pt") + serialize.save_stable_model(save_path / "gen_policy", trainer.gen_algo) + def callback(round_num: int, /) -> None: if cfg.checkpoint_interval > 0 and round_num % cfg.checkpoint_interval == 0: - logging.log( - logging.INFO, - f"Saving checkpoint at round {round_num}. TODO implement this", - ) + logging.log(logging.INFO, f"Saving checkpoint at round {round_num}") + save(f"{round_num:05d}") trainer.train(cfg.total_timesteps, callback) - # TODO: implement evaluation imit_stats = policy_evaluation.eval_policy(trainer.policy, cfg.evaluation) # Save final artifacts. if cfg.checkpoint_interval >= 0: - logging.log(logging.INFO, "Saving final checkpoint. TODO implement this") + logging.log(logging.INFO, "Saving final checkpoint.") + save("final") return { "imit_stats": imit_stats, diff --git a/src/imitation_cli/utils/policy_evaluation.py b/src/imitation_cli/utils/policy_evaluation.py index 1412bed2f..cd6ac453f 100644 --- a/src/imitation_cli/utils/policy_evaluation.py +++ b/src/imitation_cli/utils/policy_evaluation.py @@ -60,7 +60,7 @@ def eval_policy( "monitor_return" key). "expert_stats" gives the return value of `rollout_stats()` on the expert demonstrations loaded from `rollout_path`. """ - from stable_baselines3.common import base_class + from stable_baselines3.common import base_class, vec_env from imitation.data import rollout From f1366e6a36729d276922953f5302a5cecc6b13fc Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sat, 22 Apr 2023 13:01:50 +0200 Subject: [PATCH 10/25] Use the hydra chdir feature to store logs in the output directory. --- src/imitation_cli/airl.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 74407596d..efb6b8437 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -7,7 +7,6 @@ import hydra import torch as th from hydra.core.config_store import ConfigStore -from hydra.core.hydra_config import HydraConfig from hydra.utils import call from omegaconf import MISSING @@ -46,6 +45,10 @@ class RunConfig: demonstrations: trajectories.Config = MISSING airl: airl_cfg.Config = MISSING evaluation: policy_evaluation.Config = MISSING + # This ensures that the working directory is changed + # to the hydra output dir + hydra: Any = dataclasses.field( + default_factory=lambda: dict(job=dict(chdir=True))) cs = ConfigStore.instance() @@ -96,7 +99,7 @@ def run_airl(cfg: RunConfig) -> Dict[str, Any]: trainer: airl.AIRL = call(cfg.airl) - checkpoints_path = HydraConfig.get().run.dir / pathlib.Path("checkpoints") + checkpoints_path = pathlib.Path("checkpoints") def save(path: str): """Save discriminator and generator.""" From cb92c3322d3c31bf89a0c4d4a09e409db50da36b Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 08:59:39 +0200 Subject: [PATCH 11/25] Remove defaults parameter from `register_configs`, introduce a air_run.yaml and assemble the airl run config when registering in the config store. --- src/imitation_cli/airl.py | 56 +++++++------------- src/imitation_cli/config/airl_run.yaml | 14 +++++ src/imitation_cli/utils/policy.py | 12 ++--- src/imitation_cli/utils/policy_evaluation.py | 14 ++--- src/imitation_cli/utils/rl_algorithm.py | 8 +-- 5 files changed, 46 insertions(+), 58 deletions(-) create mode 100644 src/imitation_cli/config/airl_run.yaml diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index efb6b8437..ec0f45148 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -13,9 +13,8 @@ from imitation.policies import serialize from imitation_cli.algorithm_configurations import airl as airl_cfg from imitation_cli.utils import environment as environment_cfg -from imitation_cli.utils import policy -from imitation_cli.utils import policy as policy_conf from imitation_cli.utils import ( + policy, policy_evaluation, reward_network, rl_algorithm, @@ -27,17 +26,7 @@ class RunConfig: """Config for running AIRL.""" - defaults: list = dataclasses.field( - default_factory=lambda: [ - {"venv": "gym_env"}, - {"airl/reward_net": "shaped"}, - {"airl/gen_algo": "ppo"}, - {"evaluation": "default_evaluation"}, - "_self_", - ], - ) seed: int = 0 - total_timesteps: int = int(1e6) checkpoint_interval: int = 0 @@ -47,41 +36,34 @@ class RunConfig: evaluation: policy_evaluation.Config = MISSING # This ensures that the working directory is changed # to the hydra output dir - hydra: Any = dataclasses.field( - default_factory=lambda: dict(job=dict(chdir=True))) + hydra: Any = dataclasses.field(default_factory=lambda: dict(job=dict(chdir=True))) cs = ConfigStore.instance() - -environment_cfg.register_configs("venv") - trajectories.register_configs("demonstrations") -# Make sure the expert generating the demonstrations uses the same env as the main env -policy.register_configs( - "demonstrations/expert_policy", - dict(environment="${venv}"), -) - -rl_algorithm.register_configs( - "airl/gen_algo", - dict( - environment="${venv}", - policy=policy_conf.ActorCriticPolicy(environment="${venv}"), # type: ignore - ), -) # The generation algo and its policy should use the main env by default -reward_network.register_configs( - "airl/reward_net", - dict(environment="${venv}"), -) # The reward network should be tailored to the default environment by default - -policy_evaluation.register_configs("evaluation", dict(environment="${venv}")) +policy.register_configs("demonstrations/expert_policy") +environment_cfg.register_configs("venv") +rl_algorithm.register_configs("airl/gen_algo") +reward_network.register_configs("airl/reward_net") +policy_evaluation.register_configs("evaluation") cs.store( - name="airl_run", + name="airl_run_base", node=RunConfig( + demonstrations=trajectories.Generated( + expert_policy=policy.Random(environment="${venv}"), # type: ignore + ), airl=airl_cfg.Config( venv="${venv}", # type: ignore demonstrations="${demonstrations}", # type: ignore + reward_net=reward_network.Config(environment="${venv}"), # type: ignore + gen_algo=rl_algorithm.PPO( + environment="${venv}", # type: ignore + policy=policy.ActorCriticPolicy(environment="${venv}"), # type: ignore + ), + ), + evaluation=policy_evaluation.Config( + environment="${venv}", # type: ignore ), ), ) diff --git a/src/imitation_cli/config/airl_run.yaml b/src/imitation_cli/config/airl_run.yaml new file mode 100644 index 000000000..3dc15897b --- /dev/null +++ b/src/imitation_cli/config/airl_run.yaml @@ -0,0 +1,14 @@ +defaults: + - airl_run_base + - venv: cartpole + - airl/reward_net: shaped + - airl/gen_algo: ppo + - evaluation: default_evaluation +# - venv@airl.reward_net.environment: pendulum # This is how we inject a different environment + - _self_ + +total_timesteps: 40000 +checkpoint_interval: 1 + +airl: + demo_batch_size: 128 \ No newline at end of file diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index 05b1685a3..2d431059e 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -160,10 +160,10 @@ def make( return model.policy -def register_configs(group: str, defaults: Mapping[str, Any] = {}): +def register_configs(group: str): cs = ConfigStore.instance() - cs.store(group=group, name="random", node=Random(**defaults)) - cs.store(group=group, name="zero", node=ZeroPolicy(**defaults)) - cs.store(group=group, name="on_disk", node=PolicyOnDisk(**defaults)) - cs.store(group=group, name="from_huggingface", node=PolicyFromHuggingface(**defaults)) - cs.store(group=group, name="actor_critic", node=ActorCriticPolicy(**defaults)) + cs.store(group=group, name="random", node=Random) + cs.store(group=group, name="zero", node=ZeroPolicy) + cs.store(group=group, name="on_disk", node=PolicyOnDisk) + cs.store(group=group, name="from_huggingface", node=PolicyFromHuggingface) + cs.store(group=group, name="actor_critic", node=ActorCriticPolicy) diff --git a/src/imitation_cli/utils/policy_evaluation.py b/src/imitation_cli/utils/policy_evaluation.py index cd6ac453f..eb36c392b 100644 --- a/src/imitation_cli/utils/policy_evaluation.py +++ b/src/imitation_cli/utils/policy_evaluation.py @@ -24,20 +24,12 @@ class Config: rng: randomness.Config = randomness.Config() -def register_configs(group: str, defaults: Mapping[str, Any] = {}) -> None: +def register_configs(group: str) -> None: from hydra.core.config_store import ConfigStore cs = ConfigStore.instance() - cs.store( - name="default_evaluation", - group=group, - node=Config(**defaults), - ) - cs.store( - name="fast_evaluation", - group=group, - node=Config(n_episodes_eval=2, **defaults), - ) + cs.store(name="default_evaluation", group=group, node=Config) + cs.store(name="fast_evaluation", group=group, node=Config(n_episodes_eval=2)) def eval_policy( diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index baa3eb552..910ee859d 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -28,7 +28,7 @@ class PPO(Config): _target_: str = "imitation_cli.utils.rl_algorithm.PPO.make" # We disable recursive instantiation, so we can just make the arguments of the policy but not the policy itself _recursive_: bool = False - policy: policy_cfg.ActorCriticPolicy = MISSING + policy: policy_cfg.ActorCriticPolicy = policy_cfg.ActorCriticPolicy() learning_rate: schedule.Config = schedule.FixedSchedule(3e-4) n_steps: int = 2048 batch_size: int = 64 @@ -85,9 +85,9 @@ def make(environment: VecEnv, path: pathlib.Path) -> PPO: return serialize.load_stable_baselines_model(sb3.PPO, str(path), environment) -def register_configs(group: str = "rl_algorithm", defaults: Mapping[str, Any] = {}): +def register_configs(group: str = "rl_algorithm"): from hydra.core.config_store import ConfigStore cs = ConfigStore.instance() - cs.store(name="ppo", group=group, node=PPO(**defaults)) - cs.store(name="ppo_on_disk", group=group, node=PPOOnDisk(**defaults)) + cs.store(name="ppo", group=group, node=PPO) + cs.store(name="ppo_on_disk", group=group, node=PPOOnDisk) From b8221ab7314fd5ca42c35530bd93dd67640908e0 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 16:12:07 +0200 Subject: [PATCH 12/25] Define cartpole and pendulum envs as structured configs. --- src/imitation_cli/utils/environment.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/imitation_cli/utils/environment.py b/src/imitation_cli/utils/environment.py index 6ed0c94c1..1d035b3d1 100644 --- a/src/imitation_cli/utils/environment.py +++ b/src/imitation_cli/utils/environment.py @@ -45,3 +45,6 @@ def make_rollout_venv(environment_config: Config) -> VecEnv: def register_configs(group: str): cs = ConfigStore.instance() cs.store(group=group, name="gym_env", node=Config) + cs.store(group=group, name="cartpole", node=Config(env_name="CartPole-v0", max_episode_steps=500)) + cs.store(group=group, name="pendulum", node=Config(env_name="Pendulum-v1", max_episode_steps=500)) + From 09121385630b126efbf5b4cf1153940b9abb59c9 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 16:12:58 +0200 Subject: [PATCH 13/25] Ensure PPO on disk does not inherit from PPO and loads from an absolute path. --- src/imitation_cli/utils/rl_algorithm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index 910ee859d..9c7d153d1 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -9,7 +9,7 @@ from stable_baselines3.common.policies import BasePolicy from stable_baselines3 import PPO -from hydra.utils import call +from hydra.utils import call, to_absolute_path from omegaconf import MISSING from imitation_cli.utils import environment as environment_cfg @@ -73,7 +73,7 @@ def make( @dataclasses.dataclass -class PPOOnDisk(PPO): +class PPOOnDisk(Config): _target_: str = "imitation_cli.utils.rl_algorithm.PPOOnDisk.make" path: pathlib.Path = MISSING @@ -82,7 +82,7 @@ def make(environment: VecEnv, path: pathlib.Path) -> PPO: from imitation.policies import serialize import stable_baselines3 as sb3 - return serialize.load_stable_baselines_model(sb3.PPO, str(path), environment) + return serialize.load_stable_baselines_model(sb3.PPO, str(to_absolute_path(path)), environment) def register_configs(group: str = "rl_algorithm"): From a0c0d98ed9a9af71a2c37f9a8a25cfc0859d44e5 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 16:14:46 +0200 Subject: [PATCH 14/25] Remove low default number of steps for generated trajectories and move that to the config file. --- src/imitation_cli/config/airl_run.yaml | 3 ++- src/imitation_cli/utils/trajectories.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/imitation_cli/config/airl_run.yaml b/src/imitation_cli/config/airl_run.yaml index 3dc15897b..1c83e3f58 100644 --- a/src/imitation_cli/config/airl_run.yaml +++ b/src/imitation_cli/config/airl_run.yaml @@ -11,4 +11,5 @@ total_timesteps: 40000 checkpoint_interval: 1 airl: - demo_batch_size: 128 \ No newline at end of file + demo_batch_size: 128 demonstrations: + total_timesteps: 10 \ No newline at end of file diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py index 7f9e7babe..c942a48be 100644 --- a/src/imitation_cli/utils/trajectories.py +++ b/src/imitation_cli/utils/trajectories.py @@ -37,7 +37,7 @@ def make(path: pathlib.Path) -> Sequence[Trajectory]: class Generated(Config): _target_: str = "imitation_cli.utils.trajectories.Generated.make" _recursive_: bool = False # We disable the recursive flag, so we can extract the environment from the expert policy - total_timesteps: int = int(10) # TODO: this is low for debugging + total_timesteps: int = MISSING expert_policy: policy.Config = MISSING rng: randomness.Config = randomness.Config() From 1ff9aa1c3b30baeabb80ffd5dd4aa69c5026a82e Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 16:17:30 +0200 Subject: [PATCH 15/25] Introduce default_environment to the register_configs functions, remove the global demonstrations field and rename the global venv field to environment. --- src/imitation_cli/airl.py | 29 ++++++-------------- src/imitation_cli/config/airl_run.yaml | 9 ++++-- src/imitation_cli/utils/policy.py | 15 +++++----- src/imitation_cli/utils/policy_evaluation.py | 8 +++--- src/imitation_cli/utils/reward_network.py | 11 ++++---- src/imitation_cli/utils/rl_algorithm.py | 9 +++--- src/imitation_cli/utils/trajectories.py | 8 +++--- 7 files changed, 41 insertions(+), 48 deletions(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index ec0f45148..438d9038f 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -30,8 +30,7 @@ class RunConfig: total_timesteps: int = int(1e6) checkpoint_interval: int = 0 - venv: environment_cfg.Config = MISSING - demonstrations: trajectories.Config = MISSING + environment: environment_cfg.Config = MISSING airl: airl_cfg.Config = MISSING evaluation: policy_evaluation.Config = MISSING # This ensures that the working directory is changed @@ -40,30 +39,18 @@ class RunConfig: cs = ConfigStore.instance() -trajectories.register_configs("demonstrations") -policy.register_configs("demonstrations/expert_policy") -environment_cfg.register_configs("venv") -rl_algorithm.register_configs("airl/gen_algo") -reward_network.register_configs("airl/reward_net") -policy_evaluation.register_configs("evaluation") +environment_cfg.register_configs("environment") +trajectories.register_configs("airl/demonstrations", "${environment}") +policy.register_configs("airl/demonstrations/expert_policy", "${environment}") +rl_algorithm.register_configs("airl/gen_algo", "${environment}") +reward_network.register_configs("airl/reward_net", "${environment}") +policy_evaluation.register_configs("evaluation", "${environment}") cs.store( name="airl_run_base", node=RunConfig( - demonstrations=trajectories.Generated( - expert_policy=policy.Random(environment="${venv}"), # type: ignore - ), airl=airl_cfg.Config( - venv="${venv}", # type: ignore - demonstrations="${demonstrations}", # type: ignore - reward_net=reward_network.Config(environment="${venv}"), # type: ignore - gen_algo=rl_algorithm.PPO( - environment="${venv}", # type: ignore - policy=policy.ActorCriticPolicy(environment="${venv}"), # type: ignore - ), - ), - evaluation=policy_evaluation.Config( - environment="${venv}", # type: ignore + venv="${environment}", # type: ignore ), ), ) diff --git a/src/imitation_cli/config/airl_run.yaml b/src/imitation_cli/config/airl_run.yaml index 1c83e3f58..d1bdcd0fd 100644 --- a/src/imitation_cli/config/airl_run.yaml +++ b/src/imitation_cli/config/airl_run.yaml @@ -1,15 +1,18 @@ defaults: - airl_run_base - - venv: cartpole + - environment: cartpole - airl/reward_net: shaped - airl/gen_algo: ppo - evaluation: default_evaluation -# - venv@airl.reward_net.environment: pendulum # This is how we inject a different environment + - airl/demonstrations: generated + - airl/demonstrations/expert_policy: random +# - environment@airl.reward_net.environment: pendulum # This is how we inject a different environment - _self_ total_timesteps: 40000 checkpoint_interval: 1 airl: - demo_batch_size: 128 demonstrations: + demo_batch_size: 128 + demonstrations: total_timesteps: 10 \ No newline at end of file diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index 2d431059e..c7ec558c7 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -1,8 +1,9 @@ from __future__ import annotations + import dataclasses import pathlib import typing -from typing import Any, Dict, List, Optional, Mapping +from typing import Any, Dict, List, Optional, Union if typing.TYPE_CHECKING: from stable_baselines3.common.vec_env import VecEnv @@ -160,10 +161,10 @@ def make( return model.policy -def register_configs(group: str): +def register_configs(group: str, default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING): cs = ConfigStore.instance() - cs.store(group=group, name="random", node=Random) - cs.store(group=group, name="zero", node=ZeroPolicy) - cs.store(group=group, name="on_disk", node=PolicyOnDisk) - cs.store(group=group, name="from_huggingface", node=PolicyFromHuggingface) - cs.store(group=group, name="actor_critic", node=ActorCriticPolicy) + cs.store(group=group, name="random", node=Random(environment=default_environment)) + cs.store(group=group, name="zero", node=ZeroPolicy(environment=default_environment)) + cs.store(group=group, name="on_disk", node=PolicyOnDisk(environment=default_environment)) + cs.store(group=group, name="from_huggingface", node=PolicyFromHuggingface(environment=default_environment)) + cs.store(group=group, name="actor_critic", node=ActorCriticPolicy(environment=default_environment)) diff --git a/src/imitation_cli/utils/policy_evaluation.py b/src/imitation_cli/utils/policy_evaluation.py index eb36c392b..8678f27e8 100644 --- a/src/imitation_cli/utils/policy_evaluation.py +++ b/src/imitation_cli/utils/policy_evaluation.py @@ -3,7 +3,7 @@ import dataclasses import typing -from typing import Any, Mapping, Union +from typing import Optional, Union from hydra.utils import call from omegaconf import MISSING @@ -24,12 +24,12 @@ class Config: rng: randomness.Config = randomness.Config() -def register_configs(group: str) -> None: +def register_configs(group: str, default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING) -> None: from hydra.core.config_store import ConfigStore cs = ConfigStore.instance() - cs.store(name="default_evaluation", group=group, node=Config) - cs.store(name="fast_evaluation", group=group, node=Config(n_episodes_eval=2)) + cs.store(name="default_evaluation", group=group, node=Config(environment=default_environment)) + cs.store(name="fast_evaluation", group=group, node=Config(environment=default_environment, n_episodes_eval=2)) def eval_policy( diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py index f3d00f749..ede4d6b9a 100644 --- a/src/imitation_cli/utils/reward_network.py +++ b/src/imitation_cli/utils/reward_network.py @@ -1,7 +1,8 @@ from __future__ import annotations + import dataclasses import typing -from typing import Optional, Any, Mapping +from typing import Optional, Union if typing.TYPE_CHECKING: from stable_baselines3.common.vec_env import VecEnv @@ -105,8 +106,8 @@ def make( return reward_net -def register_configs(group: str, defaults: Mapping[str, Any] = {}): +def register_configs(group: str, default_environment: Optional[Union[environment_cg.Config, str]] = MISSING): cs = ConfigStore.instance() - cs.store(group=group, name="basic", node=BasicRewardNet(**defaults)) - cs.store(group=group, name="shaped", node=BasicShapedRewardNet(**defaults)) - cs.store(group=group, name="ensemble", node=RewardEnsemble(**defaults)) + cs.store(group=group, name="basic", node=BasicRewardNet(environment=default_environment)) + cs.store(group=group, name="shaped", node=BasicShapedRewardNet(environment=default_environment)) + cs.store(group=group, name="ensemble", node=RewardEnsemble(environment=default_environment)) diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index 9c7d153d1..0dee4e83a 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -1,8 +1,9 @@ from __future__ import annotations + import dataclasses import pathlib import typing -from typing import Optional, Mapping, Any +from typing import Optional, Union if typing.TYPE_CHECKING: from stable_baselines3.common.vec_env import VecEnv @@ -85,9 +86,9 @@ def make(environment: VecEnv, path: pathlib.Path) -> PPO: return serialize.load_stable_baselines_model(sb3.PPO, str(to_absolute_path(path)), environment) -def register_configs(group: str = "rl_algorithm"): +def register_configs(group: str = "rl_algorithm", default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING): from hydra.core.config_store import ConfigStore cs = ConfigStore.instance() - cs.store(name="ppo", group=group, node=PPO) - cs.store(name="ppo_on_disk", group=group, node=PPOOnDisk) + cs.store(name="ppo", group=group, node=PPO(environment=default_environment, policy=policy_cfg.ActorCriticPolicy(environment=default_environment))) + cs.store(name="ppo_on_disk", group=group, node=PPOOnDisk(environment=default_environment)) diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py index c942a48be..68aa89489 100644 --- a/src/imitation_cli/utils/trajectories.py +++ b/src/imitation_cli/utils/trajectories.py @@ -2,18 +2,18 @@ import dataclasses import pathlib import typing +from typing import Sequence, Optional, Union if typing.TYPE_CHECKING: from stable_baselines3.common.policies import BasePolicy from imitation.data.types import Trajectory - from typing import Sequence import numpy as np from hydra.core.config_store import ConfigStore from hydra.utils import call from omegaconf import MISSING -from imitation_cli.utils import policy, randomness +from imitation_cli.utils import policy, randomness, environment as environment_cfg @dataclasses.dataclass @@ -61,7 +61,7 @@ def make( ) -def register_configs(group: str): +def register_configs(group: str, default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING): cs = ConfigStore.instance() cs.store(group=group, name="on_disk", node=OnDisk) - cs.store(group=group, name="generated", node=Generated) + cs.store(group=group, name="generated", node=Generated(expert_policy=policy.Config(environment=default_environment))) From e37fd4ca7087924c05db6703098f54921d2807a5 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 16:23:18 +0200 Subject: [PATCH 16/25] Move registering the expert policy as a sub-call to registering the trajectories. --- src/imitation_cli/airl.py | 1 - src/imitation_cli/utils/trajectories.py | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 438d9038f..3ac2b4235 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -41,7 +41,6 @@ class RunConfig: cs = ConfigStore.instance() environment_cfg.register_configs("environment") trajectories.register_configs("airl/demonstrations", "${environment}") -policy.register_configs("airl/demonstrations/expert_policy", "${environment}") rl_algorithm.register_configs("airl/gen_algo", "${environment}") reward_network.register_configs("airl/reward_net", "${environment}") policy_evaluation.register_configs("evaluation", "${environment}") diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py index 68aa89489..f4f0b1cf8 100644 --- a/src/imitation_cli/utils/trajectories.py +++ b/src/imitation_cli/utils/trajectories.py @@ -64,4 +64,5 @@ def make( def register_configs(group: str, default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING): cs = ConfigStore.instance() cs.store(group=group, name="on_disk", node=OnDisk) - cs.store(group=group, name="generated", node=Generated(expert_policy=policy.Config(environment=default_environment))) + cs.store(group=group, name="generated", node=Generated) + policy.register_configs(group=group + "/expert_policy", default_environment=default_environment) From 59c2b08d715014b0c00bff876b1a3b8e37a3b6ce Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 16:32:57 +0200 Subject: [PATCH 17/25] Update the airl_sweep.yaml --- src/imitation_cli/config/airl_sweep.yaml | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/imitation_cli/config/airl_sweep.yaml b/src/imitation_cli/config/airl_sweep.yaml index 38c069a0e..2c345316a 100644 --- a/src/imitation_cli/config/airl_sweep.yaml +++ b/src/imitation_cli/config/airl_sweep.yaml @@ -1,15 +1,24 @@ defaults: - - airl - - expert_trajs: generated - - gen_algo: on_disk + - airl_run_base + - environment: gym_env + - airl/reward_net: shaped + - airl/gen_algo: ppo + - evaluation: default_evaluation + - airl/demonstrations: generated + - airl/demonstrations/expert_policy: random - _self_ -gen_algo: - environment: - max_episode_steps: 32 +total_timesteps: 40000 +checkpoint_interval: 1 + +airl: + demo_batch_size: 128 + demonstrations: + total_timesteps: 10 hydra: mode: MULTIRUN sweeper: params: environment: glob(*,exclude=gym_env) + airl/reward_net: glob(*) From 244eed06e2ea6e9727a009b1838c135e6bf22215 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 16:43:41 +0200 Subject: [PATCH 18/25] Update the airl_sweep.yaml --- .../{airl_sweep.yaml => airl_sweep_env_and_rewardnet.yaml} | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) rename src/imitation_cli/config/{airl_sweep.yaml => airl_sweep_env_and_rewardnet.yaml} (82%) diff --git a/src/imitation_cli/config/airl_sweep.yaml b/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml similarity index 82% rename from src/imitation_cli/config/airl_sweep.yaml rename to src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml index 2c345316a..fb947ee61 100644 --- a/src/imitation_cli/config/airl_sweep.yaml +++ b/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml @@ -20,5 +20,5 @@ hydra: mode: MULTIRUN sweeper: params: - environment: glob(*,exclude=gym_env) - airl/reward_net: glob(*) + environment: cartpole,pendulum + airl/reward_net: basic,shaped,ensemble From ff0028fa72305faa7899d349518be9d377b48bc3 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 24 Apr 2023 16:45:53 +0200 Subject: [PATCH 19/25] Add airl_optuna.yaml --- src/imitation_cli/config/airl_optuna.yaml | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 src/imitation_cli/config/airl_optuna.yaml diff --git a/src/imitation_cli/config/airl_optuna.yaml b/src/imitation_cli/config/airl_optuna.yaml new file mode 100644 index 000000000..cf4aceb97 --- /dev/null +++ b/src/imitation_cli/config/airl_optuna.yaml @@ -0,0 +1,25 @@ +defaults: + - airl_run_base + - environment: gym_env + - airl/reward_net: shaped + - airl/gen_algo: ppo + - evaluation: default_evaluation + - airl/demonstrations: generated + - airl/demonstrations/expert_policy: random + - override hydra/sweeper: optuna + - _self_ + +total_timesteps: 40000 +checkpoint_interval: 1 + +airl: + demo_batch_size: 128 + demonstrations: + total_timesteps: 10 + +hydra: + mode: MULTIRUN + sweeper: + params: + environment: cartpole,pendulum + airl/reward_net: basic,shaped,ensemble From ca632a8ef1b053fd70442b1dbf5bb224651882fb Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 25 Apr 2023 15:54:12 +0200 Subject: [PATCH 20/25] Formatting, typing and documentation fixes. Also the implicit seed dependency was pulled out of the utils and made explicit in airl.py --- src/imitation_cli/airl.py | 12 +-- src/imitation_cli/config/airl_optuna.yaml | 2 +- src/imitation_cli/config/airl_run.yaml | 2 +- .../config/airl_sweep_env_and_rewardnet.yaml | 2 +- src/imitation_cli/utils/__init__.py | 1 + .../utils/activation_function_class.py | 11 ++- src/imitation_cli/utils/environment.py | 38 +++++--- .../utils/feature_extractor_class.py | 11 ++- src/imitation_cli/utils/optimizer_class.py | 7 ++ src/imitation_cli/utils/policy.py | 97 +++++++++++++------ src/imitation_cli/utils/policy_evaluation.py | 31 ++++-- src/imitation_cli/utils/randomness.py | 10 +- src/imitation_cli/utils/reward_network.py | 80 ++++++++++----- src/imitation_cli/utils/rl_algorithm.py | 62 +++++++++--- src/imitation_cli/utils/schedule.py | 7 ++ src/imitation_cli/utils/trajectories.py | 35 +++++-- 16 files changed, 305 insertions(+), 103 deletions(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 3ac2b4235..103065e49 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -14,8 +14,8 @@ from imitation_cli.algorithm_configurations import airl as airl_cfg from imitation_cli.utils import environment as environment_cfg from imitation_cli.utils import ( - policy, policy_evaluation, + randomness, reward_network, rl_algorithm, trajectories, @@ -26,7 +26,7 @@ class RunConfig: """Config for running AIRL.""" - seed: int = 0 + rng: randomness.Config = randomness.Config(seed=0) total_timesteps: int = int(1e6) checkpoint_interval: int = 0 @@ -39,11 +39,11 @@ class RunConfig: cs = ConfigStore.instance() -environment_cfg.register_configs("environment") -trajectories.register_configs("airl/demonstrations", "${environment}") -rl_algorithm.register_configs("airl/gen_algo", "${environment}") +environment_cfg.register_configs("environment", "${rng}") +trajectories.register_configs("airl/demonstrations", "${environment}", "${rng}") +rl_algorithm.register_configs("airl/gen_algo", "${environment}", "${rng.seed}") reward_network.register_configs("airl/reward_net", "${environment}") -policy_evaluation.register_configs("evaluation", "${environment}") +policy_evaluation.register_configs("evaluation", "${environment}", "${rng}") cs.store( name="airl_run_base", diff --git a/src/imitation_cli/config/airl_optuna.yaml b/src/imitation_cli/config/airl_optuna.yaml index cf4aceb97..9eadb47e9 100644 --- a/src/imitation_cli/config/airl_optuna.yaml +++ b/src/imitation_cli/config/airl_optuna.yaml @@ -22,4 +22,4 @@ hydra: sweeper: params: environment: cartpole,pendulum - airl/reward_net: basic,shaped,ensemble + airl/reward_net: basic,shaped,small_ensemble diff --git a/src/imitation_cli/config/airl_run.yaml b/src/imitation_cli/config/airl_run.yaml index d1bdcd0fd..7cbece85e 100644 --- a/src/imitation_cli/config/airl_run.yaml +++ b/src/imitation_cli/config/airl_run.yaml @@ -15,4 +15,4 @@ checkpoint_interval: 1 airl: demo_batch_size: 128 demonstrations: - total_timesteps: 10 \ No newline at end of file + total_timesteps: 10 diff --git a/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml b/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml index fb947ee61..141b03c59 100644 --- a/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml +++ b/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml @@ -21,4 +21,4 @@ hydra: sweeper: params: environment: cartpole,pendulum - airl/reward_net: basic,shaped,ensemble + airl/reward_net: basic,shaped,small_ensemble diff --git a/src/imitation_cli/utils/__init__.py b/src/imitation_cli/utils/__init__.py index e69de29bb..f3dc34d7c 100644 --- a/src/imitation_cli/utils/__init__.py +++ b/src/imitation_cli/utils/__init__.py @@ -0,0 +1 @@ +"""Configurations to be used as ingredient to algorithm configurations.""" diff --git a/src/imitation_cli/utils/activation_function_class.py b/src/imitation_cli/utils/activation_function_class.py index 1ba225061..2b5852b63 100644 --- a/src/imitation_cli/utils/activation_function_class.py +++ b/src/imitation_cli/utils/activation_function_class.py @@ -1,3 +1,4 @@ +"""Classes for configuring activation functions.""" import dataclasses from hydra.core.config_store import ConfigStore @@ -5,6 +6,8 @@ @dataclasses.dataclass class Config: + """Base class for activation function configs.""" + # Note: we don't define _target_ here so in the subclasses it can be defined last. # This is the same pattern we use as in schedule.py. pass @@ -12,6 +15,8 @@ class Config: @dataclasses.dataclass class TanH(Config): + """Config for TanH activation function.""" + _target_: str = "imitation_cli.utils.activation_function_class.TanH.make" @staticmethod @@ -23,6 +28,8 @@ def make() -> type: @dataclasses.dataclass class ReLU(Config): + """Config for ReLU activation function.""" + _target_: str = "imitation_cli.utils.activation_function_class.ReLU.make" @staticmethod @@ -34,10 +41,12 @@ def make() -> type: @dataclasses.dataclass class LeakyReLU(Config): + """Config for LeakyReLU activation function.""" + _target_: str = "imitation_cli.utils.activation_function_class.LeakyReLU.make" @staticmethod - def make() -> type: + def make() -> type: import torch return torch.nn.LeakyReLU diff --git a/src/imitation_cli/utils/environment.py b/src/imitation_cli/utils/environment.py index 1d035b3d1..733fdb0d6 100644 --- a/src/imitation_cli/utils/environment.py +++ b/src/imitation_cli/utils/environment.py @@ -1,7 +1,9 @@ +"""Configuration for Gym environments.""" from __future__ import annotations + import dataclasses import typing -from typing import Optional +from typing import Optional, Union, cast if typing.TYPE_CHECKING: from stable_baselines3.common.vec_env import VecEnv @@ -15,18 +17,21 @@ @dataclasses.dataclass class Config: + """Configuration for Gym environments.""" + _target_: str = "imitation_cli.utils.environment.Config.make" env_name: str = MISSING # The environment to train on n_envs: int = 8 # number of environments in VecEnv - parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv TODO: when setting this to true this is really slow for some reason + # TODO: when setting this to true this is really slow for some reason + parallel: bool = False # Use SubprocVecEnv rather than DummyVecEnv max_episode_steps: int = MISSING # Set to positive int to limit episode horizons env_make_kwargs: dict = dataclasses.field( - default_factory=dict + default_factory=dict, ) # The kwargs passed to `spec.make`. - rng: randomness.Config = randomness.Config() + rng: randomness.Config = MISSING @staticmethod - def make(log_dir: Optional[str]=None, **kwargs) -> VecEnv: + def make(log_dir: Optional[str] = None, **kwargs) -> VecEnv: from imitation.util import util return util.make_vec_env(log_dir=log_dir, **kwargs) @@ -38,13 +43,24 @@ def make_rollout_venv(environment_config: Config) -> VecEnv: return call( environment_config, log_dir=None, - post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)] + post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], ) -def register_configs(group: str): +def register_configs( + group: str, + default_rng: Union[randomness.Config, str] = MISSING, +): + default_rng = cast(randomness.Config, default_rng) cs = ConfigStore.instance() - cs.store(group=group, name="gym_env", node=Config) - cs.store(group=group, name="cartpole", node=Config(env_name="CartPole-v0", max_episode_steps=500)) - cs.store(group=group, name="pendulum", node=Config(env_name="Pendulum-v1", max_episode_steps=500)) - + cs.store(group=group, name="gym_env", node=Config(rng=default_rng)) + cs.store( + group=group, + name="cartpole", + node=Config(env_name="CartPole-v0", max_episode_steps=500, rng=default_rng), + ) + cs.store( + group=group, + name="pendulum", + node=Config(env_name="Pendulum-v1", max_episode_steps=500, rng=default_rng), + ) diff --git a/src/imitation_cli/utils/feature_extractor_class.py b/src/imitation_cli/utils/feature_extractor_class.py index 33a5348d2..27b33cbb5 100644 --- a/src/imitation_cli/utils/feature_extractor_class.py +++ b/src/imitation_cli/utils/feature_extractor_class.py @@ -1,3 +1,4 @@ +"""Register Hydra configs for stable_baselines3 feature extractors.""" import dataclasses from hydra.core.config_store import ConfigStore @@ -6,12 +7,18 @@ @dataclasses.dataclass class Config: + """Base config for stable_baselines3 feature extractors.""" + _target_: str = MISSING @dataclasses.dataclass class FlattenExtractorConfig(Config): - _target_: str = "imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make" + """Config for FlattenExtractor.""" + + _target_: str = ( + "imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make" + ) @staticmethod def make() -> type: @@ -22,6 +29,8 @@ def make() -> type: @dataclasses.dataclass class NatureCNNConfig(Config): + """Config for NatureCNN.""" + _target_: str = "imitation_cli.utils.feature_extractor_class.NatureCNNConfig.make" @staticmethod diff --git a/src/imitation_cli/utils/optimizer_class.py b/src/imitation_cli/utils/optimizer_class.py index 9e8ae5b9c..0fd25da95 100644 --- a/src/imitation_cli/utils/optimizer_class.py +++ b/src/imitation_cli/utils/optimizer_class.py @@ -1,3 +1,4 @@ +"""Register optimizer classes with Hydra.""" import dataclasses from hydra.core.config_store import ConfigStore @@ -6,11 +7,15 @@ @dataclasses.dataclass class Config: + """Base config for optimizer classes.""" + _target_: str = MISSING @dataclasses.dataclass class Adam(Config): + """Config for Adam optimizer class.""" + _target_: str = "imitation_cli.utils.optimizer_class.Adam.make" @staticmethod @@ -22,6 +27,8 @@ def make() -> type: @dataclasses.dataclass class SGD(Config): + """Config for SGD optimizer class.""" + _target_: str = "imitation_cli.utils.optimizer_class.SGD.make" @staticmethod diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index c7ec558c7..3cc460b54 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -1,9 +1,10 @@ +"""Configurable policies for SB3 Base Policies.""" "" from __future__ import annotations import dataclasses import pathlib import typing -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast if typing.TYPE_CHECKING: from stable_baselines3.common.vec_env import VecEnv @@ -13,32 +14,41 @@ from hydra.utils import call from omegaconf import MISSING -from imitation_cli.utils import \ - activation_function_class as activation_function_class_cfg, \ - environment as environment_cfg,\ - feature_extractor_class as feature_extractor_class_cfg,\ - optimizer_class as optimizer_class_cfg, \ - schedule +from imitation_cli.utils import activation_function_class as act_fun_class_cfg +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import feature_extractor_class as feature_extractor_class_cfg +from imitation_cli.utils import optimizer_class as optimizer_class_cfg +from imitation_cli.utils import schedule @dataclasses.dataclass class Config: + """Base configuration for policies.""" + _target_: str = MISSING environment: environment_cfg.Config = MISSING @dataclasses.dataclass class Random(Config): + """Configuration for a random policy.""" + _target_: str = "imitation_cli.utils.policy.Random.make" @staticmethod def make(environment: VecEnv) -> BasePolicy: from imitation.policies import base - return base.RandomPolicy(environment.observation_space, environment.action_space) + + return base.RandomPolicy( + environment.observation_space, + environment.action_space, + ) @dataclasses.dataclass class ZeroPolicy(Config): + """Configuration for a zero policy.""" + _target_: str = "imitation_cli.utils.policy.ZeroPolicy.make" @staticmethod @@ -50,10 +60,12 @@ def make(environment: VecEnv) -> BasePolicy: @dataclasses.dataclass class ActorCriticPolicy(Config): + """Configuration for a stable-baselines3 ActorCriticPolicy.""" + _target_: str = "imitation_cli.utils.policy.ActorCriticPolicy.make" - lr_schedule: schedule.Config = schedule.FixedSchedule(3e-4) # TODO: make sure this is copied from the rl_algorithm instead + lr_schedule: schedule.Config = schedule.FixedSchedule(3e-4) net_arch: Optional[Dict[str, List[int]]] = None - activation_fn: activation_function_class_cfg.Config = activation_function_class_cfg.TanH() + activation_fn: act_fun_class_cfg.Config = act_fun_class_cfg.TanH() ortho_init: bool = True use_sde: bool = False log_std_init: float = 0.0 @@ -71,7 +83,7 @@ class ActorCriticPolicy(Config): @staticmethod def make_args( - activation_fn: activation_function_class_cfg.Config, + activation_fn: act_fun_class_cfg.Config, features_extractor_class: feature_extractor_class_cfg.Config, optimizer_class: optimizer_class_cfg.Config, **kwargs, @@ -103,47 +115,58 @@ def make( @dataclasses.dataclass class Loaded(Config): - type: str = "PPO" # The SB3 policy class. Only SAC and PPO supported as of now + """Base configuration for a policy that is loaded from somewhere.""" + + policy_type: str = ( + "PPO" # The SB3 policy class. Only SAC and PPO supported as of now + ) @staticmethod - def type_to_class(type: str): + def type_to_class(policy_type: str): import stable_baselines3 as sb3 - type = type.lower() - if type == "ppo": + policy_type = policy_type.lower() + if policy_type == "ppo": return sb3.PPO - if type == "ppo": + if policy_type == "ppo": return sb3.SAC - raise ValueError(f"Unknown policy type {type}") + raise ValueError(f"Unknown policy type {policy_type}") @dataclasses.dataclass class PolicyOnDisk(Loaded): + """Configuration for a policy that is loaded from a path on disk.""" + _target_: str = "imitation_cli.utils.policy.PolicyOnDisk.make" path: pathlib.Path = MISSING @staticmethod def make( environment: VecEnv, - type: str, + policy_type: str, path: pathlib.Path, ) -> BasePolicy: from imitation.policies import serialize return serialize.load_stable_baselines_model( - Loaded.type_to_class(type), str(path), environment + Loaded.type_to_class(policy_type), + str(path), + environment, ).policy @dataclasses.dataclass class PolicyFromHuggingface(Loaded): + """Configuration for a policy that is loaded from a HuggingFace model.""" + _target_: str = "imitation_cli.utils.policy.PolicyFromHuggingface.make" + _recursive_: bool = False organization: str = "HumanCompatibleAI" @staticmethod def make( - environment: VecEnv, - type: str, + environment: environment_cfg.Config, + policy_type: str, organization: str, ) -> BasePolicy: import huggingface_sb3 as hfsb3 @@ -151,20 +174,40 @@ def make( from imitation.policies import serialize model_name = hfsb3.ModelName( - type.lower(), hfsb3.EnvironmentName(environment.gym_id) + policy_type.lower(), + hfsb3.EnvironmentName(environment.env_name), ) repo_id = hfsb3.ModelRepoId(organization, model_name) filename = hfsb3.load_from_hub(repo_id, model_name.filename) model = serialize.load_stable_baselines_model( - Loaded.type_to_class(type), filename, environment + Loaded.type_to_class(policy_type), + filename, + call(environment), ) return model.policy -def register_configs(group: str, default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING): +def register_configs( + group: str, + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, +): + default_environment = cast(environment_cfg.Config, default_environment) cs = ConfigStore.instance() cs.store(group=group, name="random", node=Random(environment=default_environment)) cs.store(group=group, name="zero", node=ZeroPolicy(environment=default_environment)) - cs.store(group=group, name="on_disk", node=PolicyOnDisk(environment=default_environment)) - cs.store(group=group, name="from_huggingface", node=PolicyFromHuggingface(environment=default_environment)) - cs.store(group=group, name="actor_critic", node=ActorCriticPolicy(environment=default_environment)) + cs.store( + group=group, + name="on_disk", + node=PolicyOnDisk(environment=default_environment), + ) + cs.store( + group=group, + name="from_huggingface", + node=PolicyFromHuggingface(environment=default_environment), + ) + cs.store( + group=group, + name="actor_critic", + node=ActorCriticPolicy(environment=default_environment), + ) + schedule.register_configs(group=group + "/lr_schedule") diff --git a/src/imitation_cli/utils/policy_evaluation.py b/src/imitation_cli/utils/policy_evaluation.py index 8678f27e8..192750049 100644 --- a/src/imitation_cli/utils/policy_evaluation.py +++ b/src/imitation_cli/utils/policy_evaluation.py @@ -3,7 +3,7 @@ import dataclasses import typing -from typing import Optional, Union +from typing import Optional, Union, cast from hydra.utils import call from omegaconf import MISSING @@ -12,7 +12,7 @@ from imitation_cli.utils import randomness if typing.TYPE_CHECKING: - from stable_baselines3.common import base_class, policies, vec_env + from stable_baselines3.common import base_class, policies @dataclasses.dataclass @@ -21,15 +21,34 @@ class Config: environment: environment_cfg.Config = MISSING n_episodes_eval: int = 50 - rng: randomness.Config = randomness.Config() + rng: randomness.Config = MISSING -def register_configs(group: str, default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING) -> None: +def register_configs( + group: str, + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, + default_rng: Optional[Union[randomness.Config, str]] = MISSING, +) -> None: from hydra.core.config_store import ConfigStore + default_environment = cast(environment_cfg.Config, default_environment) + default_rng = cast(randomness.Config, default_rng) + cs = ConfigStore.instance() - cs.store(name="default_evaluation", group=group, node=Config(environment=default_environment)) - cs.store(name="fast_evaluation", group=group, node=Config(environment=default_environment, n_episodes_eval=2)) + cs.store( + name="default_evaluation", + group=group, + node=Config(environment=default_environment, rng=default_rng), + ) + cs.store( + name="fast_evaluation", + group=group, + node=Config( + environment=default_environment, + rng=default_rng, + n_episodes_eval=2, + ), + ) def eval_policy( diff --git a/src/imitation_cli/utils/randomness.py b/src/imitation_cli/utils/randomness.py index 1b83104a1..68c4feeb5 100644 --- a/src/imitation_cli/utils/randomness.py +++ b/src/imitation_cli/utils/randomness.py @@ -1,15 +1,21 @@ +"""Utilities for seeding random number generators.""" from __future__ import annotations + import dataclasses import typing +from omegaconf import MISSING + if typing.TYPE_CHECKING: import numpy as np @dataclasses.dataclass class Config: + """Configuration for seeding random number generators.""" + _target_: str = "imitation_cli.utils.randomness.Config.make" - seed: int = "${seed}" # type: ignore + seed: int = MISSING @staticmethod def make(seed: int) -> np.random.Generator: @@ -19,4 +25,4 @@ def make(seed: int) -> np.random.Generator: np.random.seed(seed) torch.manual_seed(seed) - return np.random.default_rng(seed) \ No newline at end of file + return np.random.default_rng(seed) diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py index ede4d6b9a..a2f6ac25d 100644 --- a/src/imitation_cli/utils/reward_network.py +++ b/src/imitation_cli/utils/reward_network.py @@ -1,8 +1,9 @@ +"""Reward network configuration.""" from __future__ import annotations import dataclasses import typing -from typing import Optional, Union +from typing import Optional, Union, cast if typing.TYPE_CHECKING: from stable_baselines3.common.vec_env import VecEnv @@ -12,17 +13,21 @@ from hydra.utils import call from omegaconf import MISSING -import imitation_cli.utils.environment as environment_cg +import imitation_cli.utils.environment as environment_cfg @dataclasses.dataclass class Config: + """Base configuration for reward networks.""" + _target_: str = MISSING - environment: environment_cg.Config = MISSING + environment: environment_cfg.Config = MISSING @dataclasses.dataclass class BasicRewardNet(Config): + """Configuration for a basic reward network.""" + _target_: str = "imitation_cli.utils.reward_network.BasicRewardNet.make" use_state: bool = True use_action: bool = True @@ -31,11 +36,7 @@ class BasicRewardNet(Config): normalize_input_layer: bool = True @staticmethod - def make( - environment: VecEnv, - normalize_input_layer: bool, - **kwargs - ) -> RewardNet: + def make(environment: VecEnv, normalize_input_layer: bool, **kwargs) -> RewardNet: from imitation.rewards import reward_nets from imitation.util import networks @@ -45,24 +46,23 @@ def make( **kwargs, ) if normalize_input_layer: - reward_net = reward_nets.NormalizedRewardNet( + return reward_nets.NormalizedRewardNet( reward_net, networks.RunningNorm, ) - return reward_net + else: + return reward_net @dataclasses.dataclass class BasicShapedRewardNet(BasicRewardNet): + """Configuration for a basic shaped reward network.""" + _target_: str = "imitation_cli.utils.reward_network.BasicShapedRewardNet.make" discount_factor: float = 0.99 @staticmethod - def make( - environment: VecEnv, - normalize_input_layer: bool, - **kwargs - ) -> RewardNet: + def make(environment: VecEnv, normalize_input_layer: bool, **kwargs) -> RewardNet: from imitation.rewards import reward_nets from imitation.util import networks @@ -72,42 +72,70 @@ def make( **kwargs, ) if normalize_input_layer: - reward_net = reward_nets.NormalizedRewardNet( + return reward_nets.NormalizedRewardNet( reward_net, networks.RunningNorm, ) - return reward_net + else: + return reward_net @dataclasses.dataclass class RewardEnsemble(Config): + """Configuration for a reward ensemble.""" + _target_: str = "imitation_cli.utils.reward_network.RewardEnsemble.make" + _recursive_: bool = False ensemble_size: int = MISSING ensemble_member_config: BasicRewardNet = MISSING add_std_alpha: Optional[float] = None @staticmethod def make( - environment: VecEnv, + environment: environment_cfg.Config, ensemble_member_config: BasicRewardNet, add_std_alpha: Optional[float], + ensemble_size: int, ) -> RewardNet: from imitation.rewards import reward_nets - members = [call(ensemble_member_config)] + venv = call(environment) reward_net = reward_nets.RewardEnsemble( - environment.observation_space, environment.action_space, members + venv.observation_space, + venv.action_space, + [call(ensemble_member_config) for _ in range(ensemble_size)], ) if add_std_alpha is not None: - reward_net = reward_nets.AddSTDRewardWrapper( + return reward_nets.AddSTDRewardWrapper( reward_net, default_alpha=add_std_alpha, ) - return reward_net + else: + return reward_net -def register_configs(group: str, default_environment: Optional[Union[environment_cg.Config, str]] = MISSING): +def register_configs( + group: str, + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, +): + default_environment = cast(environment_cfg.Config, default_environment) cs = ConfigStore.instance() - cs.store(group=group, name="basic", node=BasicRewardNet(environment=default_environment)) - cs.store(group=group, name="shaped", node=BasicShapedRewardNet(environment=default_environment)) - cs.store(group=group, name="ensemble", node=RewardEnsemble(environment=default_environment)) + cs.store( + group=group, + name="basic", + node=BasicRewardNet(environment=default_environment), + ) + cs.store( + group=group, + name="shaped", + node=BasicShapedRewardNet(environment=default_environment), + ) + cs.store( + group=group, + name="small_ensemble", + node=RewardEnsemble( + environment=default_environment, + ensemble_size=5, + ensemble_member_config=BasicRewardNet(environment=default_environment), + ), + ) diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index 0dee4e83a..8186eb64c 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -1,14 +1,14 @@ +"""Configurable RL algorithms.""" from __future__ import annotations import dataclasses import pathlib import typing -from typing import Optional, Union +from typing import Optional, Union, cast if typing.TYPE_CHECKING: + import stable_baselines3 as sb3 from stable_baselines3.common.vec_env import VecEnv - from stable_baselines3.common.policies import BasePolicy - from stable_baselines3 import PPO from hydra.utils import call, to_absolute_path from omegaconf import MISSING @@ -20,14 +20,19 @@ @dataclasses.dataclass class Config: + """Base configuration for RL algorithms.""" + _target_: str = MISSING environment: environment_cfg.Config = MISSING @dataclasses.dataclass class PPO(Config): + """Configuration for a stable-baselines3 PPO algorithm.""" + _target_: str = "imitation_cli.utils.rl_algorithm.PPO.make" - # We disable recursive instantiation, so we can just make the arguments of the policy but not the policy itself + # We disable recursive instantiation, so we can just make the + # arguments of the policy but not the policy itself _recursive_: bool = False policy: policy_cfg.ActorCriticPolicy = policy_cfg.ActorCriticPolicy() learning_rate: schedule.Config = schedule.FixedSchedule(3e-4) @@ -47,7 +52,7 @@ class PPO(Config): target_kl: Optional[float] = None tensorboard_log: Optional[str] = None verbose: int = 0 - seed: int = "${seed}" # type: ignore + seed: int = MISSING device: str = "auto" @staticmethod @@ -57,10 +62,12 @@ def make( learning_rate: schedule.Config, clip_range: schedule.Config, **kwargs, - ) -> PPO: + ) -> sb3.PPO: import stable_baselines3 as sb3 - policy_kwargs = policy_cfg.ActorCriticPolicy.make_args(**typing.cast(dict, policy)) + policy_kwargs = policy_cfg.ActorCriticPolicy.make_args( + **typing.cast(dict, policy), + ) del policy_kwargs["use_sde"] del policy_kwargs["lr_schedule"] return sb3.PPO( @@ -75,20 +82,49 @@ def make( @dataclasses.dataclass class PPOOnDisk(Config): + """Configuration for a stable-baselines3 PPO algorithm loaded from disk.""" + _target_: str = "imitation_cli.utils.rl_algorithm.PPOOnDisk.make" path: pathlib.Path = MISSING @staticmethod - def make(environment: VecEnv, path: pathlib.Path) -> PPO: - from imitation.policies import serialize + def make(environment: VecEnv, path: pathlib.Path) -> sb3.PPO: import stable_baselines3 as sb3 - return serialize.load_stable_baselines_model(sb3.PPO, str(to_absolute_path(path)), environment) + from imitation.policies import serialize + + return serialize.load_stable_baselines_model( + sb3.PPO, + str(to_absolute_path(str(path))), + environment, + ) -def register_configs(group: str = "rl_algorithm", default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING): +def register_configs( + group: str = "rl_algorithm", + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, + default_seed: Optional[Union[int, str]] = MISSING, +): from hydra.core.config_store import ConfigStore + default_environment = cast(environment_cfg.Config, default_environment) + default_seed = cast(int, default_seed) + cs = ConfigStore.instance() - cs.store(name="ppo", group=group, node=PPO(environment=default_environment, policy=policy_cfg.ActorCriticPolicy(environment=default_environment))) - cs.store(name="ppo_on_disk", group=group, node=PPOOnDisk(environment=default_environment)) + cs.store( + name="ppo", + group=group, + node=PPO( + environment=default_environment, + policy=policy_cfg.ActorCriticPolicy(environment=default_environment), + seed=default_seed, + ), + ) + cs.store( + name="ppo_on_disk", + group=group, + node=PPOOnDisk(environment=default_environment), + ) + + schedule.register_configs(group=group + "/learning_rate") + schedule.register_configs(group=group + "/clip_range") diff --git a/src/imitation_cli/utils/schedule.py b/src/imitation_cli/utils/schedule.py index 0ab279f91..08091c6a0 100644 --- a/src/imitation_cli/utils/schedule.py +++ b/src/imitation_cli/utils/schedule.py @@ -1,3 +1,4 @@ +"""Configurations for stable_baselines3 schedules.""" import dataclasses from hydra.core.config_store import ConfigStore @@ -6,6 +7,8 @@ @dataclasses.dataclass class Config: + """Base configuration for schedules.""" + # Note: we don't define _target_ here so in the subclasses it can be defined last. # This way we can instantiate a fixed schedule with `FixedSchedule(0.1)`. # If we defined _target_ here, then we would have to instantiate a fixed schedule @@ -15,12 +18,16 @@ class Config: @dataclasses.dataclass class FixedSchedule(Config): + """Configuration for a fixed schedule.""" + val: float = MISSING _target_: str = "stable_baselines3.common.utils.constant_fn" @dataclasses.dataclass class LinearSchedule(Config): + """Configuration for a linear schedule.""" + start: float = MISSING end: float = MISSING end_fraction: float = MISSING diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py index f4f0b1cf8..26d018dbd 100644 --- a/src/imitation_cli/utils/trajectories.py +++ b/src/imitation_cli/utils/trajectories.py @@ -1,8 +1,10 @@ +"""Configurable trajectory sources.""" from __future__ import annotations + import dataclasses import pathlib import typing -from typing import Sequence, Optional, Union +from typing import Optional, Sequence, Union, cast if typing.TYPE_CHECKING: from stable_baselines3.common.policies import BasePolicy @@ -13,16 +15,21 @@ from hydra.utils import call from omegaconf import MISSING -from imitation_cli.utils import policy, randomness, environment as environment_cfg +from imitation_cli.utils import environment as environment_cfg +from imitation_cli.utils import policy, randomness @dataclasses.dataclass class Config: + """Base configuration for trajectory sources.""" + _target_: str = MISSING @dataclasses.dataclass class OnDisk(Config): + """Configuration for loading trajectories from disk.""" + _target_: str = "imitation_cli.utils.trajectories.OnDisk.make" path: pathlib.Path = MISSING @@ -35,11 +42,15 @@ def make(path: pathlib.Path) -> Sequence[Trajectory]: @dataclasses.dataclass class Generated(Config): + """Configuration for generating trajectories from an expert policy.""" + _target_: str = "imitation_cli.utils.trajectories.Generated.make" - _recursive_: bool = False # We disable the recursive flag, so we can extract the environment from the expert policy + # Note: We disable the recursive flag, so we can extract + # the environment from the expert policy + _recursive_: bool = False total_timesteps: int = MISSING expert_policy: policy.Config = MISSING - rng: randomness.Config = randomness.Config() + rng: randomness.Config = MISSING @staticmethod def make( @@ -61,8 +72,18 @@ def make( ) -def register_configs(group: str, default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING): +def register_configs( + group: str, + default_environment: Optional[Union[environment_cfg.Config, str]] = MISSING, + default_rng: Optional[Union[randomness.Config, str]] = MISSING, +): + default_environment = cast(environment_cfg.Config, default_environment) + default_rng = cast(randomness.Config, default_rng) + cs = ConfigStore.instance() cs.store(group=group, name="on_disk", node=OnDisk) - cs.store(group=group, name="generated", node=Generated) - policy.register_configs(group=group + "/expert_policy", default_environment=default_environment) + cs.store(group=group, name="generated", node=Generated(rng=default_rng)) + policy.register_configs( + group=group + "/expert_policy", + default_environment=default_environment, + ) From 17388224d6036577bf0090fa674632c4453f37a8 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 26 Apr 2023 20:39:11 +0200 Subject: [PATCH 21/25] Swtich from Hydra call to hydra instantiate. --- src/imitation_cli/airl.py | 4 ++-- src/imitation_cli/utils/environment.py | 4 ++-- src/imitation_cli/utils/policy.py | 10 +++++----- src/imitation_cli/utils/reward_network.py | 6 +++--- src/imitation_cli/utils/rl_algorithm.py | 8 ++++---- src/imitation_cli/utils/trajectories.py | 8 ++++---- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 103065e49..374a4ba17 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -7,7 +7,7 @@ import hydra import torch as th from hydra.core.config_store import ConfigStore -from hydra.utils import call +from hydra.utils import instantiate from omegaconf import MISSING from imitation.policies import serialize @@ -65,7 +65,7 @@ def run_airl(cfg: RunConfig) -> Dict[str, Any]: from imitation.data import rollout from imitation.data.types import TrajectoryWithRew - trainer: airl.AIRL = call(cfg.airl) + trainer: airl.AIRL = instantiate(cfg.airl) checkpoints_path = pathlib.Path("checkpoints") diff --git a/src/imitation_cli/utils/environment.py b/src/imitation_cli/utils/environment.py index 733fdb0d6..d824965c9 100644 --- a/src/imitation_cli/utils/environment.py +++ b/src/imitation_cli/utils/environment.py @@ -9,7 +9,7 @@ from stable_baselines3.common.vec_env import VecEnv from hydra.core.config_store import ConfigStore -from hydra.utils import call +from hydra.utils import instantiate from omegaconf import MISSING from imitation_cli.utils import randomness @@ -40,7 +40,7 @@ def make(log_dir: Optional[str] = None, **kwargs) -> VecEnv: def make_rollout_venv(environment_config: Config) -> VecEnv: from imitation.data import wrappers - return call( + return instantiate( environment_config, log_dir=None, post_wrappers=[lambda env, i: wrappers.RolloutInfoWrapper(env)], diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index 3cc460b54..d749c1197 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -11,7 +11,7 @@ from stable_baselines3.common.policies import BasePolicy from hydra.core.config_store import ConfigStore -from hydra.utils import call +from hydra.utils import instantiate from omegaconf import MISSING from imitation_cli.utils import activation_function_class as act_fun_class_cfg @@ -91,9 +91,9 @@ def make_args( del kwargs["_target_"] del kwargs["environment"] - kwargs["activation_fn"] = call(activation_fn) - kwargs["features_extractor_class"] = call(features_extractor_class) - kwargs["optimizer_class"] = call(optimizer_class) + kwargs["activation_fn"] = instantiate(activation_fn) + kwargs["features_extractor_class"] = instantiate(features_extractor_class) + kwargs["optimizer_class"] = instantiate(optimizer_class) return dict( **kwargs, @@ -182,7 +182,7 @@ def make( model = serialize.load_stable_baselines_model( Loaded.type_to_class(policy_type), filename, - call(environment), + instantiate(environment), ) return model.policy diff --git a/src/imitation_cli/utils/reward_network.py b/src/imitation_cli/utils/reward_network.py index a2f6ac25d..43dbfcef4 100644 --- a/src/imitation_cli/utils/reward_network.py +++ b/src/imitation_cli/utils/reward_network.py @@ -10,7 +10,7 @@ from imitation.rewards.reward_nets import RewardNet from hydra.core.config_store import ConfigStore -from hydra.utils import call +from hydra.utils import instantiate from omegaconf import MISSING import imitation_cli.utils.environment as environment_cfg @@ -99,11 +99,11 @@ def make( ) -> RewardNet: from imitation.rewards import reward_nets - venv = call(environment) + venv = instantiate(environment) reward_net = reward_nets.RewardEnsemble( venv.observation_space, venv.action_space, - [call(ensemble_member_config) for _ in range(ensemble_size)], + [instantiate(ensemble_member_config) for _ in range(ensemble_size)], ) if add_std_alpha is not None: return reward_nets.AddSTDRewardWrapper( diff --git a/src/imitation_cli/utils/rl_algorithm.py b/src/imitation_cli/utils/rl_algorithm.py index 8186eb64c..739f0b9a7 100644 --- a/src/imitation_cli/utils/rl_algorithm.py +++ b/src/imitation_cli/utils/rl_algorithm.py @@ -10,7 +10,7 @@ import stable_baselines3 as sb3 from stable_baselines3.common.vec_env import VecEnv -from hydra.utils import call, to_absolute_path +from hydra.utils import instantiate, to_absolute_path from omegaconf import MISSING from imitation_cli.utils import environment as environment_cfg @@ -73,9 +73,9 @@ def make( return sb3.PPO( policy=sb3.common.policies.ActorCriticPolicy, policy_kwargs=policy_kwargs, - env=call(environment), - learning_rate=call(learning_rate), - clip_range=call(clip_range), + env=instantiate(environment), + learning_rate=instantiate(learning_rate), + clip_range=instantiate(clip_range), **kwargs, ) diff --git a/src/imitation_cli/utils/trajectories.py b/src/imitation_cli/utils/trajectories.py index 26d018dbd..80c9322e5 100644 --- a/src/imitation_cli/utils/trajectories.py +++ b/src/imitation_cli/utils/trajectories.py @@ -12,7 +12,7 @@ import numpy as np from hydra.core.config_store import ConfigStore -from hydra.utils import call +from hydra.utils import instantiate from omegaconf import MISSING from imitation_cli.utils import environment as environment_cfg @@ -60,9 +60,9 @@ def make( ) -> Sequence[Trajectory]: from imitation.data import rollout - expert = call(expert_policy) - env = call(expert_policy.environment) - rng = call(rng) + expert = instantiate(expert_policy) + env = instantiate(expert_policy.environment) + rng = instantiate(rng) return rollout.generate_trajectories( expert, env, From 5638222327bed1a500f96a136e7b01ab1a2e2fb9 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 26 Apr 2023 22:11:28 +0200 Subject: [PATCH 22/25] Add type ignore reason. --- src/imitation_cli/airl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 374a4ba17..4e5d9c512 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -49,7 +49,7 @@ class RunConfig: name="airl_run_base", node=RunConfig( airl=airl_cfg.Config( - venv="${environment}", # type: ignore + venv="${environment}", # type: ignore[assignment] ), ), ) From efe3b9007966ef7ca82a60ef9861c188964c7e84 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 27 Apr 2023 11:05:03 +0200 Subject: [PATCH 23/25] Dont' allow variable horizon for AIRL by default. --- src/imitation_cli/algorithm_configurations/airl.py | 2 +- src/imitation_cli/config/airl_optuna.yaml | 1 + src/imitation_cli/config/airl_run.yaml | 1 + src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml | 1 + 4 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/imitation_cli/algorithm_configurations/airl.py b/src/imitation_cli/algorithm_configurations/airl.py index 16dfbdfef..5511a0d35 100644 --- a/src/imitation_cli/algorithm_configurations/airl.py +++ b/src/imitation_cli/algorithm_configurations/airl.py @@ -30,4 +30,4 @@ class Config: init_tensorboard: bool = False init_tensorboard_graph: bool = False debug_use_ground_truth: bool = False - allow_variable_horizon: bool = True # TODO: true just for debugging + allow_variable_horizon: bool = False diff --git a/src/imitation_cli/config/airl_optuna.yaml b/src/imitation_cli/config/airl_optuna.yaml index 9eadb47e9..93224346b 100644 --- a/src/imitation_cli/config/airl_optuna.yaml +++ b/src/imitation_cli/config/airl_optuna.yaml @@ -16,6 +16,7 @@ airl: demo_batch_size: 128 demonstrations: total_timesteps: 10 + allow_variable_horizon: true hydra: mode: MULTIRUN diff --git a/src/imitation_cli/config/airl_run.yaml b/src/imitation_cli/config/airl_run.yaml index 7cbece85e..2bb5364f0 100644 --- a/src/imitation_cli/config/airl_run.yaml +++ b/src/imitation_cli/config/airl_run.yaml @@ -16,3 +16,4 @@ airl: demo_batch_size: 128 demonstrations: total_timesteps: 10 + allow_variable_horizon: true diff --git a/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml b/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml index 141b03c59..54ffc74cf 100644 --- a/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml +++ b/src/imitation_cli/config/airl_sweep_env_and_rewardnet.yaml @@ -15,6 +15,7 @@ airl: demo_batch_size: 128 demonstrations: total_timesteps: 10 + allow_variable_horizon: true hydra: mode: MULTIRUN From 5915a8050c024ef5ab3af79d6412fba9d1e2662d Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 27 Apr 2023 13:34:47 +0200 Subject: [PATCH 24/25] Simplify the class configurations using enums. --- .../algorithm_configurations/airl.py | 2 +- .../utils/activation_function_class.py | 58 ++++++------------- .../utils/feature_extractor_class.py | 43 +++++--------- src/imitation_cli/utils/optimizer_class.py | 41 +++++-------- src/imitation_cli/utils/policy.py | 6 +- 5 files changed, 54 insertions(+), 96 deletions(-) diff --git a/src/imitation_cli/algorithm_configurations/airl.py b/src/imitation_cli/algorithm_configurations/airl.py index 5511a0d35..30e5309df 100644 --- a/src/imitation_cli/algorithm_configurations/airl.py +++ b/src/imitation_cli/algorithm_configurations/airl.py @@ -24,7 +24,7 @@ class Config: reward_net: reward_network.Config = MISSING demo_batch_size: int = 64 n_disc_updates_per_round: int = 2 - disc_opt_cls: optimizer_class.Config = optimizer_class.Adam() + disc_opt_cls: optimizer_class.Config = optimizer_class.Adam gen_train_timesteps: Optional[int] = None gen_replay_buffer_capacity: Optional[int] = None init_tensorboard: bool = False diff --git a/src/imitation_cli/utils/activation_function_class.py b/src/imitation_cli/utils/activation_function_class.py index 2b5852b63..da51c070c 100644 --- a/src/imitation_cli/utils/activation_function_class.py +++ b/src/imitation_cli/utils/activation_function_class.py @@ -1,59 +1,37 @@ """Classes for configuring activation functions.""" import dataclasses +from enum import Enum +import torch from hydra.core.config_store import ConfigStore -@dataclasses.dataclass -class Config: - """Base class for activation function configs.""" - - # Note: we don't define _target_ here so in the subclasses it can be defined last. - # This is the same pattern we use as in schedule.py. - pass - - -@dataclasses.dataclass -class TanH(Config): - """Config for TanH activation function.""" - - _target_: str = "imitation_cli.utils.activation_function_class.TanH.make" - - @staticmethod - def make() -> type: - import torch +class ActivationFunctionClass(Enum): + """Enum of activation function classes.""" - return torch.nn.Tanh + TanH = torch.nn.Tanh + ReLU = torch.nn.ReLU + LeakyReLU = torch.nn.LeakyReLU @dataclasses.dataclass -class ReLU(Config): - """Config for ReLU activation function.""" +class Config: + """Base class for activation function configs.""" - _target_: str = "imitation_cli.utils.activation_function_class.ReLU.make" + activation_function_class: ActivationFunctionClass + _target_: str = "imitation_cli.utils.activation_function_class.Config.make" @staticmethod - def make() -> type: - import torch - - return torch.nn.ReLU - + def make(activation_function_class: ActivationFunctionClass) -> type: + return activation_function_class.value -@dataclasses.dataclass -class LeakyReLU(Config): - """Config for LeakyReLU activation function.""" - - _target_: str = "imitation_cli.utils.activation_function_class.LeakyReLU.make" - - @staticmethod - def make() -> type: - import torch - return torch.nn.LeakyReLU +TanH = Config(ActivationFunctionClass.TanH) +ReLU = Config(ActivationFunctionClass.ReLU) +LeakyReLU = Config(ActivationFunctionClass.LeakyReLU) def register_configs(group: str): cs = ConfigStore.instance() - cs.store(group=group, name="tanh", node=TanH) - cs.store(group=group, name="relu", node=ReLU) - cs.store(group=group, name="leaky_relu", node=LeakyReLU) + for cls in ActivationFunctionClass: + cs.store(group=group, name=cls.name.lower(), node=Config(cls)) diff --git a/src/imitation_cli/utils/feature_extractor_class.py b/src/imitation_cli/utils/feature_extractor_class.py index 27b33cbb5..2beb0e1e3 100644 --- a/src/imitation_cli/utils/feature_extractor_class.py +++ b/src/imitation_cli/utils/feature_extractor_class.py @@ -1,46 +1,35 @@ """Register Hydra configs for stable_baselines3 feature extractors.""" import dataclasses +from enum import Enum +import stable_baselines3.common.torch_layers as torch_layers from hydra.core.config_store import ConfigStore -from omegaconf import MISSING -@dataclasses.dataclass -class Config: - """Base config for stable_baselines3 feature extractors.""" +class FeatureExtractorClass(Enum): + """Enum of feature extractor classes.""" - _target_: str = MISSING + FlattenExtractor = torch_layers.FlattenExtractor + NatureCNN = torch_layers.NatureCNN @dataclasses.dataclass -class FlattenExtractorConfig(Config): - """Config for FlattenExtractor.""" +class Config: + """Base config for stable_baselines3 feature extractors.""" - _target_: str = ( - "imitation_cli.utils.feature_extractor_class.FlattenExtractorConfig.make" - ) + feature_extractor_class: FeatureExtractorClass + _target_: str = "imitation_cli.utils.feature_extractor_class.Config.make" @staticmethod - def make() -> type: - import stable_baselines3 - - return stable_baselines3.common.torch_layers.FlattenExtractor + def make(feature_extractor_class: FeatureExtractorClass) -> type: + return feature_extractor_class.value -@dataclasses.dataclass -class NatureCNNConfig(Config): - """Config for NatureCNN.""" - - _target_: str = "imitation_cli.utils.feature_extractor_class.NatureCNNConfig.make" - - @staticmethod - def make() -> type: - import stable_baselines3 - - return stable_baselines3.common.torch_layers.NatureCNN +FlattenExtractor = Config(FeatureExtractorClass.FlattenExtractor) +NatureCNN = Config(FeatureExtractorClass.NatureCNN) def register_configs(group: str): cs = ConfigStore.instance() - cs.store(group=group, name="flatten", node=FlattenExtractorConfig) - cs.store(group=group, name="nature_cnn", node=NatureCNNConfig) + for cls in FeatureExtractorClass: + cs.store(group=group, name=cls.name.lower(), node=Config(cls)) diff --git a/src/imitation_cli/utils/optimizer_class.py b/src/imitation_cli/utils/optimizer_class.py index 0fd25da95..17a3d01ca 100644 --- a/src/imitation_cli/utils/optimizer_class.py +++ b/src/imitation_cli/utils/optimizer_class.py @@ -1,44 +1,35 @@ """Register optimizer classes with Hydra.""" import dataclasses +from enum import Enum +import torch from hydra.core.config_store import ConfigStore -from omegaconf import MISSING -@dataclasses.dataclass -class Config: - """Base config for optimizer classes.""" +class OptimizerClass(Enum): + """Enum of optimizer classes.""" - _target_: str = MISSING + Adam = torch.optim.Adam + SGD = torch.optim.SGD @dataclasses.dataclass -class Adam(Config): - """Config for Adam optimizer class.""" +class Config: + """Base config for optimizer classes.""" - _target_: str = "imitation_cli.utils.optimizer_class.Adam.make" + optimizer_class: OptimizerClass + _target_: str = "imitation_cli.utils.optimizer_class.Config.make" @staticmethod - def make() -> type: - import torch - - return torch.optim.Adam + def make(optimizer_class: OptimizerClass) -> type: + return optimizer_class.value -@dataclasses.dataclass -class SGD(Config): - """Config for SGD optimizer class.""" - - _target_: str = "imitation_cli.utils.optimizer_class.SGD.make" - - @staticmethod - def make() -> type: - import torch - - return torch.optim.SGD +Adam = Config(OptimizerClass.Adam) +SGD = Config(OptimizerClass.SGD) def register_configs(group: str): cs = ConfigStore.instance() - cs.store(group=group, name="adam", node=Adam) - cs.store(group=group, name="sgd", node=SGD) + for cls in OptimizerClass: + cs.store(group=group, name=cls.name.lower(), node=Config(cls)) diff --git a/src/imitation_cli/utils/policy.py b/src/imitation_cli/utils/policy.py index d749c1197..26fe18a3f 100644 --- a/src/imitation_cli/utils/policy.py +++ b/src/imitation_cli/utils/policy.py @@ -65,7 +65,7 @@ class ActorCriticPolicy(Config): _target_: str = "imitation_cli.utils.policy.ActorCriticPolicy.make" lr_schedule: schedule.Config = schedule.FixedSchedule(3e-4) net_arch: Optional[Dict[str, List[int]]] = None - activation_fn: act_fun_class_cfg.Config = act_fun_class_cfg.TanH() + activation_fn: act_fun_class_cfg.Config = act_fun_class_cfg.TanH ortho_init: bool = True use_sde: bool = False log_std_init: float = 0.0 @@ -73,12 +73,12 @@ class ActorCriticPolicy(Config): use_expln: bool = False squash_output: bool = False features_extractor_class: feature_extractor_class_cfg.Config = ( - feature_extractor_class_cfg.FlattenExtractorConfig() + feature_extractor_class_cfg.FlattenExtractor ) features_extractor_kwargs: Optional[Dict[str, Any]] = None share_features_extractor: bool = True normalize_images: bool = True - optimizer_class: optimizer_class_cfg.Config = optimizer_class_cfg.Adam() + optimizer_class: optimizer_class_cfg.Config = optimizer_class_cfg.Adam optimizer_kwargs: Optional[Dict[str, Any]] = None @staticmethod From 67601044226341f31d9c55bf4782176bcf5c8b30 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 27 Apr 2023 13:35:02 +0200 Subject: [PATCH 25/25] Fix bug in type ignore reason. --- src/imitation_cli/airl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/imitation_cli/airl.py b/src/imitation_cli/airl.py index 4e5d9c512..d74f4bd5b 100644 --- a/src/imitation_cli/airl.py +++ b/src/imitation_cli/airl.py @@ -49,7 +49,7 @@ class RunConfig: name="airl_run_base", node=RunConfig( airl=airl_cfg.Config( - venv="${environment}", # type: ignore[assignment] + venv="${environment}", # type: ignore[arg-type] ), ), )