From c8e509cd97f5c5809760183de3600cebc326b9ca Mon Sep 17 00:00:00 2001 From: Aidis Date: Mon, 18 Apr 2022 18:04:35 +0300 Subject: [PATCH 1/3] Support sklearn.compose.TransformedTargetRegressor --- tune_sklearn/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tune_sklearn/utils.py b/tune_sklearn/utils.py index 4f74c1e..dbb439b 100644 --- a/tune_sklearn/utils.py +++ b/tune_sklearn/utils.py @@ -1,6 +1,7 @@ from collections import defaultdict from typing import Dict +from sklearn.compose import TransformedTargetRegressor from sklearn.metrics import check_scoring from sklearn.pipeline import Pipeline from tune_sklearn._detect_booster import ( @@ -94,6 +95,10 @@ def check_error_warm_start(early_stop_type, estimator_config, estimator): def get_early_stop_type(estimator, early_stopping): + # If estimator is TransformedTargetRegressor we should get the wrapped regressor. + if isinstance(estimator, TransformedTargetRegressor): + estimator = estimator.regressor + if not early_stopping: return EarlyStopping.NO_EARLY_STOP can_partial_fit = check_partial_fit(estimator) From d5b8ae4cfc2bff61990a5dd37e1d1854a6f5dde7 Mon Sep 17 00:00:00 2001 From: Aidis Date: Mon, 18 Apr 2022 23:25:55 +0300 Subject: [PATCH 2/3] Add tests --- tests/test_trainable.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/test_trainable.py b/tests/test_trainable.py index b76f3f7..45622de 100644 --- a/tests/test_trainable.py +++ b/tests/test_trainable.py @@ -5,12 +5,15 @@ from tune_sklearn._detect_booster import ( has_xgboost, has_required_lightgbm_version, has_catboost) +from lightgbm import LGBMRegressor +from ray.tune.schedulers import AsyncHyperBandScheduler +from sklearn.compose import TransformedTargetRegressor from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression, SGDClassifier from sklearn.model_selection import check_cv from sklearn.svm import SVC -from tune_sklearn.utils import _check_multimetric_scoring, get_early_stop_type +from tune_sklearn.utils import _check_multimetric_scoring, get_early_stop_type, EarlyStopping def create_xgboost(): @@ -203,3 +206,14 @@ def testWarmStart(self): trainable.train() trainable.train() trainable.stop() + + +class TestGetEarlyStopType(unittest.TestCase): + def testLGBMRegressor(self): + lgbm_regressor = LGBMRegressor() + transformed_target_regressor = TransformedTargetRegressor(lgbm_regressor) + early_stopping_type = get_early_stop_type( + estimator=transformed_target_regressor, + early_stopping=AsyncHyperBandScheduler(), + ) + self.assertEqual(early_stopping_type, EarlyStopping.LGBM) From 980ac200c4f609164ee728890014fd2ba919a3db Mon Sep 17 00:00:00 2001 From: Aidis Date: Wed, 20 Apr 2022 13:52:46 +0300 Subject: [PATCH 3/3] Support sklearn Pipeline --- tune_sklearn/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tune_sklearn/utils.py b/tune_sklearn/utils.py index dbb439b..6cfa5d0 100644 --- a/tune_sklearn/utils.py +++ b/tune_sklearn/utils.py @@ -1,9 +1,11 @@ from collections import defaultdict from typing import Dict +from sklearn.base import RegressorMixin from sklearn.compose import TransformedTargetRegressor from sklearn.metrics import check_scoring from sklearn.pipeline import Pipeline + from tune_sklearn._detect_booster import ( is_xgboost_model, is_lightgbm_model_of_required_version, is_catboost_model) import numpy as np @@ -101,6 +103,14 @@ def get_early_stop_type(estimator, early_stopping): if not early_stopping: return EarlyStopping.NO_EARLY_STOP + + if check_is_pipeline(estimator): + for step_name, step in estimator.steps: + is_regressor = isinstance(step, RegressorMixin) + if is_regressor: + estimator = step + break + can_partial_fit = check_partial_fit(estimator) can_warm_start_iter = check_warm_start_iter(estimator) can_warm_start_ensemble = check_warm_start_ensemble(estimator)