Skip to content

Commit

Permalink
mce_irl_train
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueWang25 committed Oct 6, 2023
1 parent 95110dc commit 10ec8a2
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
64 changes: 60 additions & 4 deletions src/imitation/scripts/train_imitation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
"""Trains DAgger on synthetic demonstrations generated from an expert policy."""

from functools import partial
import logging
import os.path as osp
import pathlib
from typing import Any, Dict, Mapping, Optional, Sequence, cast


import numpy as np
import torch as th
from sacred.observers import FileStorageObserver

from imitation.algorithms import dagger as dagger_algorithm
from imitation.algorithms import sqil as sqil_algorithm
from seals import base_envs
from seals.diagnostics.cliff_world import CliffWorldEnv
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

from imitation.algorithms import (
dagger as dagger_algorithm,
sqil as sqil_algorithm,
mce_irl as mceirl_algorithm,
)
from imitation.data import rollout, types
from imitation.scripts.config.train_imitation import train_imitation_ex
from imitation.scripts.ingredients import bc as bc_ingredient
from imitation.scripts.ingredients import demonstrations, environment, expert
from imitation.scripts.ingredients import logging as logging_ingredient
from imitation.scripts.ingredients import policy_evaluation
from imitation.scripts.ingredients import policy_evaluation, reward
from imitation.util import util

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -185,6 +194,53 @@ def sqil(
return stats


@train_imitation_ex.command
def mceirl(
mceirl: Mapping[str, Any],
optimizer_cls: th.optim.Optimizer, # not sure
optimizer_kwargs: Mapping[str, Any],
env_kwargs: Mapping[str, Any],
num_vec: int,
parallel: bool,
_run,
_rnd: np.random.Generator,
) -> Mapping[str, Mapping[str, float]]:
custom_logger, log_dir = logging_ingredient.setup_logging()
expert_trajs = demonstrations.get_expert_trajectories()
env_creator = partial(CliffWorldEnv, **env_kwargs)
env = env_creator()

env_fns = [lambda: base_envs.ExposePOMDPStateWrapper(env_creator())] * num_vec
# This is just a vectorized environment because `generate_trajectories` expects one
if parallel:
# See GH hill-a/stable-baselines issue #217
state_venv = SubprocVecEnv(env_fns, start_method="forkserver")
else:
state_venv = DummyVecEnv(env_fns)

reward_net = reward.make_reward_net(state_venv)
mceirl_trainer = mceirl_algorithm.MCEIRL(
env=env,
demonstrations=expert_trajs,
reward_net=reward_net,
rng=_rnd,
optimizer_cls=optimizer_cls,
optimizer_kwargs=optimizer_kwargs,
discount=mceirl["discount"],
linf_eps=mceirl["linf_eps"],
grad_l2_eps=mceirl["grad_l2_eps"],
log_interval=mceirl["log_interval"],
custom_logger=custom_logger,
)
mceirl_trainer.train(
max_iter=int(mceirl["max_iter"]),
)
util.save_policy(mceirl_trainer.policy, policy_path=osp.join(log_dir, "final.th"))
imit_stats = policy_evaluation.eval_policy(mceirl_trainer.policy, state_venv)
stats = _collect_stats(imit_stats, expert_trajs)
return stats


def main_console():
observer_path = pathlib.Path.cwd() / "output" / "sacred" / "train_imitation"
observer = FileStorageObserver(observer_path)
Expand Down
6 changes: 6 additions & 0 deletions tests/scripts/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,12 @@ def test_train_bc_warmstart(tmpdir):
assert isinstance(run_warmstart.result, dict)


def test_train_mceirl_main(mceirl_config):
run = train_imitation.train_imitation_ex.run(**mceirl_config)
assert run.status == "COMPLETED"
assert isinstance(run.result, dict)


def test_train_sqil_main(sqil_config):
# NOTE: Having four different expert types as in bc might be overkill for sqil
run = train_imitation.train_imitation_ex.run(**sqil_config)
Expand Down

0 comments on commit 10ec8a2

Please sign in to comment.