Skip to content

Commit

Permalink
fixing model checkpointing. also improved loading and saving efficieny
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalIversen committed Dec 17, 2024
1 parent 17aece9 commit 79705e1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
11 changes: 9 additions & 2 deletions drevalpy/models/DIPK/dipk.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import os
import secrets
from typing import Any

import numpy as np
Expand Down Expand Up @@ -153,7 +154,11 @@ def train(

# Ensure the checkpoint directory exists
os.makedirs(model_checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(model_checkpoint_dir, "best_model.pth")
version = "version-" + "".join(
[secrets.choice("0123456789abcdef") for _ in range(20)]
) # preventing conflicts of filenames

checkpoint_path = os.path.join(model_checkpoint_dir, f"{version}_best_DIPK_model.pth")

# Train model
print("Training DIPK model")
Expand Down Expand Up @@ -238,7 +243,9 @@ def train(

# Reload the best model after training
print("DIPK: Reloading the best model")
self.model.load_state_dict(torch.load(checkpoint_path, map_location=self.DEVICE)) # noqa S614
self.model.load_state_dict(
torch.load(checkpoint_path, map_location=self.DEVICE, weights_only=True) # noqa S614
)
self.model.to(self.DEVICE) # Ensure model is on the correct device

def predict(
Expand Down
6 changes: 2 additions & 4 deletions drevalpy/models/SuperFELTR/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,21 +324,19 @@ def train_superfeltr_model(
[secrets.choice("0123456789abcdef") for _ in range(20)]
) # preventing conflicts of filenames
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath=model_checkpoint_dir,
dirpath=os.path.join(model_checkpoint_dir, name),
monitor=monitor,
mode="min",
save_top_k=1,
filename=name,
)
# Initialize the Lightning trainer
trainer = pl.Trainer(
max_epochs=hpams["epochs"],
callbacks=[
early_stop_callback,
checkpoint_callback,
TQDMProgressBar(),
TQDMProgressBar(refresh_rate=0),
],
default_root_dir=os.path.join(model_checkpoint_dir, "superfeltr_checkpoints/lightning_logs/" + name),
)
if val_loader is None:
trainer.fit(model, train_loader)
Expand Down

0 comments on commit 79705e1

Please sign in to comment.