Skip to content

Commit

Permalink
Fix Dataset for inference (#28)
Browse files Browse the repository at this point in the history
* Convert predictions without mask

* Drop samples for any occurrence of nans

* Allow inference only mode
  • Loading branch information
dnerini authored Apr 3, 2023
1 parent 1b8ea15 commit d58035c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 11 deletions.
34 changes: 23 additions & 11 deletions mlpp_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,13 +421,14 @@ def _as_variables(self) -> tuple[xr.Variable, ...]:
def drop_nans(self):
"""Drop incomplete samples and return a new `Dataset` with a mask."""
if not self._is_stacked:
raise ValueError("Dataset shoud be stacked before dropping samples.")
raise ValueError("Dataset should be stacked before dropping samples.")

x, y, w = self._get_copies()

mask = da.any(da.isnan(da.from_array(x, name="x")), axis=-1)
event_axes = [self.dims.index(dim) for dim in self.dims if dim != "s"]
mask = da.any(da.isnan(da.from_array(x, name="x")), axis=event_axes)
if y is not None:
mask = mask | da.any(da.isnan(da.from_array(y, name="y")), axis=-1)
mask = mask | da.any(da.isnan(da.from_array(y, name="y")), axis=event_axes)
mask = (~mask).compute()

x = x[mask]
Expand Down Expand Up @@ -472,23 +473,34 @@ def get_multiindex(self) -> pd.MultiIndex:
)

def dataset_from_predictions(
self, preds: np.ndarray, ensemble_axis=None
self,
preds: np.ndarray,
ensemble_axis: Optional[int] = None,
targets: Optional[Sequence[Hashable]] = None,
) -> xr.Dataset:
if not self._is_stacked:
raise ValueError("Dataset should be stacked first.")
if self.targets is None and targets is None:
raise ValueError("Please specify argument 'targets'")
else:
targets = targets or self.targets
event_shape = [
len(c) for dim, c in self.coords.items() if dim not in self.batch_dims
]
full_shape = [self.mask.shape[0], *event_shape, len(self.targets)]
full_shape = [self.x.shape[0], *event_shape, len(targets)]
dims = list(self.dims)
coords = self.coords | {"v": self.targets}
coords = self.coords | {"v": targets}
if ensemble_axis is not None:
full_shape.insert(ensemble_axis, preds.shape[ensemble_axis])
dims.insert(ensemble_axis, "realization")
coords = coords | {"realization": np.arange(preds.shape[ensemble_axis])}
out = np.full(full_shape, fill_value=np.nan)
# out[self.mask] = preds
out = xr.Variable(dims, out)
out[{"s": self.mask}] = preds
out = out.unstack(s={dim: len(coord) for dim, coord in self.coords.items()})
if self.mask is not None:
out = np.full(full_shape, fill_value=np.nan)
out = xr.Variable(dims, out)
out[{"s": self.mask}] = preds
else:
out = xr.Variable(dims, preds)
out = out.unstack(s={dim: len(self.coords[dim]) for dim in self.batch_dims})
out = xr.DataArray(out, coords=coords)
return out.to_dataset("v")

Expand Down
41 changes: 41 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,3 +260,44 @@ def test_drop_nans_only_x(self, dataset_only_x, dims, coords, features):
assert len(ds.mask) == n_samples
assert ds.dims == ["s", *event_dims, "v"]
assert ds.coords.keys() == coords.keys()

@pytest.mark.parametrize(
"batch_dims",
[
("forecast_reference_time", "t", "station"),
("forecast_reference_time", "station"),
],
ids=lambda x: repr(x),
)
def test_dataset_from_predictions(self, dataset, batch_dims):
n_samples = 3
ds = dataset.stack(batch_dims)
ds = ds.drop_nans()
predictions = np.random.randn(n_samples, *ds.y.shape)
ds_pred = ds.dataset_from_predictions(predictions, ensemble_axis=0)
assert isinstance(ds_pred, xr.Dataset)
assert ds_pred.dims["realization"] == n_samples
assert all([ds_pred.dims[c] == ds.coords[c].size for c in ds.coords])
assert list(ds_pred.data_vars) == ds.targets

@pytest.mark.parametrize(
"batch_dims",
[
("forecast_reference_time", "t", "station"),
("forecast_reference_time", "station"),
],
ids=lambda x: repr(x),
)
def test_dataset_from_predictions_only_x(self, dataset_only_x, batch_dims):
n_samples = 3
targets = ["obs:y1", "obs:y2"]
ds = dataset_only_x.stack(batch_dims)
# Note that here we do not drop nan, hence the mask is not created!
predictions = np.random.randn(n_samples, *ds.x.shape[:-1], len(targets))
ds_pred = ds.dataset_from_predictions(
predictions, ensemble_axis=0, targets=targets
)
assert isinstance(ds_pred, xr.Dataset)
assert ds_pred.dims["realization"] == n_samples
assert all([ds_pred.dims[c] == ds.coords[c].size for c in ds.coords])
assert list(ds_pred.data_vars) == targets

0 comments on commit d58035c

Please sign in to comment.