diff --git a/drevalpy/experiment.py b/drevalpy/experiment.py index 2df1135..d64d55c 100644 --- a/drevalpy/experiment.py +++ b/drevalpy/experiment.py @@ -919,26 +919,28 @@ def train_and_predict( early_stopping_dataset.transform(response_transformation) prediction_dataset.transform(response_transformation) - train_inputs = { - "output": train_dataset, - "cell_line_input": cl_features, - "drug_input": drug_features, - "output_earlystopping": early_stopping_dataset, - } - + print("Training model ...") if model_checkpoint_dir == "TEMPORARY": with tempfile.TemporaryDirectory() as temp_dir: print(f"Using temporary directory: {temp_dir} for model checkpoints") - train_inputs["model_checkpoint_dir"] = temp_dir - print("Training model ...") - model.train(**train_inputs) + model.train( + output=train_dataset, + output_earlystopping=early_stopping_dataset, + cell_line_input=cl_features, + drug_input=drug_features, + model_checkpoint_dir=model_checkpoint_dir, + ) else: if not os.path.exists(model_checkpoint_dir): os.makedirs(model_checkpoint_dir, exist_ok=True) print(f"Using directory: {model_checkpoint_dir} for model checkpoints") - train_inputs["model_checkpoint_dir"] = model_checkpoint_dir - print("Training model ...") - model.train(**train_inputs) + model.train( + output=train_dataset, + output_earlystopping=early_stopping_dataset, + cell_line_input=cl_features, + drug_input=drug_features, + model_checkpoint_dir=model_checkpoint_dir, + ) if len(prediction_dataset) > 0: prediction_dataset._predictions = model.predict( diff --git a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py index 9f93c4e..4df6f8e 100644 --- a/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py +++ b/drevalpy/models/SimpleNeuralNetwork/simple_neural_network.py @@ -102,7 +102,7 @@ def train( output_earlystopping=output_earlystopping, batch_size=16, patience=5, - num_workers=1, + num_workers=8, model_checkpoint_dir=model_checkpoint_dir, ) diff --git a/drevalpy/models/SimpleNeuralNetwork/utils.py b/drevalpy/models/SimpleNeuralNetwork/utils.py index cb55549..9209ed7 100644 --- a/drevalpy/models/SimpleNeuralNetwork/utils.py +++ b/drevalpy/models/SimpleNeuralNetwork/utils.py @@ -1,6 +1,5 @@ """Utility functions for the simple neural network models.""" -import os import secrets from typing import Any @@ -187,7 +186,7 @@ def fit( if trainer_params is None: trainer_params = { - "progress_bar_refresh_rate": 300, + "progress_bar_refresh_rate": 500, "max_epochs": 70, } @@ -252,8 +251,8 @@ def fit( self.checkpoint_callback, progress_bar, ], - default_root_dir=os.path.join(model_checkpoint_dir, "nn_baseline_checkpoints/lightning_logs/" + name), - strategy="ddp_find_unused_parameters_true", + default_root_dir=model_checkpoint_dir, + devices=1, **trainer_params_copy, ) if val_loader is None: @@ -265,6 +264,8 @@ def fit( if self.checkpoint_callback.best_model_path is not None: checkpoint = torch.load(self.checkpoint_callback.best_model_path) # noqa: S614 self.load_state_dict(checkpoint["state_dict"]) + else: + print("checkpoint_callback: No best model found, using the last model.") def forward(self, x) -> torch.Tensor: """