diff --git a/mlpp_lib/train.py b/mlpp_lib/train.py index d3846c4..9a28530 100644 --- a/mlpp_lib/train.py +++ b/mlpp_lib/train.py @@ -90,7 +90,7 @@ def train( res = model.fit( x=datamodule.train.x, y=datamodule.train.y, - sample_weight=datamodule.train.w, + sample_weight=(datamodule.train.w,) if datamodule.train.w is not None else None, epochs=cfg.get("epochs", 1), validation_data=(datamodule.val.x, datamodule.val.y), callbacks=callbacks,