Skip to content

Commit

Permalink
Merge pull request #96 from daisybio/ray_optional
Browse files Browse the repository at this point in the history
Ray optional
  • Loading branch information
JudithBernett authored Dec 12, 2024
2 parents 1c8e101 + fbad259 commit dd06454
Show file tree
Hide file tree
Showing 4 changed files with 550 additions and 675 deletions.
16 changes: 10 additions & 6 deletions drevalpy/experiment.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Main module for running the drug response prediction experiment."""

import importlib
import json
import os
import shutil
Expand All @@ -8,9 +9,7 @@

import numpy as np
import pandas as pd
import ray
import torch
from ray import tune
from sklearn.base import TransformerMixin

from .datasets.dataset import DrugResponseDataset, FeatureDataset
Expand All @@ -19,6 +18,11 @@
from .models.drp_model import DRPModel
from .pipeline_function import pipeline_function

if importlib.util.find_spec("ray"):
import ray
else:
ray = None


def drug_response_experiment(
models: list[type[DRPModel]],
Expand Down Expand Up @@ -47,7 +51,7 @@ def drug_response_experiment(
:param response_transformation: normalizer to use for the response data
:param metric: metric to use for hyperparameter optimization
:param n_cv_splits: number of cross-validation splits
:param multiprocessing: whether to use multiprocessing
:param multiprocessing: whether to use multiprocessing. This requires Ray to be installed.
:param randomization_mode: list of randomization modes to do. Modes: SVCC, SVRC, SVCD, SVRD Can be a list of
randomization tests e.g. 'SVCC SVCD'. Default is None, which means no randomization tests are run.
Expand Down Expand Up @@ -1027,7 +1031,7 @@ def hpam_tune_raytune(
path_data: str = "data",
) -> dict:
"""
Tune the hyperparameters for the given model using raytune.
Tune the hyperparameters for the given model using raytune. This requires ray to be installed.
:param model: model to use
:param train_dataset: training dataset
Expand All @@ -1047,7 +1051,7 @@ def hpam_tune_raytune(
resources_per_trial = {"gpu": 1} # TODO make this user defined
else:
resources_per_trial = {"cpu": 1} # TODO make this user defined
analysis = tune.run(
analysis = ray.tune.run(
lambda hpams: train_and_evaluate(
model=model,
hpams=hpams,
Expand All @@ -1058,7 +1062,7 @@ def hpam_tune_raytune(
metric=metric,
response_transformation=response_transformation,
),
config=tune.grid_search(hpam_set),
config=ray.tune.grid_search(hpam_set),
mode="min",
num_samples=5,
resources_per_trial=resources_per_trial,
Expand Down
Loading

0 comments on commit dd06454

Please sign in to comment.