diff --git a/examples/lgbm.py b/examples/lgbm.py index f66799b3..7d6bc04e 100644 --- a/examples/lgbm.py +++ b/examples/lgbm.py @@ -3,8 +3,8 @@ Example taken from https://mlfromscratch.com/gridsearch-keras-sklearn/#/ """ -from tune_sklearn import TuneSearchCV import lightgbm as lgb +from tune_sklearn import TuneSearchCV from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split diff --git a/examples/torch_nn.py b/examples/skorch_example.py similarity index 100% rename from examples/torch_nn.py rename to examples/skorch_example.py diff --git a/examples/xgbclassifier.py b/examples/xgbclassifier.py index 22024005..6dc8f493 100644 --- a/examples/xgbclassifier.py +++ b/examples/xgbclassifier.py @@ -36,6 +36,7 @@ digit_search = TuneSearchCV( xgb, param_distributions=params, + early_stopping="MedianStoppingRule", n_trials=3, # use_gpu=True # Commented out for testing on travis, # but this is how you would use gpu diff --git a/requirements.txt b/requirements-test.txt similarity index 100% rename from requirements.txt rename to requirements-test.txt diff --git a/tune_sklearn/_detect_xgboost.py b/tune_sklearn/_detect_xgboost.py new file mode 100644 index 00000000..c27eca07 --- /dev/null +++ b/tune_sklearn/_detect_xgboost.py @@ -0,0 +1,12 @@ +def has_xgboost(): + try: + import xgboost # ignore: F401 + return True + except ImportError: + return False + +def is_xgboost_model(clf): + if not has_xgboost(): + return False + import xgboost # ignore: F401 + return isinstance(clf, XGBModel) diff --git a/tune_sklearn/_trainable.py b/tune_sklearn/_trainable.py index 9a9dd435..6d6aaca8 100644 --- a/tune_sklearn/_trainable.py +++ b/tune_sklearn/_trainable.py @@ -12,6 +12,8 @@ import ray.cloudpickle as cpickle import warnings +from tune_sklearn._detect_xgboost import is_xgboost_model + class _Trainable(Trainable): """Class to be passed in as the first argument of tune.run to train models. @@ -74,6 +76,9 @@ def _setup(self, config): self.estimator_config["max_iter"] = 1 for i in range(n_splits): self.estimator_list[i].set_params(**self.estimator_config) + + if is_xgboost_model(self.main_estimator): + self.saved_models = [None for _ in range(n_splits)] else: self.main_estimator.set_params(**self.estimator_config) @@ -112,8 +117,14 @@ def _train(self): test, train_indices=train) if self._can_partial_fit(): - self.estimator_list[i].partial_fit(X_train, y_train, - np.unique(self.y)) + if is_xgboost_model(self.main_estimator): + self.estimator_list[i].fit( + X_train, y_train, xgb_model=self.saved_models[i]) + self.saved_models[i] = self.estimator_list[ + i].get_booster() + else: + self.estimator_list[i].partial_fit( + X_train, y_train, np.unique(self.y)) else: self.estimator_list[i].fit(X_train, y_train) diff --git a/tune_sklearn/tune_basesearch.py b/tune_sklearn/tune_basesearch.py index c5d3e3e6..a27bd02a 100644 --- a/tune_sklearn/tune_basesearch.py +++ b/tune_sklearn/tune_basesearch.py @@ -30,6 +30,8 @@ import multiprocessing import os +from tune_sklearn._detect_xgboost import is_xgboost_model + def resolve_early_stopping(early_stopping, max_iters): if isinstance(early_stopping, str): @@ -442,7 +444,6 @@ def _can_early_stop(self): bool: if the estimator can early stop """ - from sklearn.tree import BaseDecisionTree from sklearn.ensemble import BaseEnsemble @@ -458,7 +459,9 @@ def _can_early_stop(self): and is_not_ensemble_subclass and is_not_tree_subclass) - return can_partial_fit or can_warm_start + is_gbm = is_xgboost_model(self.estimator) + + return can_partial_fit or can_warm_start or is_gbm def _fill_config_hyperparam(self, config): """Fill in the ``config`` dictionary with the hyperparameters. diff --git a/tune_sklearn/tune_gridsearch.py b/tune_sklearn/tune_gridsearch.py index b8d9003a..7e338a59 100644 --- a/tune_sklearn/tune_gridsearch.py +++ b/tune_sklearn/tune_gridsearch.py @@ -206,7 +206,6 @@ def _tune_run(self, config, resources_per_trial): stop={"training_iteration": self.max_iters}, config=config, fail_fast=True, - checkpoint_at_end=True, resources_per_trial=resources_per_trial, local_dir=os.path.expanduser(self.local_dir)) else: @@ -218,7 +217,6 @@ def _tune_run(self, config, resources_per_trial): stop={"training_iteration": self.max_iters}, config=config, fail_fast=True, - checkpoint_at_end=True, resources_per_trial=resources_per_trial, local_dir=os.path.expanduser(self.local_dir)) diff --git a/tune_sklearn/tune_search.py b/tune_sklearn/tune_search.py index a6d99c5d..a5bf8be9 100644 --- a/tune_sklearn/tune_search.py +++ b/tune_sklearn/tune_search.py @@ -210,7 +210,7 @@ class TuneSearchCV(TuneBaseSearchCV): However computing the scores on the training set can be computationally expensive and is not strictly required to select the parameters that yield the best generalization performance. - local_dir (str): A string that defines where checkpoints will + local_dir (str): A string that defines where checkpoints and logs will be stored. Defaults to "~/ray_results" max_iters (int): Indicates the maximum number of epochs to run for each hyperparameter configuration sampled (specified by ``n_trials``). @@ -232,8 +232,8 @@ class TuneSearchCV(TuneBaseSearchCV): All types of search aside from Randomized search require parent libraries to be installed. use_gpu (bool): Indicates whether to use gpu for fitting. - Defaults to False. If True, training will use 1 gpu - for `resources_per_trial`. + Defaults to False. If True, training will start processes + with the proper CUDA VISIBLE DEVICE settings set. **search_kwargs (Any): Additional arguments to pass to the SearchAlgorithms (tune.suggest) objects. @@ -490,7 +490,6 @@ def _try_import_required_libraries(self, search_optimization): from skopt import Optimizer # noqa: F401 from ray.tune.suggest.skopt import SkOptSearch # noqa: F401 except ImportError: - logger.exception() raise ImportError( "It appears that scikit-optimize is not installed. " "Do: pip install scikit-optimize") from None @@ -500,7 +499,6 @@ def _try_import_required_libraries(self, search_optimization): from ray.tune.schedulers import HyperBandForBOHB # noqa: F401 import ConfigSpace as CS # noqa: F401 except ImportError: - logger.exception() raise ImportError( "It appears that either HpBandSter or ConfigSpace " "is not installed. " @@ -510,7 +508,6 @@ def _try_import_required_libraries(self, search_optimization): from ray.tune.suggest.hyperopt import HyperOptSearch # noqa: F401,E501 from hyperopt import hp # noqa: F401 except ImportError: - logger.exception() raise ImportError("It appears that hyperopt is not installed. " "Do: pip install hyperopt") from None elif search_optimization == "optuna": @@ -518,7 +515,6 @@ def _try_import_required_libraries(self, search_optimization): from ray.tune.suggest.optuna import OptunaSearch, param # noqa: F401,E501 import optuna # noqa: F401 except ImportError: - logger.exception() raise ImportError("It appears that optuna is not installed. " "Do: pip install optuna") from None @@ -561,7 +557,6 @@ def _tune_run(self, config, resources_per_trial): num_samples=self.num_samples, config=config, fail_fast=True, - checkpoint_at_end=True, resources_per_trial=resources_per_trial, local_dir=os.path.expanduser(self.local_dir)) @@ -623,7 +618,6 @@ def _tune_run(self, config, resources_per_trial): num_samples=self.num_samples, config=config, fail_fast=True, - checkpoint_at_end=True, resources_per_trial=resources_per_trial, local_dir=os.path.expanduser(self.local_dir))