Skip to content

Commit

Permalink
Saving config to directory, not just wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
dibyaghosh committed Aug 17, 2023
1 parent 3d156f9 commit 3c43951
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions experiments/main/train.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
import datetime
import json
import os
from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
import optax
import tensorflow as tf
import tqdm
import wandb
from absl import app, flags, logging
from flax.training import checkpoints
from flax.traverse_util import flatten_dict
from ml_collections import config_flags
import wandb

from orca.data.bridge_dataset import BridgeDataset, glob_to_path_list
from orca.data.text_processing import text_processors
from orca.model import create_model_def
from orca.model.tokenizers import weights_loaders
import optax

from orca.train_utils import (
Timer,
TrainState,
create_train_state,
shard_batch,
Timer,
format_name_with_config,
initialize_compilation_cache,
shard_batch,
)
import datetime

try:
from jax_smi import initialise_tracking # type: ignore
Expand Down Expand Up @@ -90,6 +90,11 @@ def main(_):
)
wandb.config.update(dict(save_dir=save_dir), allow_val_change=True)
logging.info("Saving to %s", save_dir)
tf.io.gfile.makedirs(save_dir)
with tf.io.gfile.GFile(
os.path.join(save_dir, "config.json"), "w"
) as config_file:
config_file.write(FLAGS.config.to_json_best_effort())
else:
save_dir = None
logging.info("save_dir not passed in, not saving checkpoints")
Expand Down

0 comments on commit 3c43951

Please sign in to comment.