Skip to content

Commit

Permalink
if_delegate_has_method -> available_if (#511)
Browse files Browse the repository at this point in the history
  • Loading branch information
hoffmansc authored Feb 20, 2024
1 parent ab7e52a commit 4b2f8f3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions aif360/sklearn/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone
from sklearn.model_selection import train_test_split
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.metaestimators import available_if

from aif360.sklearn.postprocessing.calibrated_equalized_odds import CalibratedEqualizedOdds
from aif360.sklearn.postprocessing.reject_option_classification import RejectOptionClassifier, RejectOptionClassifierCV
Expand Down Expand Up @@ -132,7 +132,7 @@ def fit(self, X, y, sample_weight=None, **fit_params):
**fit_params)
return self

@if_delegate_has_method('postprocessor_')
@available_if(lambda self: hasattr(self.postprocessor_, "predict"))
def predict(self, X):
"""Predict class labels for the given samples.
Expand All @@ -151,7 +151,7 @@ def predict(self, X):
y_score = pd.DataFrame(y_score, index=X.index).squeeze('columns')
return self.postprocessor_.predict(y_score)

@if_delegate_has_method('postprocessor_')
@available_if(lambda self: hasattr(self.postprocessor_, "predict_proba"))
def predict_proba(self, X):
"""Probability estimates.
Expand All @@ -175,7 +175,7 @@ def predict_proba(self, X):
y_score = pd.DataFrame(y_score, index=X.index).squeeze('columns')
return self.postprocessor_.predict_proba(y_score)

@if_delegate_has_method('postprocessor_')
@available_if(lambda self: hasattr(self.postprocessor_, "predict_log_proba"))
def predict_log_proba(self, X):
"""Log of probability estimates.
Expand All @@ -199,7 +199,7 @@ def predict_log_proba(self, X):
y_score = pd.DataFrame(y_score, index=X.index).squeeze('columns')
return self.postprocessor_.predict_log_proba(y_score)

@if_delegate_has_method('postprocessor_')
@available_if(lambda self: hasattr(self.postprocessor_, "score"))
def score(self, X, y, sample_weight=None):
"""Returns the output of the post-processor's score function on the
given test data and labels.
Expand Down
10 changes: 5 additions & 5 deletions aif360/sklearn/preprocessing/reweighing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone
from sklearn.utils.metaestimators import if_delegate_has_method
from sklearn.utils.metaestimators import available_if
from sklearn.utils.validation import has_fit_parameter

from aif360.sklearn.utils import check_inputs, check_groups
Expand Down Expand Up @@ -153,7 +153,7 @@ def fit(self, X, y, sample_weight=None):
self.estimator_.fit(X, y, sample_weight=sample_weight)
return self

@if_delegate_has_method('estimator_')
@available_if(lambda self: hasattr(self.estimator_, "predict"))
def predict(self, X):
"""Predict class labels for the given samples using ``self.estimator_``.
Expand All @@ -165,7 +165,7 @@ def predict(self, X):
"""
return self.estimator_.predict(X)

@if_delegate_has_method('estimator_')
@available_if(lambda self: hasattr(self.estimator_, "predict_proba"))
def predict_proba(self, X):
"""Probability estimates from ``self.estimator_``.
Expand All @@ -181,7 +181,7 @@ def predict_proba(self, X):
"""
return self.estimator_.predict_proba(X)

@if_delegate_has_method('estimator_')
@available_if(lambda self: hasattr(self.estimator_, "predict_log_proba"))
def predict_log_proba(self, X):
"""Log of probability estimates from ``self.estimator_``.
Expand All @@ -198,7 +198,7 @@ def predict_log_proba(self, X):
"""
return self.estimator_.predict_log_proba(X)

@if_delegate_has_method('estimator_')
@available_if(lambda self: hasattr(self.estimator_, "score"))
def score(self, X, y, sample_weight=None):
"""Returns the output of the estimator's score function on the given
test data and labels.
Expand Down

0 comments on commit 4b2f8f3

Please sign in to comment.