Skip to content

Commit

Permalink
device needs to be one
Browse files Browse the repository at this point in the history
  • Loading branch information
PascalIversen committed Dec 16, 2024
1 parent b507f67 commit 603ff39
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 18 deletions.
28 changes: 15 additions & 13 deletions drevalpy/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,26 +919,28 @@ def train_and_predict(
early_stopping_dataset.transform(response_transformation)
prediction_dataset.transform(response_transformation)

train_inputs = {
"output": train_dataset,
"cell_line_input": cl_features,
"drug_input": drug_features,
"output_earlystopping": early_stopping_dataset,
}

print("Training model ...")
if model_checkpoint_dir == "TEMPORARY":
with tempfile.TemporaryDirectory() as temp_dir:
print(f"Using temporary directory: {temp_dir} for model checkpoints")
train_inputs["model_checkpoint_dir"] = temp_dir
print("Training model ...")
model.train(**train_inputs)
model.train(
output=train_dataset,
output_earlystopping=early_stopping_dataset,
cell_line_input=cl_features,
drug_input=drug_features,
model_checkpoint_dir=model_checkpoint_dir,
)
else:
if not os.path.exists(model_checkpoint_dir):
os.makedirs(model_checkpoint_dir, exist_ok=True)
print(f"Using directory: {model_checkpoint_dir} for model checkpoints")
train_inputs["model_checkpoint_dir"] = model_checkpoint_dir
print("Training model ...")
model.train(**train_inputs)
model.train(
output=train_dataset,
output_earlystopping=early_stopping_dataset,
cell_line_input=cl_features,
drug_input=drug_features,
model_checkpoint_dir=model_checkpoint_dir,
)

if len(prediction_dataset) > 0:
prediction_dataset._predictions = model.predict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def train(
output_earlystopping=output_earlystopping,
batch_size=16,
patience=5,
num_workers=1,
num_workers=8,
model_checkpoint_dir=model_checkpoint_dir,
)

Expand Down
9 changes: 5 additions & 4 deletions drevalpy/models/SimpleNeuralNetwork/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Utility functions for the simple neural network models."""

import os
import secrets
from typing import Any

Expand Down Expand Up @@ -187,7 +186,7 @@ def fit(

if trainer_params is None:
trainer_params = {
"progress_bar_refresh_rate": 300,
"progress_bar_refresh_rate": 500,
"max_epochs": 70,
}

Expand Down Expand Up @@ -252,8 +251,8 @@ def fit(
self.checkpoint_callback,
progress_bar,
],
default_root_dir=os.path.join(model_checkpoint_dir, "nn_baseline_checkpoints/lightning_logs/" + name),
strategy="ddp_find_unused_parameters_true",
default_root_dir=model_checkpoint_dir,
devices=1,
**trainer_params_copy,
)
if val_loader is None:
Expand All @@ -265,6 +264,8 @@ def fit(
if self.checkpoint_callback.best_model_path is not None:
checkpoint = torch.load(self.checkpoint_callback.best_model_path) # noqa: S614
self.load_state_dict(checkpoint["state_dict"])
else:
print("checkpoint_callback: No best model found, using the last model.")

def forward(self, x) -> torch.Tensor:
"""
Expand Down

0 comments on commit 603ff39

Please sign in to comment.