diff --git a/drevalpy/models/simple_neural_network/utils.py b/drevalpy/models/simple_neural_network/utils.py index 3868462..da4d3e5 100644 --- a/drevalpy/models/simple_neural_network/utils.py +++ b/drevalpy/models/simple_neural_network/utils.py @@ -155,6 +155,7 @@ def fit( shuffle=True, num_workers=num_workers, persistent_workers=True, + drop_last=True, ) val_loader = None