diff --git a/CHANGELOG.md b/CHANGELOG.md index 01d4cac9..b2451c47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#66](https://github.com/mllam/neural-lam/pull/66) @leifdenby @sadamov +- Implement standardization of static features when loaded in ARModel [\#96](https://github.com/mllam/neural-lam/pull/96) @joeloskarsson + ### Fixed - Fix bugs introduced with datastores functionality relating visualation plots [\#91](https://github.com/mllam/neural-lam/pull/91) @leifdenby diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index b0055e39..f0291657 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -186,9 +186,36 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ pass + def _standardize_datarray( + self, da: xr.DataArray, category: str + ) -> xr.DataArray: + """ + Helper function to standardize a dataarray before returning it. + + Parameters + ---------- + da: xr.DataArray + The dataarray to standardize + category : str + The category of the dataarray (state/forcing/static), to load + standardization statistics for. + + Returns + ------- + xr.Dataarray + The standardized dataarray + """ + + standard_da = self.get_standardization_dataarray(category=category) + + mean = standard_da[f"{category}_mean"] + std = standard_da[f"{category}_std"] + + return (da - mean) / std + @abc.abstractmethod def get_dataarray( - self, category: str, split: str + self, category: str, split: str, standardize: bool = False ) -> Union[xr.DataArray, None]: """ Return the processed data (as a single `xr.DataArray`) for the given @@ -219,6 +246,8 @@ def get_dataarray( The category of the dataset (state/forcing/static). split : str The time split to filter the dataset (train/val/test). + standardize: bool + If the dataarray should be returned standardized Returns ------- diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 0d1aac7b..0b6bb5e4 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -218,7 +218,9 @@ def get_num_data_vars(self, category: str) -> int: """ return len(self.get_vars_names(category)) - def get_dataarray(self, category: str, split: str) -> xr.DataArray: + def get_dataarray( + self, category: str, split: str, standardize: bool = False + ) -> xr.DataArray: """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in @@ -246,6 +248,8 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: The category of the dataset (state/forcing/static). split : str The time split to filter the dataset (train/val/test). + standardize: bool + If the dataarray should be returned standardized Returns ------- @@ -283,7 +287,12 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: da_category = da_category.sel(time=slice(t_start, t_end)) dim_order = self.expected_dim_order(category=category) - return da_category.transpose(*dim_order) + da_category = da_category.transpose(*dim_order) + + if standardize: + return self._standardize_datarray(da_category, category=category) + + return da_category def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 42e80706..8f926f7e 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -199,7 +199,9 @@ def config(self) -> NpyDatastoreConfig: """ return self._config - def get_dataarray(self, category: str, split: str) -> DataArray: + def get_dataarray( + self, category: str, split: str, standardize: bool = False + ) -> DataArray: """ Get the data array for the given category and split of data. If the category is 'state', the data array will be a concatenation of the data @@ -214,6 +216,8 @@ def get_dataarray(self, category: str, split: str) -> DataArray: split : str The dataset split to load the data for. One of 'train', 'val', or 'test'. + standardize: bool + If the dataarray should be returned standardized Returns ------- @@ -303,6 +307,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: dim_order = self.expected_dim_order(category=category) da = da.transpose(*dim_order) + if standardize: + return self._standardize_datarray(da, category=category) + return da def _get_single_timeseries_dataarray( diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 44baf9c2..754cfb3a 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -39,8 +39,9 @@ def __init__( self._datastore = datastore num_state_vars = datastore.get_num_data_vars(category="state") num_forcing_vars = datastore.get_num_data_vars(category="forcing") + # Load static features standardized da_static_features = datastore.get_dataarray( - category="static", split=None + category="static", split=None, standardize=True ) da_state_stats = datastore.get_standardization_dataarray( category="state" @@ -49,14 +50,10 @@ def __init__( num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps - # Load static features for grid/data, NB: self.predict_step assumes - # dimension order to be (grid_index, static_feature) - arr_static = da_static_features.transpose( - "grid_index", "static_feature" - ).values + # Load static features for grid/data, self.register_buffer( "grid_static_features", - torch.tensor(arr_static, dtype=torch.float32), + torch.tensor(da_static_features.values, dtype=torch.float32), persistent=False, ) diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py index 9075d404..0c76bca8 100644 --- a/tests/dummy_datastore.py +++ b/tests/dummy_datastore.py @@ -301,7 +301,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: return ds_standardization def get_dataarray( - self, category: str, split: str + self, category: str, split: str, standardize: bool = False ) -> Union[xr.DataArray, None]: """ Return the processed data (as a single `xr.DataArray`) for the given @@ -332,6 +332,8 @@ def get_dataarray( The category of the dataset (state/forcing/static). split : str The time split to filter the dataset (train/val/test). + standardize: bool + If the dataarray should be returned standardized Returns ------- @@ -340,7 +342,13 @@ def get_dataarray( """ dim_order = self.expected_dim_order(category=category) - return self.ds[category].transpose(*dim_order) + + da_category = self.ds[category].transpose(*dim_order) + + if standardize: + return self._standardize_datarray(da_category, category=category) + + return da_category @cached_property def boundary_mask(self) -> xr.DataArray: