Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RewardNetwork predict_processed doesn't work without next_state and done #836

Open
gustavodemari opened this issue Jan 9, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@gustavodemari
Copy link

Bug description

RewardNet predict_processed method only works using state, action, next_state and done attributes, despite trained using only state, action.

For example, the BasicRewardNet by default trains a network using only state, action, i.e, $R(s, a)$.
However, the predict_processed needs state, action, next_state and done attributes.

Thus, maybe predict_processed should have next_state and done optional (see below) and inside the method should check if next_state and done are None to change the behavior.

def predict_processed(
        self,
        state: np.ndarray,
        action: np.ndarray,
        next_state: Optional[np.ndarray] = None,
        done: Optional[np.ndarray] = None,
        **kwargs,
    ) -> np.ndarray:

Steps to reproduce

#!/usr/bin/env python
# coding: utf-8

import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

SEED = 42

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)
expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=env,
)

from imitation.data import rollout

rollouts = rollout.rollout(
    expert,
    env,
    rollout.make_sample_until(min_timesteps=None, min_episodes=60),
    rng=np.random.default_rng(SEED),
)

from imitation.algorithms.adversarial.gail import GAIL
from imitation.rewards.reward_nets import BasicRewardNet
from imitation.util.networks import RunningNorm
from stable_baselines3 import PPO
from stable_baselines3.ppo import MlpPolicy
from stable_baselines3.common.evaluation import evaluate_policy

learner = PPO(
    env=env,
    policy=MlpPolicy,
    batch_size=64,
    ent_coef=0.0,
    learning_rate=0.0004,
    gamma=0.95,
    n_epochs=5,
    seed=SEED,
)
reward_net = BasicRewardNet(
    observation_space=env.observation_space,
    action_space=env.action_space,
    normalize_input_layer=RunningNorm,
)
gail_trainer = GAIL(
    demonstrations=rollouts,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=env,
    gen_algo=learner,
    reward_net=reward_net,
)

env.seed(SEED)
learner_rewards_before_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

gail_trainer.train(200_000)

env.seed(SEED)
learner_rewards_after_training, _ = evaluate_policy(
    learner, env, 100, return_episode_rewards=True
)

print(
    "Rewards before training:",
    np.mean(learner_rewards_before_training),
    "+/-",
    np.std(learner_rewards_before_training),
)
print(
    "Rewards after training:",
    np.mean(learner_rewards_after_training),
    "+/-",
    np.std(learner_rewards_after_training),
)

n_samples = 10

print(f"Generating {n_samples} samples")

obs = np.vstack([env.observation_space.sample() for i in range(n_samples)])
action = np.vstack([env.action_space.sample() for i in range(n_samples)])
next_obs = np.vstack([env.observation_space.sample() for i in range(n_samples)])
done = np.array([False] * len(obs))

print(f"Predicting rewards using {n_samples} samples")
rewards_predict_processed = reward_net.predict_processed(state=obs, action=action, next_state=next_obs, done=done)
print(f"Rewards: {rewards_predict_processed}")

print(f"Predicting rewards using {n_samples} samples, without next_state and done")
reward_net.predict_processed(state=obs, action=action)
reward_net.predict_processed(state=obs, action=action, next_state=None, done=None)

Environment

  • Operating system and version: Ubuntu 23.10
  • Python version: 3.8.10
  • Output of pip freeze --all:
Pip Freeze

