From 08b5a4db0f86f2403b7ad6a2589c3f941a243939 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Louis=20Poulain--Auz=C3=A9au?= <47986600+louisPoulain@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:33:57 +0200 Subject: [PATCH] Fix NaN bug in validation (#59) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Change from checking for NaNs to checking non finite values (Inf, -Inf, NaN) * Fix - Check for non finite * Remove debugging log * Remove comment for another PR --------- Co-authored-by: Louis Poulain--Auzéau --- mlpp_lib/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlpp_lib/datasets.py b/mlpp_lib/datasets.py index 539a1bb..7d449d2 100644 --- a/mlpp_lib/datasets.py +++ b/mlpp_lib/datasets.py @@ -449,9 +449,9 @@ def drop_nans(self, group_size: int = 1): x, y, w = self._get_copies() event_axes = [self.dims.index(dim) for dim in self.dims if dim != "s"] - mask = da.any(da.isnan(da.from_array(x, name="x")), axis=event_axes) + mask = da.any(~da.isfinite(da.from_array(x, name="x")), axis=event_axes) if y is not None: - mask = mask | da.any(da.isnan(da.from_array(y, name="y")), axis=event_axes) + mask = mask | da.any(~da.isfinite(da.from_array(y, name="y")), axis=event_axes) mask = (~mask).compute() # with grouped samples, nans have to be removed in blocks: