Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement standardization of static features #96

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#66](https://github.com/mllam/neural-lam/pull/66)
@leifdenby @sadamov

- Implement standardization of static features when loaded in ARModel [\#96](https://github.com/mllam/neural-lam/pull/96) @joeloskarsson

### Fixed

- Fix bugs introduced with datastores functionality relating visualation plots [\#91](https://github.com/mllam/neural-lam/pull/91) @leifdenby
Expand Down
31 changes: 30 additions & 1 deletion neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,36 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
pass

def _standardize_datarray(
self, da: xr.DataArray, category: str
) -> xr.DataArray:
"""
Helper function to standardize a dataarray before returning it.

Parameters
----------
da: xr.DataArray
The dataarray to standardize
category : str
The category of the dataarray (state/forcing/static), to load
standardization statistics for.

Returns
-------
xr.Dataarray
The standardized dataarray
"""

standard_da = self.get_standardization_dataarray(category=category)

mean = standard_da[f"{category}_mean"]
std = standard_da[f"{category}_std"]

return (da - mean) / std

@abc.abstractmethod
def get_dataarray(
self, category: str, split: str
self, category: str, split: str, standardize: bool = False
) -> Union[xr.DataArray, None]:
"""
Return the processed data (as a single `xr.DataArray`) for the given
Expand Down Expand Up @@ -219,6 +246,8 @@ def get_dataarray(
The category of the dataset (state/forcing/static).
split : str
The time split to filter the dataset (train/val/test).
standardize: bool
If the dataarray should be returned standardized

Returns
-------
Expand Down
13 changes: 11 additions & 2 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def get_num_data_vars(self, category: str) -> int:
"""
return len(self.get_vars_names(category))

def get_dataarray(self, category: str, split: str) -> xr.DataArray:
def get_dataarray(
self, category: str, split: str, standardize: bool = False
) -> xr.DataArray:
"""
Return the processed data (as a single `xr.DataArray`) for the given
category of data and test/train/val-split that covers all the data (in
Expand Down Expand Up @@ -246,6 +248,8 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
The category of the dataset (state/forcing/static).
split : str
The time split to filter the dataset (train/val/test).
standardize: bool
If the dataarray should be returned standardized

Returns
-------
Expand Down Expand Up @@ -283,7 +287,12 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
da_category = da_category.sel(time=slice(t_start, t_end))

dim_order = self.expected_dim_order(category=category)
return da_category.transpose(*dim_order)
da_category = da_category.transpose(*dim_order)

if standardize:
return self._standardize_datarray(da_category, category=category)

return da_category

def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
Expand Down
9 changes: 8 additions & 1 deletion neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def config(self) -> NpyDatastoreConfig:
"""
return self._config

def get_dataarray(self, category: str, split: str) -> DataArray:
def get_dataarray(
self, category: str, split: str, standardize: bool = False
) -> DataArray:
"""
Get the data array for the given category and split of data. If the
category is 'state', the data array will be a concatenation of the data
Expand All @@ -214,6 +216,8 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
split : str
The dataset split to load the data for. One of 'train', 'val', or
'test'.
standardize: bool
If the dataarray should be returned standardized

Returns
-------
Expand Down Expand Up @@ -303,6 +307,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
dim_order = self.expected_dim_order(category=category)
da = da.transpose(*dim_order)

if standardize:
return self._standardize_datarray(da, category=category)

return da

def _get_single_timeseries_dataarray(
Expand Down
11 changes: 4 additions & 7 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def __init__(
self._datastore = datastore
num_state_vars = datastore.get_num_data_vars(category="state")
num_forcing_vars = datastore.get_num_data_vars(category="forcing")
# Load static features standardized
da_static_features = datastore.get_dataarray(
category="static", split=None
category="static", split=None, standardize=True
)
da_state_stats = datastore.get_standardization_dataarray(
category="state"
Expand All @@ -49,14 +50,10 @@ def __init__(
num_past_forcing_steps = args.num_past_forcing_steps
num_future_forcing_steps = args.num_future_forcing_steps

# Load static features for grid/data, NB: self.predict_step assumes
# dimension order to be (grid_index, static_feature)
arr_static = da_static_features.transpose(
"grid_index", "static_feature"
).values
# Load static features for grid/data,
self.register_buffer(
"grid_static_features",
torch.tensor(arr_static, dtype=torch.float32),
torch.tensor(da_static_features.values, dtype=torch.float32),
persistent=False,
)

Expand Down
12 changes: 10 additions & 2 deletions tests/dummy_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
return ds_standardization

def get_dataarray(
self, category: str, split: str
self, category: str, split: str, standardize: bool = False
) -> Union[xr.DataArray, None]:
"""
Return the processed data (as a single `xr.DataArray`) for the given
Expand Down Expand Up @@ -332,6 +332,8 @@ def get_dataarray(
The category of the dataset (state/forcing/static).
split : str
The time split to filter the dataset (train/val/test).
standardize: bool
If the dataarray should be returned standardized

Returns
-------
Expand All @@ -340,7 +342,13 @@ def get_dataarray(

"""
dim_order = self.expected_dim_order(category=category)
return self.ds[category].transpose(*dim_order)

da_category = self.ds[category].transpose(*dim_order)

if standardize:
return self._standardize_datarray(da_category, category=category)

return da_category

@cached_property
def boundary_mask(self) -> xr.DataArray:
Expand Down
Loading