Skip to content

Commit

Permalink
switched hooks for saving config xaml file
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Jun 7, 2024
1 parent f2a8180 commit d392e53
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD)

### Added

- Added `rank_zero_print` function to `utils.py` for printing in multi-node distributed training
[\#16](https://github.com/mllam/neural-lam/pull/16)
@sadamov

- Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests.
[/#38](https://github.com/mllam/neural-lam/pull/38)
@SimonKamuk
Expand All @@ -30,6 +35,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Initialization of wandb is now robust for multi-node distributed training and config files are saved to wandb
[\#16](https://github.com/mllam/neural-lam/pull/16)
@sadamov

- Robust restoration of optimizer and scheduler using `ckpt_path`
[\#17](https://github.com/mllam/neural-lam/pull/17)
@sadamov
Expand Down
14 changes: 10 additions & 4 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ def training_step(self, batch):
)
return batch_loss

def on_train_start(self):
"""Save data config file to wandb at start of training"""
if self.trainer.is_global_zero:
wandb.save("neural_lam/data_config.yaml")

def all_gather_cat(self, tensor_to_gather):
"""
Gather tensors across all ranks, and concatenate across dim. 0
Expand Down Expand Up @@ -521,6 +526,11 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
wandb.log(log_dict) # Log all
plt.close("all") # Close all figs

def on_test_start(self):
"""Save data config file to wandb at start of test"""
if self.trainer.is_global_zero:
wandb.save("neural_lam/data_config.yaml")

def on_test_epoch_end(self):
"""
Compute test metrics and make plots at the end of test epoch.
Expand Down Expand Up @@ -597,7 +607,3 @@ def on_load_checkpoint(self, checkpoint):
if not self.restore_opt:
opt = self.configure_optimizers()
checkpoint["optimizer_states"] = [opt.state_dict()]

def on_run_end(self):
if self.trainer.is_global_zero:
wandb.save("neural_lam/data_config.yaml")

0 comments on commit d392e53

Please sign in to comment.