Skip to content

Commit

Permalink
🐛 Fix xgboost error
Browse files Browse the repository at this point in the history
  • Loading branch information
kaylode committed Nov 4, 2023
1 parent 4fe8224 commit 1b1ea74
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ ml = [
"psycopg2-binary>=2.9.5",
"gunicorn>=20.1.0",
"lightgbm>=3.3.3",
"xgboost>=1.7.1",
"xgboost<=1.7.1",
"catboost",
"shap>=0.41.0",
"lime>=0.2.0.1",
Expand Down
2 changes: 1 addition & 1 deletion tests/classification/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def inference(self):

for idx, batch in enumerate(tqdm(self.dataloader)):
img_names = batch["img_names"]
outputs = self.model.get_prediction(batch)
outputs = self.model.predict_step(batch)
preds = outputs["names"]
probs = outputs["confidences"]

Expand Down
2 changes: 1 addition & 1 deletion tests/semantic/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def inference(self):
img_names = batch["img_names"]
ori_sizes = batch["ori_sizes"]

outputs = self.model.get_prediction(batch)
outputs = self.model.predict_step(batch)
preds = outputs["masks"]

for (inpt, pred, filename, ori_size) in zip(
Expand Down
10 changes: 5 additions & 5 deletions tests/tabular/test_tablr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def test_train_tblr(override_config):
train_pipeline.fit()


# @pytest.mark.order(2)
# def test_eval_tblr(override_config):
# override_config["global"]["pretrained"] = "runs/pytest_tablr/checkpoints/last"
# val_pipeline = MLPipeline(override_config)
# val_pipeline.evaluate()
@pytest.mark.order(2)
def test_eval_tblr(override_config):
override_config["global"]["pretrained"] = "runs/pytest_tablr/checkpoints/last"
val_pipeline = MLPipeline(override_config)
val_pipeline.evaluate()


# @pytest.mark.order(2)
Expand Down
2 changes: 1 addition & 1 deletion theseus/base/models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def validation_step(self, batch, batch_idx):
self.log_dict(outputs["loss_dict"], prog_bar=True, on_step=True, on_epoch=False)
return outputs

def predict_step(self, batch, batch_idx):
def predict_step(self, batch, batch_idx=None):
pred = self.model.get_prediction(batch)
return pred

Expand Down

3 comments on commit 1b1ea74

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tabular Classification Report

Metrics

SHAP train

SHAP

SHAP val

SHAP

Hyperparameters Tuning

Leaderboard

number value datetime_start datetime_complete duration params_gamma params_learning_rate params_max_depth params_n_estimators params_reg_alpha params_reg_lambda best_key model_name state
0 0.7755790944 1699094300961 1699094301099 138 0.7945736611 0.8777502647 4 160 0.1489015551 0.7730128265 bl_acc xgboost COMPLETE
1 0.8319171226 1699094301100 1699094301251 150 0.432788513 0.8145506214 5 182 0.9873357615 0.4220072194 bl_acc xgboost COMPLETE
2 0.8037481085 1699094301252 1699094301423 170 0.9391836415 0.3691424918 5 409 0.2298818258 0.9618219432 bl_acc xgboost COMPLETE

Figures
History
Contour plot
Parallel
Importance
Slice

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semantic Report

Metrics

precision recall dice iters
0.40056 1 0.57195 0

Prediction

Prediction
Prediction
Prediction

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Classification Report

Metrics

acc bl_acc weighted-f1 iters
0.25 0.25 0.2 0

Confusion Matrix

Confusion Matrix

Errorcases

Confusion Matrix

Hyperparameters Tuning

Leaderboard

number value datetime_start datetime_complete duration params_optimizer.args.lr best_key model_name state
0 1 1699094301030 1699094352745 51715 0.0008913358 bl_acc efficientnet_b0 COMPLETE
1 0.75 1699094352746 1699094407028 54281 0.0003423306 bl_acc efficientnet_b0 COMPLETE

Please sign in to comment.