Skip to content

Commit

Permalink
remove old sklearn version of _fit_and_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenKlaassen committed Jun 10, 2024
1 parent f04c959 commit 13255a6
Showing 1 changed file with 5 additions and 20 deletions.
25 changes: 5 additions & 20 deletions doubleml/tests/_utils_dml_cv_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,6 @@
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection._validation import _fit_and_predict, _check_is_permutation

# Adapt _fit_and_predict for earlier sklearn versions
from distutils.version import LooseVersion
from sklearn import __version__ as sklearn_version

if LooseVersion(sklearn_version) < LooseVersion("1.4.0"):
def _fit_and_predict_adapted(estimator, x, y, train, test, fit_params, method):
res = _fit_and_predict(estimator, x, y, train, test,
verbose=0,
fit_params=fit_params,
method=method)
return res
else:
def _fit_and_predict_adapted(estimator, x, y, train, test, fit_params, method):
return _fit_and_predict(estimator, x, y, train, test, fit_params, method)


def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
n_jobs=None, est_params=None, method='predict'):
Expand All @@ -42,12 +27,12 @@ def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
else:
predictions = np.full(len(y), np.nan)
if est_params is None:
xx = _fit_and_predict_adapted(
xx = _fit_and_predict(
clone(estimator),
x, y, train_index, test_index, fit_params, method)
else:
assert isinstance(est_params, dict)
xx = _fit_and_predict_adapted(
xx = _fit_and_predict(
clone(estimator).set_params(**est_params),
x, y, train_index, test_index, fit_params, method)

Expand Down Expand Up @@ -77,20 +62,20 @@ def _dml_cv_predict_ut_version(estimator, x, y, smpls=None,
pre_dispatch=pre_dispatch)
# FixMe: Find a better way to handle the different combinations of paramters and smpls_is_partition
if est_params is None:
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
prediction_blocks = parallel(delayed(_fit_and_predict)(
estimator,
x, y, train_index, test_index, fit_params, method)
for idx, (train_index, test_index) in enumerate(smpls))
elif isinstance(est_params, dict):
# if no fold-specific parameters we redirect to the standard method
# warnings.warn("Using the same (hyper-)parameters for all folds")
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
prediction_blocks = parallel(delayed(_fit_and_predict)(
clone(estimator).set_params(**est_params),
x, y, train_index, test_index, fit_params, method)
for idx, (train_index, test_index) in enumerate(smpls))
else:
assert len(est_params) == len(smpls), 'provide one parameter setting per fold'
prediction_blocks = parallel(delayed(_fit_and_predict_adapted)(
prediction_blocks = parallel(delayed(_fit_and_predict)(
clone(estimator).set_params(**est_params[idx]),
x, y, train_index, test_index, fit_params, method)
for idx, (train_index, test_index) in enumerate(smpls))
Expand Down

0 comments on commit 13255a6

Please sign in to comment.