Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to Retrieve mean/gen/train/loss during GAIL Training and Create Checkpoints on New Minimum #859

Open
kechirojp opened this issue Oct 9, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@kechirojp
Copy link

Question

Hello,

I am trying to implement a callback during GAIL training that retrieves the mean/gen/train/loss value and creates a checkpoint whenever this loss reaches a new minimum.

However, I am unsure how to access the mean/gen/train/loss value.

Currently, I have the logger configured as follows:

from imitation.util import logger as imit_logger
custom_logger = imit_logger.configure(
        folder=optuna_log_dir,
        format_strs=["tensorboard","csv", "stdout"],
    )

gail_trainer = GAIL(
    demonstrations=flattened_transitions,
    demo_batch_size=1024,
    gen_replay_buffer_capacity=512,
    n_disc_updates_per_round=8,
    venv=venv,
    gen_algo=learner,
    reward_net=reward_net,
    allow_variable_horizon=True,  
    log_dir=log_GAIL_dir,
    init_tensorboard=True,
    init_tensorboard_graph=True,
    custom_logger= custom_logger
)

Could you please advise on how to retrieve mean/gen/train/loss and how to save a checkpoint whenever it reaches a new minimum?

Environment

  • Operating system and version: Google Colaboratory
  • Python version: Python 3.10.12
  • Output of pip freeze --all:absl-py==1.4.0
    accelerate==0.34.2
    aiohappyeyeballs==2.4.3
    aiohttp==3.10.8
    aiosignal==1.3.1
    alabaster==0.7.16
    albucore==0.0.16
    albumentations==1.4.15
    ale-py==0.8.1
    alembic==1.13.3
    altair==4.2.2
    annotated-types==0.7.0
    anyio==3.7.1
    argon2-cffi==23.1.0
    argon2-cffi-bindings==21.2.0
    array_record==0.5.1
    arviz==0.19.0
    astropy==6.1.4
    astropy-iers-data==0.2024.9.30.0.32.59
    astunparse==1.6.3
    async-timeout==4.0.3
    atpublic==4.1.0
    attrs==24.2.0
    audioread==3.0.1
    autograd==1.7.0
    AutoROM==0.6.1
    AutoROM.accept-rom-license==0.6.1
    babel==2.16.0
    backcall==0.2.0
    beautifulsoup4==4.12.3
    bigframes==1.21.0
    bigquery-magics==0.4.0
    bleach==6.1.0
    blinker==1.4
    blis==0.7.11
    blosc2==2.0.0
    bokeh==3.4.3
    Bottleneck==1.4.0
    bqplot==0.12.43
    branca==0.8.0
    build==1.2.2.post1
    CacheControl==0.14.0
    cachetools==5.5.0
    catalogue==2.0.10
    certifi==2024.8.30
    cffi==1.17.1
    chardet==5.2.0
    charset-normalizer==3.3.2
    chex==0.1.87
    clarabel==0.9.0
    click==8.1.7
    cloudpathlib==0.19.0
    cloudpickle==2.2.1
    cmake==3.30.4
    cmdstanpy==1.2.4
    colorama==0.4.6
    colorcet==3.1.0
    colorlog==6.8.2
    colorlover==0.3.0
    colour==0.1.5
    community==1.0.0b1
    confection==0.1.5
    cons==0.4.6
    contextlib2==21.6.0
    contourpy==1.3.0
    cryptography==43.0.1
    cuda-python==12.2.1
    cudf-cu12 @ https://pypi.nvidia.com/cudf-cu12/cudf_cu12-24.6.1-cp310-cp310-manylinux_2_28_x86_64.whl#sha256=6fdd6fd412117503ad23dca75aabb5b536348d3af459b59920152be1a0af6f15
    cufflinks==0.17.3
    cupy-cuda12x==12.2.0
    cvxopt==1.3.2
    cvxpy==1.5.3
    cycler==0.12.1
    cymem==2.0.8
    Cython==3.0.11
    dask==2024.8.0
    datascience==0.17.6
    datasets==3.0.1
    db-dtypes==1.3.0
    dbus-python==1.2.18
    debugpy==1.6.6
    decorator==4.4.2
    defusedxml==0.7.1
    Deprecated==1.2.14
    dill==0.3.8
    distributed==2024.8.0
    distro==1.7.0
    dlib==19.24.2
    dm-tree==0.1.8
    docopt-ng==0.9.0
    docstring_parser==0.16
    docutils==0.18.1
    dopamine_rl==4.0.9
    duckdb==1.1.1
    earthengine-api==1.0.0
    easydict==1.13
    ecos==2.0.14
    editdistance==0.8.1
    eerepr==0.0.4
    einops==0.8.0
    en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl#sha256=86cc141f63942d4b2c5fcee06630fd6f904788d2f0ab005cce45aadb8fb73889
    entrypoints==0.4
    et-xmlfile==1.1.0
    etils==1.9.4
    etuples==0.3.9
    eval_type_backport==0.2.0
    exceptiongroup==1.2.2
    Farama-Notifications==0.0.4
    fastai==2.7.17
    fastcore==1.7.10
    fastdownload==0.0.7
    fastjsonschema==2.20.0
    fastprogress==1.0.3
    fastrlock==0.8.2
    filelock==3.16.1
    firebase-admin==6.5.0
    Flask==2.2.5
    flatbuffers==24.3.25
    flax==0.8.5
    folium==0.17.0
    fonttools==4.54.1
    frozendict==2.4.4
    frozenlist==1.4.1
    fsspec==2024.6.1
    future==1.0.0
    gast==0.6.0
    gcsfs==2024.6.1
    GDAL==3.6.4
    gdown==5.2.0
    geemap==0.34.5
    gensim==4.3.3
    geocoder==1.38.1
    geographiclib==2.0
    geopandas==1.0.1
    geopy==2.4.1
    gin-config==0.5.0
    gitdb==4.0.11
    GitPython==3.1.43
    glob2==0.7
    google==2.0.3
    google-ai-generativelanguage==0.6.6
    google-api-core==2.19.2
    google-api-python-client==2.137.0
    google-auth==2.27.0
    google-auth-httplib2==0.2.0
    google-auth-oauthlib==1.2.1
    google-cloud-aiplatform==1.69.0
    google-cloud-bigquery==3.25.0
    google-cloud-bigquery-connection==1.15.5
    google-cloud-bigquery-storage==2.26.0
    google-cloud-bigtable==2.26.0
    google-cloud-core==2.4.1
    google-cloud-datastore==2.19.0
    google-cloud-firestore==2.16.1
    google-cloud-functions==1.16.5
    google-cloud-iam==2.15.2
    google-cloud-language==2.13.4
    google-cloud-pubsub==2.25.2
    google-cloud-resource-manager==1.12.5
    google-cloud-storage==2.8.0
    google-cloud-translate==3.15.5
    google-colab @ file:///colabtools/dist/google_colab-1.0.0.tar.gz#sha256=f8df4d5e53a79aac2c8af9405f533f064988497448276e4f71c56dc9a1491702
    google-crc32c==1.6.0
    google-generativeai==0.7.2
    google-pasta==0.2.0
    google-resumable-media==2.7.2
    googleapis-common-protos==1.65.0
    googledrivedownloader==0.4
    graphviz==0.20.3
    greenlet==3.1.1
    grpc-google-iam-v1==0.13.1
    grpcio==1.64.1
    grpcio-status==1.48.2
    gspread==6.0.2
    gspread-dataframe==3.3.1
    gym==0.25.2
    gym-notices==0.0.8
    gymnasium==0.29.1
    h5netcdf==1.3.0
    h5py==3.11.0
    holidays==0.57
    holoviews==1.19.1
    html5lib==1.1
    httpimport==1.4.0
    httplib2==0.22.0
    huggingface-hub==0.24.7
    huggingface-sb3==3.0
    humanize==4.10.0
    hyperopt==0.2.7
    ibis-framework==9.2.0
    idna==3.10
    imageio==2.35.1
    imageio-ffmpeg==0.5.1
    imagesize==1.4.1
    imbalanced-learn==0.12.3
    imgaug==0.4.0
    imitation @ file:///content/imitation
    immutabledict==4.2.0
    importlib_metadata==8.4.0
    importlib_resources==6.4.5
    imutils==0.5.4
    inflect==7.4.0
    iniconfig==2.0.0
    intel-cmplr-lib-ur==2024.2.1
    intel-openmp==2024.2.1
    ipyevents==2.0.2
    ipyfilechooser==0.6.0
    ipykernel==5.5.6
    ipyleaflet==0.19.2
    ipyparallel==8.8.0
    ipython==7.34.0
    ipython-genutils==0.2.0
    ipython-sql==0.5.0
    ipytree==0.2.2
    ipywidgets==7.7.1
    itsdangerous==2.2.0
    jax==0.4.33
    jax-cuda12-pjrt==0.4.33
    jax-cuda12-plugin==0.4.33
    jaxlib==0.4.33
    jeepney==0.7.1
    jellyfish==1.1.0
    jieba==0.42.1
    Jinja2==3.1.4
    joblib==1.4.2
    jsonpickle==3.3.0
    jsonschema==4.23.0
    jsonschema-specifications==2023.12.1
    jupyter-client==6.1.12
    jupyter-console==6.1.0
    jupyter-leaflet==0.19.2
    jupyter-server==1.24.0
    jupyter_core==5.7.2
    jupyterlab_pygments==0.3.0
    jupyterlab_widgets==3.0.13
    kaggle==1.6.17
    kagglehub==0.3.1
    keras==3.4.1
    keyring==23.5.0
    kiwisolver==1.4.7
    langcodes==3.4.1
    language_data==1.2.0
    launchpadlib==1.10.16
    lazr.restfulclient==0.14.4
    lazr.uri==1.0.6
    lazy_loader==0.4
    libclang==18.1.1
    librosa==0.10.2.post1
    lightgbm==4.5.0
    linkify-it-py==2.0.3
    llvmlite==0.43.0
    locket==1.0.0
    logical-unification==0.4.6
    lxml==4.9.4
    Mako==1.3.5
    marisa-trie==1.2.0
    Markdown==3.7
    markdown-it-py==3.0.0
    MarkupSafe==2.1.5
    matplotlib==3.7.1
    matplotlib-inline==0.1.7
    matplotlib-venn==1.1.1
    mdit-py-plugins==0.4.2
    mdurl==0.1.2
    miniKanren==1.0.3
    missingno==0.5.2
    mistune==0.8.4
    mizani==0.11.4
    mkl==2024.2.2
    ml-dtypes==0.4.1
    mlxtend==0.23.1
    more-itertools==10.5.0
    moviepy==1.0.3
    mpmath==1.3.0
    msgpack==1.0.8
    multidict==6.1.0
    multipledispatch==1.0.0
    multiprocess==0.70.16
    multitasking==0.0.11
    munch==4.0.0
    murmurhash==1.0.10
    music21==9.1.0
    namex==0.0.8
    natsort==8.4.0
    nbclassic==1.1.0
    nbclient==0.10.0
    nbconvert==6.5.4
    nbformat==5.10.4
    nest-asyncio==1.6.0
    networkx==3.3
    nibabel==5.2.1
    nltk==3.8.1
    notebook==6.5.5
    notebook_shim==0.2.4
    numba==0.60.0
    numexpr==2.10.1
    numpy==1.26.4
    nvidia-cublas-cu12==12.6.3.3
    nvidia-cuda-cupti-cu12==12.6.80
    nvidia-cuda-nvcc-cu12==12.6.77
    nvidia-cuda-runtime-cu12==12.6.77
    nvidia-cudnn-cu12==9.4.0.58
    nvidia-cufft-cu12==11.3.0.4
    nvidia-cusolver-cu12==11.7.1.2
    nvidia-cusparse-cu12==12.5.4.2
    nvidia-nccl-cu12==2.23.4
    nvidia-nvjitlink-cu12==12.6.77
    nvtx==0.2.10
    oauth2client==4.1.3
    oauthlib==3.2.2
    opencv-contrib-python==4.10.0.84
    opencv-python==4.10.0.84
    opencv-python-headless==4.10.0.84
    openpyxl==3.1.5
    opentelemetry-api==1.27.0
    opentelemetry-sdk==1.27.0
    opentelemetry-semantic-conventions==0.48b0
    opt_einsum==3.4.0
    optax==0.2.3
    optree==0.13.0
    optuna==4.0.0
    orbax-checkpoint==0.6.4
    osqp==0.6.7.post0
    packaging==24.1
    pandas==2.2.2
    pandas-datareader==0.10.0
    pandas-gbq==0.23.2
    pandas-stubs==2.2.2.240909
    pandocfilters==1.5.1
    panel==1.4.5
    param==2.1.1
    parso==0.8.4
    parsy==2.1
    partd==1.4.2
    pathlib==1.0.1
    patsy==0.5.6
    peewee==3.17.6
    pexpect==4.9.0
    pickleshare==0.7.5
    pillow==10.4.0
    pip==24.1.2
    pip-tools==7.4.1
    platformdirs==4.3.6
    plotly==5.24.1
    plotnine==0.13.6
    pluggy==1.5.0
    polars==1.7.1
    pooch==1.8.2
    portpicker==1.5.2
    prefetch_generator==1.0.3
    preshed==3.0.9
    prettytable==3.11.0
    proglog==0.1.10
    progressbar2==4.5.0
    prometheus_client==0.21.0
    promise==2.3
    prompt_toolkit==3.0.48
    prophet==1.1.6
    proto-plus==1.24.0
    protobuf==3.20.3
    psutil==5.9.5
    psycopg2==2.9.9
    ptyprocess==0.7.0
    py-cpuinfo==9.0.0
    py4j==0.10.9.7
    pyarrow==16.1.0
    pyarrow-hotfix==0.6
    pyasn1==0.6.1
    pyasn1_modules==0.4.1
    pycocotools==2.0.8
    pycparser==2.22
    pydantic==2.9.2
    pydantic_core==2.23.4
    pydata-google-auth==1.8.2
    pydot==3.0.2
    pydot-ng==2.0.0
    pydotplus==2.0.2
    PyDrive==1.3.1
    PyDrive2==1.20.0
    pyerfa==2.0.1.4
    pygame==2.6.1
    Pygments==2.18.0
    PyGObject==3.42.1
    PyJWT==2.9.0
    pymc==5.16.2
    pymystem3==0.2.0
    pynvjitlink-cu12==0.3.0
    pyogrio==0.10.0
    PyOpenGL==3.1.7
    pyOpenSSL==24.2.1
    pyparsing==3.1.4
    pyperclip==1.9.0
    pyproj==3.7.0
    pyproject_hooks==1.2.0
    pyshp==2.3.1
    PySocks==1.7.1
    pytensor==2.25.5
    pytest==7.4.4
    python-apt==0.0.0
    python-box==7.2.0
    python-dateutil==2.8.2
    python-louvain==0.16
    python-slugify==8.0.4
    python-utils==3.9.0
    pytz==2024.2
    pyviz_comms==3.0.3
    PyYAML==6.0.2
    pyzmq==24.0.1
    qdldl==0.1.7.post4
    ratelim==0.1.6
    referencing==0.35.1
    regex==2024.9.11
    requests==2.32.3
    requests-oauthlib==1.3.1
    requirements-parser==0.9.0
    rich==13.9.1
    rmm-cu12==24.6.0
    rpds-py==0.20.0
    rpy2==3.4.2
    rsa==4.9
    sacred==0.8.6
    safetensors==0.4.5
    scikit-image==0.24.0
    scikit-learn==1.5.2
    scipy==1.13.1
    scooby==0.10.0
    scs==3.2.7
    seaborn==0.13.1
    seals==0.2.1
    SecretStorage==3.3.1
    Send2Trash==1.8.3
    sentencepiece==0.2.0
    setuptools==71.0.4
    shapely==2.0.6
    shellingham==1.5.4
    Shimmy==1.3.0
    simple-parsing==0.1.6
    six==1.16.0
    sklearn-pandas==2.2.0
    smart-open==7.0.4
    smmap==5.0.1
    sniffio==1.3.1
    snowballstemmer==2.2.0
    sortedcontainers==2.4.0
    soundfile==0.12.1
    soupsieve==2.6
    soxr==0.5.0.post1
    spacy==3.7.5
    spacy-legacy==3.0.12
    spacy-loggers==1.0.5
    Sphinx==5.0.2
    sphinxcontrib-applehelp==2.0.0
    sphinxcontrib-devhelp==2.0.0
    sphinxcontrib-htmlhelp==2.1.0
    sphinxcontrib-jsmath==1.0.1
    sphinxcontrib-qthelp==2.0.0
    sphinxcontrib-serializinghtml==2.0.0
    SQLAlchemy==2.0.35
    sqlglot==25.1.0
    sqlparse==0.5.1
    srsly==2.4.8
    stable-baselines3==2.2.1
    stanio==0.5.1
    statsmodels==0.14.4
    StrEnum==0.4.15
    sympy==1.13.3
    tables==3.8.0
    tabulate==0.9.0
    tbb==2021.13.1
    tblib==3.0.0
    tenacity==9.0.0
    tensorboard==2.17.0
    tensorboard-data-server==0.7.2
    tensorflow==2.17.0
    tensorflow-datasets==4.9.6
    tensorflow-hub==0.16.1
    tensorflow-io-gcs-filesystem==0.37.1
    tensorflow-metadata==1.16.0
    tensorflow-probability==0.24.0
    tensorstore==0.1.66
    termcolor==2.4.0
    terminado==0.18.1
    text-unidecode==1.3
    textblob==0.17.1
    tf-slim==1.1.0
    tf_keras==2.17.0
    thinc==8.2.5
    threadpoolctl==3.5.0
    tifffile==2024.9.20
    tinycss2==1.3.0
    tokenizers==0.19.1
    toml==0.10.2
    tomli==2.0.2
    toolz==0.12.1
    torch @ https://download.pytorch.org/whl/cu121_full/torch-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=f3ed9a2b7f8671b2b32a2f036d1b81055eb3ad9b18ba43b705aa34bae4289e1a
    torchaudio @ https://download.pytorch.org/whl/cu121_full/torchaudio-2.4.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=da8c87c80a1c1376a48dc33eef30b03bbdf1df25a05bd2b1c620b8811c7b19be
    torchsummary==1.5.1
    torchvision @ https://download.pytorch.org/whl/cu121_full/torchvision-0.19.1%2Bcu121-cp310-cp310-linux_x86_64.whl#sha256=b8cc4bf381b75522995b601e07a1b433b5fd925dc3e34a7fa6cd22f449d65379
    tornado==6.3.3
    tqdm==4.66.5
    traitlets==5.7.1
    traittypes==0.2.1
    transformers==4.44.2
    tweepy==4.14.0
    typeguard==4.3.0
    typer==0.12.5
    types-pytz==2024.2.0.20241003
    types-setuptools==75.1.0.20240917
    typing_extensions==4.12.2
    tzdata==2024.2
    tzlocal==5.2
    uc-micro-py==1.0.3
    uritemplate==4.1.1
    urllib3==2.2.3
    vega-datasets==0.9.0
    wadllib==1.3.6
    wasabi==1.1.3
    wcwidth==0.2.13
    weasel==0.4.1
    webcolors==24.8.0
    webencodings==0.5.1
    websocket-client==1.8.0
    Werkzeug==3.0.4
    wheel==0.44.0
    widgetsnbextension==3.6.9
    wordcloud==1.9.3
    wrapt==1.16.0
    xarray==2024.9.0
    xarray-einstats==0.8.0
    xgboost==2.1.1
    xlrd==2.0.1
    xxhash==3.5.0
    xyzservices==2024.9.0
    yarl==1.13.1
    yellowbrick==1.5
    yfinance==0.2.44
    zict==3.0.0
    zipp==3.20.2
@kechirojp kechirojp added the bug Something isn't working label Oct 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant