From 3c439511f744923ef0ccd41c46a028d207ec61ec Mon Sep 17 00:00:00 2001 From: Dibya Ghosh Date: Thu, 17 Aug 2023 20:50:10 +0000 Subject: [PATCH] Saving config to directory, not just wandb --- experiments/main/train.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/experiments/main/train.py b/experiments/main/train.py index f449a601..197a89fa 100644 --- a/experiments/main/train.py +++ b/experiments/main/train.py @@ -1,32 +1,32 @@ +import datetime +import json import os from functools import partial import jax import jax.numpy as jnp import numpy as np +import optax import tensorflow as tf import tqdm +import wandb from absl import app, flags, logging from flax.training import checkpoints from flax.traverse_util import flatten_dict from ml_collections import config_flags -import wandb from orca.data.bridge_dataset import BridgeDataset, glob_to_path_list from orca.data.text_processing import text_processors from orca.model import create_model_def from orca.model.tokenizers import weights_loaders -import optax - from orca.train_utils import ( + Timer, TrainState, create_train_state, - shard_batch, - Timer, format_name_with_config, initialize_compilation_cache, + shard_batch, ) -import datetime try: from jax_smi import initialise_tracking # type: ignore @@ -90,6 +90,11 @@ def main(_): ) wandb.config.update(dict(save_dir=save_dir), allow_val_change=True) logging.info("Saving to %s", save_dir) + tf.io.gfile.makedirs(save_dir) + with tf.io.gfile.GFile( + os.path.join(save_dir, "config.json"), "w" + ) as config_file: + config_file.write(FLAGS.config.to_json_best_effort()) else: save_dir = None logging.info("save_dir not passed in, not saving checkpoints")