absl-py==2.0.0
aiohttp==3.9.1
aiosignal==1.3.1
alembic==1.13.1
anyio==4.2.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
Babel==2.14.0
backcall==0.2.0
beautifulsoup4==4.12.2
bleach==6.1.0
cachetools==5.3.2
certifi==2023.11.17
cffi==1.16.0
charset-normalizer==3.3.2
cloudpickle==3.0.0
colorama==0.4.6
colorlog==6.8.0
comm==0.2.1
contourpy==1.1.1
cycler==0.12.1
Cython==3.0.7
dataclasses==0.6
datasets==2.16.1
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
dfa==2.1.2
dill==0.3.7
docopt==0.6.2
exceptiongroup==1.2.0
execnet==2.0.2
executing==2.0.1
Farama-Notifications==0.0.4
fastjsonschema==2.19.1
filelock==3.13.1
fonttools==4.47.0
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2023.10.0
funcy==1.18
gitdb==4.0.11
GitPython==3.1.40
google-auth==2.26.1
google-auth-oauthlib==1.0.0
GPy==1.10.0
GPyOpt==1.2.6
greenlet==3.0.3
grpcio==1.60.0
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
h5py==3.10.0
huggingface-hub==0.20.1
huggingface-sb3==3.0
idna==3.6
imitation==1.0.0
importlib-metadata==7.0.1
importlib-resources==6.1.1
iniconfig==2.0.0
ipykernel==6.28.0
ipython==8.12.3
isoduration==20.11.0
istype==0.2.0
jedi==0.19.1
Jinja2==3.1.2
joblib==1.3.2
json5==0.9.14
jsonpickle==3.0.2
jsonpointer==2.4
jsonschema==4.20.0
jsonschema-specifications==2023.12.1
jupyter-events==0.9.0
jupyter-lsp==2.2.1
jupyter_client==8.6.0
jupyter_core==5.7.0
jupyter_server==2.12.2
jupyter_server_terminals==0.5.1
jupyterlab==4.0.10
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.2
kiwisolver==1.4.5
lazytree==0.3.2
lenses==0.5.0
Mako==1.3.0
Markdown==3.5.1
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.7.4
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
munch==4.0.0
mypy-extensions==1.0.0
nbclient==0.9.0
nbconvert==7.14.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.1
notebook_shim==0.2.3
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
optuna==3.5.0
orderedset==2.0.3
overrides==7.4.0
packaging==23.2
pandas==2.0.3
pandocfilters==1.5.0
paramz==0.9.5
parso==0.8.3
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.2.0
pip==23.3.1
pkgutil_resolve_name==1.3.10
platformdirs==4.1.0
pluggy==1.3.0
probabilistic-automata==0.4.2
prometheus-client==0.19.0
prompt-toolkit==3.0.43
protobuf==4.25.1
psutil==5.9.7
ptyprocess==0.7.0
pure-eval==0.2.2
py==1.11.0
py-cpuinfo==9.0.0
py-spy==0.3.14
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycparser==2.21
pygame==2.5.2
Pygments==2.17.2
pyparsing==3.1.1
pyrsistent==0.20.0
pytest==7.4.4
pytest-forked==1.6.0
pytest-xdist==2.5.0
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2023.3.post1
PyYAML==6.0.1
pyzmq==25.1.2
referencing==0.32.1
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.0
rpds-py==0.16.2
rsa==4.9
sacred==0.8.5
scikit-learn==1.3.2
scipy==1.10.1
seals==0.2.1
Send2Trash==1.8.2
setuptools==68.2.2
singledispatch==4.1.0
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
SQLAlchemy==2.0.25
stable-baselines3==2.2.1
stack-data==0.6.3
structlog==23.3.0
sympy==1.12
tensorboard==2.14.0
tensorboard-data-server==0.7.2
terminado==0.18.0
threadpoolctl==3.2.0
tinycss2==1.2.1
tomli==2.0.1
torch==2.1.2
tornado==6.4
tqdm==4.66.1
traitlets==5.14.1
triton==2.1.0
types-python-dateutil==2.8.19.20240106
typing-inspect==0.5.0
typing_extensions==4.9.0
tzdata==2023.4
uri-template==1.3.0
urllib3==2.1.0
wasabi==1.1.2
wcwidth==0.2.12
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
Werkzeug==3.0.1
wheel==0.41.2
wrapt==1.16.0
xeus-python==0.15.12
xeus-python-shell==0.5.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0

@gustavodemari gustavodemari added the bug Something isn't working label Jan 9, 2024
@CAI23sbP
Copy link

CAI23sbP commented Apr 24, 2024

How are you @gustavodemari ?
In my opinion, it is not a bug.
See this link, flatten_trajectories creates next_obs and dones automatically.
In this code which is used in GAIL for training, you can see flatten_trajectories s family, which is called flatten_trajectories_with_rew.
So, you just choose about dones and next_obs in initialize BasicRewardNet, whether to use them or not.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants