Skip to content

Commit

Permalink
add train_mce_irl script
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueWang25 committed Oct 6, 2023
1 parent 10ec8a2 commit 7436784
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 60 deletions.
48 changes: 48 additions & 0 deletions src/imitation/scripts/config/train_mce_irl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Configuration for imitation.scripts.train_mce_irl."""
import sacred
from torch import nn
import torch as th

from imitation.scripts.ingredients import environment
from imitation.scripts.ingredients import logging as logging_ingredient
from imitation.scripts.ingredients import policy_evaluation, reward, rl

train_mce_irl_ex = sacred.Experiment(
"train_mce_irl",
ingredients=[
logging_ingredient.logging_ingredient,
environment.environment_ingredient,
reward.reward_ingredient,
rl.rl_ingredient,
policy_evaluation.policy_evaluation_ingredient,
],
)


MUJOCO_SHARED_LOCALS = dict(rl=dict(rl_kwargs=dict(ent_coef=0.1)))
ANT_SHARED_LOCALS = dict(
total_timesteps=int(3e7),
rl=dict(batch_size=16384),
)


@train_mce_irl_ex.config
def train_defaults():
mceirl = {
"discount": 1,
"linf_eps": 0.001,
"grad_l2_eps": 0.0001,
"log_interval": 100,
}
optimizer_cls = th.optim.Adam
optimizer_kwargs = dict(
lr=4e-4,
)
env_kwargs = {
"height": 4,
"horizon": 40,
"width": 7,
"use_xy_obs": True,
}
num_vec = 8 # number of environments in VecEnv
parallel = False
64 changes: 4 additions & 60 deletions src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
"""Trains DAgger on synthetic demonstrations generated from an expert policy."""

from functools import partial
import logging
import os.path as osp
import pathlib
from typing import Any, Dict, Mapping, Optional, Sequence, cast

from typing import Any, Dict, Mapping, Optional, Sequence, Type, cast

import numpy as np
import torch as th
from sacred.observers import FileStorageObserver
from seals import base_envs
from seals.diagnostics.cliff_world import CliffWorldEnv
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

from imitation.algorithms import (
dagger as dagger_algorithm,
sqil as sqil_algorithm,
mce_irl as mceirl_algorithm,
)

from imitation.algorithms import dagger as dagger_algorithm
from imitation.algorithms import sqil as sqil_algorithm
from imitation.data import rollout, types
from imitation.scripts.config.train_imitation import train_imitation_ex
from imitation.scripts.ingredients import bc as bc_ingredient
Expand Down Expand Up @@ -194,53 +185,6 @@ def sqil(
return stats


@train_imitation_ex.command
def mceirl(
mceirl: Mapping[str, Any],
optimizer_cls: th.optim.Optimizer, # not sure
optimizer_kwargs: Mapping[str, Any],
env_kwargs: Mapping[str, Any],
num_vec: int,
parallel: bool,
_run,
_rnd: np.random.Generator,
) -> Mapping[str, Mapping[str, float]]:
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()
env_creator = partial(CliffWorldEnv, **env_kwargs)
env = env_creator()

env_fns = [lambda: base_envs.ExposePOMDPStateWrapper(env_creator())] * num_vec
# This is just a vectorized environment because `generate_trajectories` expects one
if parallel:
# See GH hill-a/stable-baselines issue #217
state_venv = SubprocVecEnv(env_fns, start_method="forkserver")
else:
state_venv = DummyVecEnv(env_fns)

reward_net = reward.make_reward_net(state_venv)
mceirl_trainer = mceirl_algorithm.MCEIRL(
env=env,
demonstrations=expert_trajs,
reward_net=reward_net,
rng=_rnd,
optimizer_cls=optimizer_cls,
optimizer_kwargs=optimizer_kwargs,
discount=mceirl["discount"],
linf_eps=mceirl["linf_eps"],
grad_l2_eps=mceirl["grad_l2_eps"],
log_interval=mceirl["log_interval"],
custom_logger=custom_logger,
)
mceirl_trainer.train(
max_iter=int(mceirl["max_iter"]),
)
util.save_policy(mceirl_trainer.policy, policy_path=osp.join(log_dir, "final.th"))
imit_stats = policy_evaluation.eval_policy(mceirl_trainer.policy, state_venv)
stats = _collect_stats(imit_stats, expert_trajs)
return stats


def main_console():
observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_imitation"
observer = FileStorageObserver(observer_path)
Expand Down
88 changes: 88 additions & 0 deletions src/imitation/scripts/train_mce_irl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
"""Train Finite-horizon tabular Maximum Causal Entropy IRL.
Can be used as a CLI script, or the `train_mce_irl` function
can be called directly.
"""

from functools import partial
import logging
import pathlib
import os.path as osp
from typing import Any, Mapping, Type


import numpy as np
import torch as th
from sacred.observers import FileStorageObserver
from seals import base_envs
from seals.diagnostics.cliff_world import CliffWorldEnv
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

from imitation.algorithms import mce_irl as mceirl_algorithm
from imitation.data import rollout
from imitation.scripts.config.train_mce_irl import train_mce_irl_ex
from imitation.scripts.ingredients import demonstrations
from imitation.scripts.ingredients import logging as logging_ingredient
from imitation.scripts.ingredients import policy_evaluation, reward
from imitation.util import util

logger = logging.getLogger(__name__)


@train_mce_irl_ex.command
def train_mce_irl(
mceirl: Mapping[str, Any],
optimizer_cls: Type[th.optim.Optimizer],
optimizer_kwargs: Mapping[str, Any],
env_kwargs: Mapping[str, Any],
num_vec: int,
parallel: bool,
_run,
_rnd: np.random.Generator,
) -> Mapping[str, Mapping[str, float]]:
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()
env_creator = partial(CliffWorldEnv, **env_kwargs)
env = env_creator()

env_fns = [lambda: base_envs.ExposePOMDPStateWrapper(env_creator())] * num_vec
# This is just a vectorized environment because `generate_trajectories` expects one
if parallel:
# See GH hill-a/stable-baselines issue #217
state_venv = SubprocVecEnv(env_fns, start_method="forkserver")
else:
state_venv = DummyVecEnv(env_fns)

reward_net = reward.make_reward_net(state_venv)
mceirl_trainer = mceirl_algorithm.MCEIRL(
demonstrations=expert_trajs,
env=env,
reward_net=reward_net,
rng=_rnd,
optimizer_cls=optimizer_cls,
optimizer_kwargs=optimizer_kwargs,
discount=mceirl["discount"],
linf_eps=mceirl["linf_eps"],
grad_l2_eps=mceirl["grad_l2_eps"],
log_interval=mceirl["log_interval"],
custom_logger=custom_logger,
)
mceirl_trainer.train(max_iter=int(mceirl["max_iter"]))
util.save_policy(mceirl_trainer.policy, policy_path=osp.join(log_dir, "final.th"))
th.save(reward_net, osp.join(log_dir, "reward_net.pt"))
imit_stats = policy_evaluation.eval_policy(mceirl_trainer.policy, state_venv)
return {
"imit_stats": imit_stats,
"expert_stats": rollout.rollout_stats(expert_trajs),
}


def main_console():
observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_mce_irl"
observer = FileStorageObserver(observer_path)
train_mce_irl_ex.observers.append(observer)
train_mce_irl_ex.run_commandline()


if __name__ == "__main__": # pragma: no cover
main_console()

0 comments on commit 7436784

Please sign in to comment.