Skip to content

Commit

Permalink
Fix a missing parameter error in PyTorch Lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
taiypeo authored Feb 2, 2022
1 parent 1395f37 commit 9d715b5
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,22 @@ def main(args):
args.logger = wandb_logger

if args.use_custom_checkpoint_callback:
args.checkpoint_callback = ModelCheckpoint(
save_top_k=-1,
period=1,
verbose=True,
)
try:
args.checkpoint_callback = ModelCheckpoint(
save_top_k=-1,
period=1,
verbose=True,
)
except TypeError:
logger.warning(
"'period' parameter of ModelCheckpoint has been renamed to 'every_n_epochs'."
)
args.checkpoint_callback = ModelCheckpoint(
save_top_k=-1,
every_n_epochs=1,
verbose=True,
)

if args.custom_checkpoint_every_n:
custom_checkpoint_callback = StepCheckpointCallback(
step_interval=args.custom_checkpoint_every_n,
Expand Down

0 comments on commit 9d715b5

Please sign in to comment.