Skip to content

Commit

Permalink
🔧 fix small bug
Browse files Browse the repository at this point in the history
  • Loading branch information
kaylode committed May 4, 2023
1 parent 86c4af3 commit aabaae6
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cv = [
]
cv_classification = [
"timm",
"scikit-plot",
"grad-cam>=1.4.5"
]
cv_semantic = [
Expand Down Expand Up @@ -85,7 +86,6 @@ tabular = [
"pandarallel>=1.6.3",
"numpy>=1.23.4",
"scikit-learn>=1.0.0",
"scikit-plot",
"scipy>=1.7.0",
"optuna>=3.0.5",
"psycopg2-binary>=2.9.5",
Expand All @@ -96,7 +96,8 @@ tabular_classification = [
"xgboost>=1.7.1",
"catboost",
"shap>=0.41.0",
"lime>=0.2.0.1"
"lime>=0.2.0.1",
"scikit-plot",
]
all = [
"theseus[cv,cv_classification,cv_semantic,cv_detection,nlp,nlp_retrieval,tabular,tabular_classification]",
Expand Down
1 change: 0 additions & 1 deletion theseus/base/metrics/precision_recall.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Dict

from scikitplot.metrics import plot_precision_recall_curve
from sklearn.metrics import precision_score, recall_score

from theseus.base.metrics.metric_template import Metric
Expand Down
25 changes: 19 additions & 6 deletions theseus/base/metrics/roc_auc_score.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from typing import Any, Dict

import torch
from scikitplot.metrics import plot_precision_recall_curve, plot_roc_curve

try:
from scikitplot.metrics import plot_precision_recall_curve, plot_roc_curve

has_scikitplot = True
except:
has_scikitplot = False
from sklearn.metrics import roc_auc_score

from theseus.base.metrics.metric_template import Metric
Expand Down Expand Up @@ -41,14 +47,21 @@ def value(self):
roc_auc_scr = roc_auc_score(
self.targets, self.preds, average=self.average, multi_class=self.label_type
)
roc_curve_fig = plot_roc_curve(self.targets, self.preds).get_figure()
pr_fig = plot_precision_recall_curve(self.targets, self.preds).get_figure()

return {
results = {
f"{self.average}-roc_auc_score": roc_auc_scr,
"roc_curve": roc_curve_fig,
"precision_recall_curve": pr_fig,
}
if has_scikitplot:
roc_curve_fig = plot_roc_curve(self.targets, self.preds).get_figure()
pr_fig = plot_precision_recall_curve(self.targets, self.preds).get_figure()
results.update(
{
"roc_curve": roc_curve_fig,
"precision_recall_curve": pr_fig,
}
)

return results

def reset(self):
self.targets = []
Expand Down

0 comments on commit aabaae6

Please sign in to comment.