You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thus, maybe predict_processed should have next_state and done optional (see below) and inside the method should check if next_state and done are None to change the behavior.
How are you @gustavodemari ?
In my opinion, it is not a bug.
See this link, flatten_trajectories creates next_obs and dones automatically.
In this code which is used in GAIL for training, you can see flatten_trajectories s family, which is called flatten_trajectories_with_rew.
So, you just choose about dones and next_obs in initialize BasicRewardNet, whether to use them or not.
Bug description
RewardNet
predict_processed
method only works usingstate, action, next_state and done
attributes, despite trained using onlystate, action
.For example, the BasicRewardNet by default trains a network using only$R(s, a)$ .
state, action
, i.e,However, the
predict_processed
needsstate, action, next_state and done
attributes.Thus, maybe
predict_processed
should havenext_state and done
optional (see below) and inside the method should check if next_state and done are None to change the behavior.Steps to reproduce
Environment
pip freeze --all
:Pip Freeze
absl-py==2.0.0
aiohttp==3.9.1
aiosignal==1.3.1
alembic==1.13.1
anyio==4.2.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.2.0
Babel==2.14.0
backcall==0.2.0
beautifulsoup4==4.12.2
bleach==6.1.0
cachetools==5.3.2
certifi==2023.11.17
cffi==1.16.0
charset-normalizer==3.3.2
cloudpickle==3.0.0
colorama==0.4.6
colorlog==6.8.0
comm==0.2.1
contourpy==1.1.1
cycler==0.12.1
Cython==3.0.7
dataclasses==0.6
datasets==2.16.1
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
dfa==2.1.2
dill==0.3.7
docopt==0.6.2
exceptiongroup==1.2.0
execnet==2.0.2
executing==2.0.1
Farama-Notifications==0.0.4
fastjsonschema==2.19.1
filelock==3.13.1
fonttools==4.47.0
fqdn==1.5.1
frozenlist==1.4.1
fsspec==2023.10.0
funcy==1.18
gitdb==4.0.11
GitPython==3.1.40
google-auth==2.26.1
google-auth-oauthlib==1.0.0
GPy==1.10.0
GPyOpt==1.2.6
greenlet==3.0.3
grpcio==1.60.0
gym==0.26.2
gym-notices==0.0.8
gymnasium==0.29.1
h5py==3.10.0
huggingface-hub==0.20.1
huggingface-sb3==3.0
idna==3.6
imitation==1.0.0
importlib-metadata==7.0.1
importlib-resources==6.1.1
iniconfig==2.0.0
ipykernel==6.28.0
ipython==8.12.3
isoduration==20.11.0
istype==0.2.0
jedi==0.19.1
Jinja2==3.1.2
joblib==1.3.2
json5==0.9.14
jsonpickle==3.0.2
jsonpointer==2.4
jsonschema==4.20.0
jsonschema-specifications==2023.12.1
jupyter-events==0.9.0
jupyter-lsp==2.2.1
jupyter_client==8.6.0
jupyter_core==5.7.0
jupyter_server==2.12.2
jupyter_server_terminals==0.5.1
jupyterlab==4.0.10
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.2
kiwisolver==1.4.5
lazytree==0.3.2
lenses==0.5.0
Mako==1.3.0
Markdown==3.5.1
markdown-it-py==3.0.0
MarkupSafe==2.1.3
matplotlib==3.7.4
matplotlib-inline==0.1.6
mdurl==0.1.2
mistune==3.0.2
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
munch==4.0.0
mypy-extensions==1.0.0
nbclient==0.9.0
nbconvert==7.14.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.1
notebook_shim==0.2.3
numpy==1.24.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
optuna==3.5.0
orderedset==2.0.3
overrides==7.4.0
packaging==23.2
pandas==2.0.3
pandocfilters==1.5.0
paramz==0.9.5
parso==0.8.3
pexpect==4.9.0
pickleshare==0.7.5
pillow==10.2.0
pip==23.3.1
pkgutil_resolve_name==1.3.10
platformdirs==4.1.0
pluggy==1.3.0
probabilistic-automata==0.4.2
prometheus-client==0.19.0
prompt-toolkit==3.0.43
protobuf==4.25.1
psutil==5.9.7
ptyprocess==0.7.0
pure-eval==0.2.2
py==1.11.0
py-cpuinfo==9.0.0
py-spy==0.3.14
pyarrow==14.0.2
pyarrow-hotfix==0.6
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycparser==2.21
pygame==2.5.2
Pygments==2.17.2
pyparsing==3.1.1
pyrsistent==0.20.0
pytest==7.4.4
pytest-forked==1.6.0
pytest-xdist==2.5.0
python-dateutil==2.8.2
python-json-logger==2.0.7
pytz==2023.3.post1
PyYAML==6.0.1
pyzmq==25.1.2
referencing==0.32.1
requests==2.31.0
requests-oauthlib==1.3.1
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.7.0
rpds-py==0.16.2
rsa==4.9
sacred==0.8.5
scikit-learn==1.3.2
scipy==1.10.1
seals==0.2.1
Send2Trash==1.8.2
setuptools==68.2.2
singledispatch==4.1.0
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
SQLAlchemy==2.0.25
stable-baselines3==2.2.1
stack-data==0.6.3
structlog==23.3.0
sympy==1.12
tensorboard==2.14.0
tensorboard-data-server==0.7.2
terminado==0.18.0
threadpoolctl==3.2.0
tinycss2==1.2.1
tomli==2.0.1
torch==2.1.2
tornado==6.4
tqdm==4.66.1
traitlets==5.14.1
triton==2.1.0
types-python-dateutil==2.8.19.20240106
typing-inspect==0.5.0
typing_extensions==4.9.0
tzdata==2023.4
uri-template==1.3.0
urllib3==2.1.0
wasabi==1.1.2
wcwidth==0.2.12
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
Werkzeug==3.0.1
wheel==0.41.2
wrapt==1.16.0
xeus-python==0.15.12
xeus-python-shell==0.5.0
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0
The text was updated successfully, but these errors were encountered: