-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #87 from SimonBlanke/feature/sklearn-integration
add prototype for sklearn integration
- Loading branch information
Showing
15 changed files
with
387 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Author: Simon Blanke | ||
# Email: simon.blanke@yahoo.com | ||
# License: MIT License | ||
|
||
|
||
from .sklearn import HyperactiveSearchCV |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Author: Simon Blanke | ||
# Email: simon.blanke@yahoo.com | ||
# License: MIT License | ||
|
||
|
||
from .hyperactive_search_cv import HyperactiveSearchCV |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Author: Simon Blanke | ||
# Email: simon.blanke@yahoo.com | ||
# License: MIT License | ||
|
||
|
||
from sklearn.utils.metaestimators import available_if | ||
from sklearn.utils.deprecation import _deprecate_Xt_in_inverse_transform | ||
from sklearn.exceptions import NotFittedError | ||
from sklearn.utils.validation import check_is_fitted | ||
|
||
from .utils import _estimator_has | ||
|
||
|
||
# NOTE Implementations of following methods from: | ||
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/model_selection/_search.py | ||
# Tag: 1.5.1 | ||
class BestEstimator: | ||
|
||
@available_if(_estimator_has("score_samples")) | ||
def score_samples(self, X): | ||
check_is_fitted(self) | ||
return self.best_estimator_.score_samples(X) | ||
|
||
@available_if(_estimator_has("predict")) | ||
def predict(self, X): | ||
check_is_fitted(self) | ||
return self.best_estimator_.predict(X) | ||
|
||
@available_if(_estimator_has("predict_proba")) | ||
def predict_proba(self, X): | ||
check_is_fitted(self) | ||
return self.best_estimator_.predict_proba(X) | ||
|
||
@available_if(_estimator_has("predict_log_proba")) | ||
def predict_log_proba(self, X): | ||
check_is_fitted(self) | ||
return self.best_estimator_.predict_log_proba(X) | ||
|
||
@available_if(_estimator_has("decision_function")) | ||
def decision_function(self, X): | ||
check_is_fitted(self) | ||
return self.best_estimator_.decision_function(X) | ||
|
||
@available_if(_estimator_has("transform")) | ||
def transform(self, X): | ||
check_is_fitted(self) | ||
return self.best_estimator_.transform(X) | ||
|
||
@available_if(_estimator_has("inverse_transform")) | ||
def inverse_transform(self, X=None, Xt=None): | ||
X = _deprecate_Xt_in_inverse_transform(X, Xt) | ||
check_is_fitted(self) | ||
return self.best_estimator_.inverse_transform(X) | ||
|
||
@property | ||
def classes_(self): | ||
_estimator_has("classes_")(self) | ||
return self.best_estimator_.classes_ |
83 changes: 83 additions & 0 deletions
83
src/hyperactive/integrations/sklearn/hyperactive_search_cv.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
# Author: Simon Blanke | ||
# Email: simon.blanke@yahoo.com | ||
# License: MIT License | ||
|
||
|
||
from sklearn.base import BaseEstimator, clone | ||
from sklearn.metrics import check_scoring | ||
from sklearn.utils.validation import indexable, _check_method_params | ||
|
||
|
||
from hyperactive import Hyperactive | ||
|
||
from .objective_function_adapter import ObjectiveFunctionAdapter | ||
from .best_estimator import BestEstimator | ||
|
||
|
||
class HyperactiveSearchCV(BaseEstimator, BestEstimator): | ||
_required_parameters = ["estimator", "optimizer", "params_config"] | ||
|
||
def __init__( | ||
self, | ||
estimator, | ||
optimizer, | ||
params_config, | ||
n_iter=100, | ||
*, | ||
scoring=None, | ||
n_jobs=1, | ||
random_state=None, | ||
refit=True, | ||
cv=None, | ||
): | ||
self.estimator = estimator | ||
self.optimizer = optimizer | ||
self.params_config = params_config | ||
self.n_iter = n_iter | ||
self.scoring = scoring | ||
self.n_jobs = n_jobs | ||
self.random_state = random_state | ||
self.refit = refit | ||
self.cv = cv | ||
|
||
def _refit( | ||
self, | ||
X, | ||
y=None, | ||
**fit_params, | ||
): | ||
self.best_estimator_ = clone(self.estimator) | ||
self.best_estimator_.fit(X, y, **fit_params) | ||
return self | ||
|
||
def fit(self, X, y, **params): | ||
X, y = indexable(X, y) | ||
X, y = self._validate_data(X, y) | ||
|
||
params = _check_method_params(X, params=params) | ||
self.scorer_ = check_scoring(self.estimator, scoring=self.scoring) | ||
|
||
objective_function_adapter = ObjectiveFunctionAdapter( | ||
self.estimator, | ||
) | ||
objective_function_adapter.add_dataset(X, y) | ||
objective_function_adapter.add_validation(self.scorer_, self.cv) | ||
|
||
hyper = Hyperactive(verbosity=False) | ||
hyper.add_search( | ||
objective_function_adapter.objective_function, | ||
search_space=self.params_config, | ||
optimizer=self.optimizer, | ||
n_iter=self.n_iter, | ||
n_jobs=self.n_jobs, | ||
random_state=self.random_state, | ||
) | ||
hyper.run() | ||
|
||
if self.refit: | ||
self._refit(X, y, **params) | ||
|
||
return self | ||
|
||
def score(self, X, y=None, **params): | ||
return self.scorer_(self.best_estimator_, X, y, **params) |
36 changes: 36 additions & 0 deletions
36
src/hyperactive/integrations/sklearn/objective_function_adapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Author: Simon Blanke | ||
# Email: simon.blanke@yahoo.com | ||
# License: MIT License | ||
|
||
|
||
from sklearn.model_selection import cross_validate | ||
from sklearn.utils.validation import _num_samples | ||
|
||
|
||
class ObjectiveFunctionAdapter: | ||
def __init__(self, estimator) -> None: | ||
self.estimator = estimator | ||
|
||
def add_dataset(self, X, y): | ||
self.X = X | ||
self.y = y | ||
|
||
def add_validation(self, scoring, cv): | ||
self.scoring = scoring | ||
self.cv = cv | ||
|
||
def objective_function(self, params): | ||
cv_results = cross_validate( | ||
self.estimator, | ||
self.X, | ||
self.y, | ||
cv=self.cv, | ||
) | ||
|
||
add_info_d = { | ||
"score_time": cv_results["score_time"], | ||
"fit_time": cv_results["fit_time"], | ||
"n_test_samples": _num_samples(self.X), | ||
} | ||
|
||
return cv_results["test_score"].mean(), add_info_d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Author: Simon Blanke | ||
# Email: simon.blanke@yahoo.com | ||
# License: MIT License | ||
|
||
|
||
from sklearn.utils.validation import ( | ||
indexable, | ||
_check_method_params, | ||
check_is_fitted, | ||
) | ||
|
||
# NOTE Implementations of following methods from: | ||
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/model_selection/_search.py | ||
# Tag: 1.5.1 | ||
|
||
|
||
def _check_refit(search_cv, attr): | ||
if not search_cv.refit: | ||
raise AttributeError( | ||
f"This {type(search_cv).__name__} instance was initialized with " | ||
f"`refit=False`. {attr} is available only after refitting on the best " | ||
"parameters. You can refit an estimator manually using the " | ||
"`best_params_` attribute" | ||
) | ||
|
||
|
||
def _estimator_has(attr): | ||
def check(self): | ||
_check_refit(self, attr) | ||
if hasattr(self, "best_estimator_"): | ||
# raise an AttributeError if `attr` does not exist | ||
getattr(self.best_estimator_, attr) | ||
return True | ||
# raise an AttributeError if `attr` does not exist | ||
getattr(self.estimator, attr) | ||
return True | ||
|
||
return check |
Empty file.
Empty file.
16 changes: 16 additions & 0 deletions
16
tests/integrations/sklearn/test_parametrize_with_checks.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from sklearn import svm | ||
|
||
from hyperactive.integrations import HyperactiveSearchCV | ||
from hyperactive.optimizers import RandomSearchOptimizer | ||
|
||
from sklearn.utils.estimator_checks import parametrize_with_checks | ||
|
||
|
||
svc = svm.SVC() | ||
parameters = {"kernel": ["linear", "rbf"], "C": [1, 10]} | ||
opt = RandomSearchOptimizer() | ||
|
||
|
||
@parametrize_with_checks([HyperactiveSearchCV(svc, opt, parameters)]) | ||
def test_estimators(estimator, check): | ||
check(estimator) |
Oops, something went wrong.