diff --git a/pyproject.toml b/pyproject.toml index 184b131..7eafb9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ cv = [ ] cv_classification = [ "timm", + "scikit-plot", "grad-cam>=1.4.5" ] cv_semantic = [ @@ -85,7 +86,6 @@ tabular = [ "pandarallel>=1.6.3", "numpy>=1.23.4", "scikit-learn>=1.0.0", - "scikit-plot", "scipy>=1.7.0", "optuna>=3.0.5", "psycopg2-binary>=2.9.5", @@ -96,7 +96,8 @@ tabular_classification = [ "xgboost>=1.7.1", "catboost", "shap>=0.41.0", - "lime>=0.2.0.1" + "lime>=0.2.0.1", + "scikit-plot", ] all = [ "theseus[cv,cv_classification,cv_semantic,cv_detection,nlp,nlp_retrieval,tabular,tabular_classification]", diff --git a/theseus/base/metrics/precision_recall.py b/theseus/base/metrics/precision_recall.py index 05e327a..2852e2f 100644 --- a/theseus/base/metrics/precision_recall.py +++ b/theseus/base/metrics/precision_recall.py @@ -1,6 +1,5 @@ from typing import Any, Dict -from scikitplot.metrics import plot_precision_recall_curve from sklearn.metrics import precision_score, recall_score from theseus.base.metrics.metric_template import Metric diff --git a/theseus/base/metrics/roc_auc_score.py b/theseus/base/metrics/roc_auc_score.py index 836b1a4..724f71a 100644 --- a/theseus/base/metrics/roc_auc_score.py +++ b/theseus/base/metrics/roc_auc_score.py @@ -1,7 +1,13 @@ from typing import Any, Dict import torch -from scikitplot.metrics import plot_precision_recall_curve, plot_roc_curve + +try: + from scikitplot.metrics import plot_precision_recall_curve, plot_roc_curve + + has_scikitplot = True +except: + has_scikitplot = False from sklearn.metrics import roc_auc_score from theseus.base.metrics.metric_template import Metric @@ -41,14 +47,21 @@ def value(self): roc_auc_scr = roc_auc_score( self.targets, self.preds, average=self.average, multi_class=self.label_type ) - roc_curve_fig = plot_roc_curve(self.targets, self.preds).get_figure() - pr_fig = plot_precision_recall_curve(self.targets, self.preds).get_figure() - return { + results = { f"{self.average}-roc_auc_score": roc_auc_scr, - "roc_curve": roc_curve_fig, - "precision_recall_curve": pr_fig, } + if has_scikitplot: + roc_curve_fig = plot_roc_curve(self.targets, self.preds).get_figure() + pr_fig = plot_precision_recall_curve(self.targets, self.preds).get_figure() + results.update( + { + "roc_curve": roc_curve_fig, + "precision_recall_curve": pr_fig, + } + ) + + return results def reset(self): self.targets = []