diff --git a/drevalpy/models/DIPK/dipk.py b/drevalpy/models/DIPK/dipk.py index 5bc7713..8f89fae 100644 --- a/drevalpy/models/DIPK/dipk.py +++ b/drevalpy/models/DIPK/dipk.py @@ -8,6 +8,7 @@ """ import os +import secrets from typing import Any import numpy as np @@ -153,7 +154,11 @@ def train( # Ensure the checkpoint directory exists os.makedirs(model_checkpoint_dir, exist_ok=True) - checkpoint_path = os.path.join(model_checkpoint_dir, "best_model.pth") + version = "version-" + "".join( + [secrets.choice("0123456789abcdef") for _ in range(20)] + ) # preventing conflicts of filenames + + checkpoint_path = os.path.join(model_checkpoint_dir, f"{version}_best_DIPK_model.pth") # Train model print("Training DIPK model") @@ -238,7 +243,9 @@ def train( # Reload the best model after training print("DIPK: Reloading the best model") - self.model.load_state_dict(torch.load(checkpoint_path, map_location=self.DEVICE)) # noqa S614 + self.model.load_state_dict( + torch.load(checkpoint_path, map_location=self.DEVICE, weights_only=True) # noqa S614 + ) self.model.to(self.DEVICE) # Ensure model is on the correct device def predict( diff --git a/drevalpy/models/SuperFELTR/utils.py b/drevalpy/models/SuperFELTR/utils.py index 569e876..4816302 100644 --- a/drevalpy/models/SuperFELTR/utils.py +++ b/drevalpy/models/SuperFELTR/utils.py @@ -324,11 +324,10 @@ def train_superfeltr_model( [secrets.choice("0123456789abcdef") for _ in range(20)] ) # preventing conflicts of filenames checkpoint_callback = pl.callbacks.ModelCheckpoint( - dirpath=model_checkpoint_dir, + dirpath=os.path.join(model_checkpoint_dir, name), monitor=monitor, mode="min", save_top_k=1, - filename=name, ) # Initialize the Lightning trainer trainer = pl.Trainer( @@ -336,9 +335,8 @@ def train_superfeltr_model( callbacks=[ early_stop_callback, checkpoint_callback, - TQDMProgressBar(), + TQDMProgressBar(refresh_rate=0), ], - default_root_dir=os.path.join(model_checkpoint_dir, "superfeltr_checkpoints/lightning_logs/" + name), ) if val_loader is None: trainer.fit(model, train_loader)