From 1b8ea151df33c7b383886b87cc2c28f2af23e2e5 Mon Sep 17 00:00:00 2001 From: ned Date: Fri, 31 Mar 2023 09:12:15 +0200 Subject: [PATCH] Fix performance issue with sample weights See https://github.com/keras-team/keras/pull/17357 --- mlpp_lib/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlpp_lib/train.py b/mlpp_lib/train.py index d3846c4..9a28530 100644 --- a/mlpp_lib/train.py +++ b/mlpp_lib/train.py @@ -90,7 +90,7 @@ def train( res = model.fit( x=datamodule.train.x, y=datamodule.train.y, - sample_weight=datamodule.train.w, + sample_weight=(datamodule.train.w,) if datamodule.train.w is not None else None, epochs=cfg.get("epochs", 1), validation_data=(datamodule.val.x, datamodule.val.y), callbacks=callbacks,