diff --git a/.github/workflows/ci-pdm-install-and-test-cpu.yml b/.github/workflows/ci-pdm-install-and-test-cpu.yml
index c5da88cc..8fb4df79 100644
--- a/.github/workflows/ci-pdm-install-and-test-cpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-cpu.yml
@@ -39,17 +39,17 @@ jobs:
- name: Load cache data
uses: actions/cache/restore@v4
with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0
restore-keys: |
- ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ ${{ runner.os }}-meps-reduced-example-data-v0.2.0
- name: Run tests
run: |
- pdm run pytest
+ pdm run pytest -vv -s
- name: Save cache data
uses: actions/cache/save@v4
with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0
diff --git a/.github/workflows/ci-pdm-install-and-test-gpu.yml b/.github/workflows/ci-pdm-install-and-test-gpu.yml
index 9ab4f379..43a701c2 100644
--- a/.github/workflows/ci-pdm-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pdm-install-and-test-gpu.yml
@@ -44,17 +44,17 @@ jobs:
- name: Load cache data
uses: actions/cache/restore@v4
with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0
restore-keys: |
- ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ ${{ runner.os }}-meps-reduced-example-data-v0.2.0
- name: Run tests
run: |
- pdm run pytest
+ pdm run pytest -vv -s
- name: Save cache data
uses: actions/cache/save@v4
with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0
diff --git a/.github/workflows/ci-pip-install-and-test-cpu.yml b/.github/workflows/ci-pip-install-and-test-cpu.yml
index 81e402c5..b131596d 100644
--- a/.github/workflows/ci-pip-install-and-test-cpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-cpu.yml
@@ -29,17 +29,17 @@ jobs:
- name: Load cache data
uses: actions/cache/restore@v4
with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0
restore-keys: |
- ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ ${{ runner.os }}-meps-reduced-example-data-v0.2.0
- name: Run tests
run: |
- python -m pytest
+ python -m pytest -vv -s
- name: Save cache data
uses: actions/cache/save@v4
with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0
diff --git a/.github/workflows/ci-pip-install-and-test-gpu.yml b/.github/workflows/ci-pip-install-and-test-gpu.yml
index ce68946a..3afcca5a 100644
--- a/.github/workflows/ci-pip-install-and-test-gpu.yml
+++ b/.github/workflows/ci-pip-install-and-test-gpu.yml
@@ -34,17 +34,17 @@ jobs:
- name: Load cache data
uses: actions/cache/restore@v4
with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0
restore-keys: |
- ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ ${{ runner.os }}-meps-reduced-example-data-v0.2.0
- name: Run tests
run: |
- python -m pytest
+ python -m pytest -vv -s
- name: Save cache data
uses: actions/cache/save@v4
with:
- path: data
- key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0
+ path: tests/datastore_examples/npyfilesmeps/meps_example_reduced.zip
+ key: ${{ runner.os }}-meps-reduced-example-data-v0.2.0
diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml
index ad2b1a9c..71e28ad7 100644
--- a/.github/workflows/pre-commit.yml
+++ b/.github/workflows/pre-commit.yml
@@ -13,7 +13,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ["3.9", "3.10", "3.11", "3.12"]
+ python-version: ["3.9", "3.10", "3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python
diff --git a/.gitignore b/.gitignore
index 022206f5..fdb51d3d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,14 +1,14 @@
### Project Specific ###
wandb
-slurm_log*
saved_models
lightning_logs
data
graphs
*.sif
sweeps
-test_*.sh
.vscode
+*.html
+*.zarr
*slurm*
### Python ###
@@ -75,8 +75,15 @@ tags
# Coc configuration directory
.vim
+.vscode
+
+# macos
+.DS_Store
+__MACOSX
# pdm (https://pdm-project.org/en/stable/)
.pdm-python
+.venv
+
# exclude pdm.lock file so that both cpu and gpu versions of torch will be accepted by pdm
pdm.lock
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 18cf5d4d..12cf54f6 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.2.0...HEAD)
+### Added
+
+- Introduce Datastores to represent input data from different sources, including zarr and numpy.
+ [\#66](https://github.com/mllam/neural-lam/pull/66)
+ @leifdenby @sadamov
+
## [v0.2.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.2.0)
### Added
diff --git a/README.md b/README.md
index 416f7e8c..e21b7c24 100644
--- a/README.md
+++ b/README.md
@@ -63,18 +63,7 @@ Still, some restrictions are inevitable:
-## A note on the limited area setting
-Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](#graph-based-neural-weather-prediction-for-limited-area-modeling)).
-There are still some parts of the code that is quite specific for the MEPS area use case.
-This is in particular true for the mesh graph creation (`python -m neural_lam.create_mesh`) and some of the constants set in a `data_config.yaml` file (path specified in `python -m neural_lam.train_model --data_config ` ).
-There is ongoing efforts to refactor the code to be fully area-agnostic.
-See issues [4](https://github.com/mllam/neural-lam/issues/4) and [24](https://github.com/mllam/neural-lam/issues/24) for more about this.
-See also the [weather-model-graphs](https://github.com/mllam/weather-model-graphs) package for constructing graphs for arbitrary areas.
-
-# Using Neural-LAM
-Below follows instructions on how to use Neural-LAM to train and evaluate models.
-
-## Installation
+# Installing Neural-LAM
When installing `neural-lam` you have a choice of either installing with
directly `pip` or using the `pdm` package manager.
@@ -91,7 +80,7 @@ expects the most recent version of CUDA on your system.
We cover all the installation options in our [github actions ci/cd
setup](.github/workflows/) which you can use as a reference.
-### Using `pdm`
+## Using `pdm`
1. Clone this repository and navigate to the root directory.
2. Install `pdm` if you don't have it installed on your system (either with `pip install pdm` or [following the install instructions](https://pdm-project.org/latest/#installation)).
@@ -100,7 +89,7 @@ setup](.github/workflows/) which you can use as a reference.
4. Install a specific version of `torch` with `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cpu` for a CPU-only version or `pdm run python -m pip install torch --index-url https://download.pytorch.org/whl/cu111` for CUDA 11.1 support (you can find the correct URL for the variant you want on [PyTorch webpage](https://pytorch.org/get-started/locally/)).
5. Install the dependencies with `pdm install` (by default this in include the). If you will be developing `neural-lam` we recommend to install the development dependencies with `pdm install --group dev`. By default `pdm` installs the `neural-lam` package in editable mode, so you can make changes to the code and see the effects immediately.
-### Using `pip`
+## Using `pip`
1. Clone this repository and navigate to the root directory.
> If you are happy using the latest version of `torch` with GPU support (expecting the latest version of CUDA is installed on your system) you can skip to step 3.
@@ -108,41 +97,291 @@ setup](.github/workflows/) which you can use as a reference.
3. Install the dependencies with `python -m pip install .`. If you will be developing `neural-lam` we recommend to install in editable mode and install the development dependencies with `python -m pip install -e ".[dev]"` so you can make changes to the code and see the effects immediately.
-## Data
-Datasets should be stored in a directory called `data`.
-See the [repository format section](#format-of-data-directory) for details on the directory structure.
+# Using Neural-LAM
+
+Once `neural-lam` is installed you will be able to train/evaluate models. For this you will in general need two things:
+
+1. **Data to train/evaluate the model**. To represent this data we use a concept of
+ *datastores* in Neural-LAM (see the [Data](#data-the-datastore-and-weatherdataset-classes) section for more details).
+ In brief, a datastore implements the process of loading data from disk in a
+ specific format (for example zarr or numpy files) by implementing an
+ interface that provides the data in a data-structure that can be used within
+ neural-lam. A datastore is used to create a `pytorch.Dataset`-derived
+ class that samples the data in time to create individual samples for
+ training, validation and testing.
+
+2. **The graph structure** is used to define message-passing GNN layers,
+ that are trained to emulate fluid flow in the atmosphere over time. The
+ graph structure is created for a specific datastore.
+
+Any command you run in neural-lam will include the path to a configuration file
+to be used (usually called `config.yaml`). This configuration file defines the
+path to the datastore configuration you wish to use and allows you to configure
+different aspects about the training and evaluation of the model.
+
+The path you provide to the neural-lam config (`config.yaml`) also sets the
+root directory relative to which all other paths are resolved, as in the parent
+directory of the config becomes the root directory. Both the datastore and
+graphs you generate are then stored in subdirectories of this root directory.
+Exactly how and where a specific datastore expects its source data to be stored
+and where it stores its derived data is up to the implementation of the
+datastore.
+
+In general the folder structure assumed in Neural-LAM is as follows (we will
+assume you placed `config.yaml` in a folder called `data`):
+
+```
+data/
+├── config.yaml - Configuration file for neural-lam
+├── danra.datastore.yaml - Configuration file for the datastore, referred to from config.yaml
+└── graphs/ - Directory containing graphs for training
+```
+
+And the content of `config.yaml` could in this case look like:
+```yaml
+datastore:
+ kind: mdp
+ config_path: danra.datastore.yaml
+training:
+ state_feature_weighting:
+ __config_class__: ManualStateFeatureWeighting
+ values:
+ u100m: 1.0
+ v100m: 1.0
+```
+
+For now the neural-lam config only defines two things: 1) the kind of data
+store and the path to its config, and 2) the weighting of different features in
+the loss function. If you don't define the state feature weighting it will default
+to weighting all features equally.
+
+(This example is taken from the `tests/datastore_examples/mdp` directory.)
+
+
+Below follows instructions on how to use Neural-LAM to train and evaluate
+models, with details first given for each kind of datastore implemented
+and later the graph generation. Once `neural-lam` has been installed the
+general process is:
+
+1. Run any pre-processing scripts to generate the necessary derived data that your chosen datastore requires
+2. Run graph-creation step
+3. Train the model
+
+## Data (the `DataStore` and `WeatherDataset` classes)
+
+To enable flexibility in what input-data sources can be used with neural-lam,
+the input-data representation is split into two parts:
+
+1. A "datastore" (represented by instances of
+ [neural_lam.datastore.BaseDataStore](neural_lam/datastore/base.py)) which
+ takes care of loading a given category (state, forcing or static) and split
+ (train/val/test) of data from disk and returning it as a `xarray.DataArray`.
+ The returned data-array is expected to have the spatial coordinates
+ flattened into a single `grid_index` dimension and all variables and vertical
+ levels stacked into a feature dimension (named as `{category}_feature`). The
+ datastore also provides information about the number, names and units of
+ variables in the data, the boundary mask, normalisation values and grid
+ information.
+
+2. A `pytorch.Dataset`-derived class (called
+ `neural_lam.weather_dataset.WeatherDataset`) which takes care of sampling in
+ time to create individual samples for training, validation and testing. The
+ `WeatherDataset` class is also responsible for normalising the values and
+ returning `torch.Tensor`-objects.
+
+There are currently two different datastores implemented in the codebase:
+
+1. `neural_lam.datastore.MDPDatastore` which represents loading of
+ *training-ready* datasets in zarr format created with the
+ [mllam-data-prep](https://github.com/mllam/mllam-data-prep) package.
+ Training-ready refers to the fact that this data has been transformed
+ (variables have been stacked, spatial coordinates have been flattened,
+ statistics for normalisation have been calculated, etc) to be ready for
+ training. `mllam-data-prep` can combine any number of datasets that can be
+ read with [xarray](https://github.com/pydata/xarray) and the processing can
+ either be done at run-time or as a pre-processing step before calling
+ neural-lam.
+
+2. `neural_lam.datastore.NpyFilesDatastoreMEPS` which reads MEPS data from
+ `.npy`-files in the format introduced in neural-lam `v0.1.0`. Note that this
+ datastore is specific to the format of the MEPS dataset, but can act as an
+ example for how to create similar numpy-based datastores.
+
+If neither of these options fit your need you can create your own datastore by
+subclassing the `neural_lam.datastore.BaseDataStore` class or
+`neural_lam.datastore.BaseRegularGridDatastore` class (if your data is stored on
+a regular grid) and implementing the abstract methods.
+
+
+### MDP (mllam-data-prep) Datastore - `MDPDatastore`
+
+With `MDPDatastore` (the mllam-data-prep datastore) all the selection,
+transformation and pre-calculation steps that are needed to go from
+for example gridded weather data to a format that is optimised for training
+in neural-lam, are done in a separate package called
+[mllam-data-prep](https://github.com/mllam/mllam-data-prep) rather than in
+neural-lam itself.
+Specifically, the `mllam-data-prep` datastore configuration (for example
+[danra.datastore.yaml](tests/datastore_examples/mdp/danra.datastore.yaml))
+specifies a) what source datasets to read from, b) what variables to select, c)
+what transformations of dimensions and variables to make, d) what statistics to
+calculate (for normalisation) and e) how to split the data into training,
+validation and test sets (see full details about the configuration specification
+in the [mllam-data-prep README](https://github.com/mllam/mllam-data-prep)).
+
+From a datastore configuration `mllam-data-prep` returns the transformed
+dataset as an `xr.Dataset` which is then written in zarr-format to disk by
+`neural-lam` when the datastore is first initiated (the path of the dataset is
+derived from the datastore config, so that from a config named `danra.datastore.yaml` the resulting dataset is stored in `danra.datastore.zarr`).
+You can also run `mllam-data-prep` directly to create the processed dataset by providing the path to the datastore configuration file:
+
+```bash
+python -m mllam_data_prep --config data/danra.datastore.yaml
+```
+
+If you will be working on a large dataset (on the order of 10GB or more) it
+could be beneficial to produce the processed `.zarr` dataset before using it
+in neural-lam so that you can do the processing across multiple CPU cores in parallel. This is done by including the `--dask-distributed-local-core-fraction` argument when calling mllam-data-prep to set the fraction of your system's CPU cores that should be used for processing (see the
+[mllam-data-prep
+README for details](https://github.com/mllam/mllam-data-prep?tab=readme-ov-file#creating-large-datasets-with-daskdistributed)).
+
+For example:
+
+```bash
+python -m mllam_data_prep --config data/danra.datastore.yaml --dask-distributed-local-core-fraction 0.5
+```
+
+### NpyFiles MEPS Datastore - `NpyFilesDatastoreMEPS`
+
+Version `v0.1.0` of Neural-LAM was built to train from numpy-files from the
+MEPS weather forecasts dataset.
+To enable this functionality to live on in later versions of neural-lam we have
+built a datastore called `NpyFilesDatastoreMEPS` which implements functionality
+to read from these exact same numpy-files. At this stage this datastore class
+is very much tied to the MEPS dataset, but the code is written in a way where
+it quite easily could be adapted to work with numpy-based weather
+forecast/analysis files in future.
The full MEPS dataset can be shared with other researchers on request, contact us for this.
-A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX).
+A tiny subset of the data (named `meps_example`) is available in
+`example_data.zip`, which can be downloaded from
+[here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX).
+
Download the file and unzip in the neural-lam directory.
-Graphs used in the initial paper are also available for download at the same link (but can as easily be re-generated using `python -m neural_lam.create_mesh`).
+Graphs used in the initial paper are also available for download at the same link (but can as easily be re-generated using `python -m neural_lam.create_graph`).
Note that this is far too little data to train any useful models, but all pre-processing and training steps can be run with it.
It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues.
-## Pre-processing
-An overview of how the different pre-processing steps, training and files depend on each other is given in this figure:
-
-
-
-In order to start training models at least three pre-processing steps have to be run:
+The following datastore configuration works with the MEPS dataset:
+
+```yaml
+# meps.datastore.yaml
+dataset:
+ name: meps_example
+ num_forcing_features: 16
+ var_longnames:
+ - pres_heightAboveGround_0_instant
+ - pres_heightAboveSea_0_instant
+ - nlwrs_heightAboveGround_0_accum
+ - nswrs_heightAboveGround_0_accum
+ - r_heightAboveGround_2_instant
+ - r_hybrid_65_instant
+ - t_heightAboveGround_2_instant
+ - t_hybrid_65_instant
+ - t_isobaricInhPa_500_instant
+ - t_isobaricInhPa_850_instant
+ - u_hybrid_65_instant
+ - u_isobaricInhPa_850_instant
+ - v_hybrid_65_instant
+ - v_isobaricInhPa_850_instant
+ - wvint_entireAtmosphere_0_instant
+ - z_isobaricInhPa_1000_instant
+ - z_isobaricInhPa_500_instant
+ var_names:
+ - pres_0g
+ - pres_0s
+ - nlwrs_0
+ - nswrs_0
+ - r_2
+ - r_65
+ - t_2
+ - t_65
+ - t_500
+ - t_850
+ - u_65
+ - u_850
+ - v_65
+ - v_850
+ - wvint_0
+ - z_1000
+ - z_500
+ var_units:
+ - Pa
+ - Pa
+ - W/m\textsuperscript{2}
+ - W/m\textsuperscript{2}
+ - "-"
+ - "-"
+ - K
+ - K
+ - K
+ - K
+ - m/s
+ - m/s
+ - m/s
+ - m/s
+ - kg/m\textsuperscript{2}
+ - m\textsuperscript{2}/s\textsuperscript{2}
+ - m\textsuperscript{2}/s\textsuperscript{2}
+ num_timesteps: 65
+ num_ensemble_members: 2
+ step_length: 3
+ remove_state_features_with_index: [15]
+grid_shape_state:
+- 268
+- 238
+projection:
+ class_name: LambertConformal
+ kwargs:
+ central_latitude: 63.3
+ central_longitude: 15.0
+ standard_parallels:
+ - 63.3
+ - 63.3
+```
+
+Which you can then use in a neural-lam configuration file like this:
+
+```yaml
+# config.yaml
+datastore:
+ kind: npyfilesmeps
+ config_path: meps.datastore.yaml
+training:
+ state_feature_weighting:
+ __config_class__: ManualStateFeatureWeighting
+ values:
+ u100m: 1.0
+ v100m: 1.0
+```
-* `python -m neural_lam.create_mesh`
-* `python -m neural_lam.create_grid_features`
-* `python -m neural_lam.create_parameter_weights`
+For npy-file based datastores you must separately run the command that creates the variables used for standardization:
+
+```bash
+python -m neural_lam.datastore.npyfilesmeps.compute_standardization_stats
+```
+
+### Graph creation
-### Create graph
Run `python -m neural_lam.create_mesh` with suitable options to generate the graph you want to use (see `python neural_lam.create_mesh --help` for a list of options).
The graphs used for the different models in the [paper](#graph-based-neural-weather-prediction-for-limited-area-modeling) can be created as:
-* **GC-LAM**: `python -m neural_lam.create_mesh --graph multiscale`
-* **Hi-LAM**: `python -m neural_lam.create_mesh --graph hierarchical --hierarchical` (also works for Hi-LAM-Parallel)
-* **L1-LAM**: `python -m neural_lam.create_mesh --graph 1level --levels 1`
+* **GC-LAM**: `python -m neural_lam.create_graph --config_path --name multiscale`
+* **Hi-LAM**: `python -m neural_lam.create_graph --config_path --name hierarchical --hierarchical` (also works for Hi-LAM-Parallel)
+* **L1-LAM**: `python -m neural_lam.create_graph --config_path --name 1level --levels 1`
The graph-related files are stored in a directory called `graphs`.
-### Create remaining static features
-To create the remaining static files run `python -m neural_lam.create_grid_features` and `python -m neural_lam.create_parameter_weights`.
-
## Weights & Biases Integration
The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it.
When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface.
@@ -160,15 +399,17 @@ wandb off
```
## Train Models
-Models can be trained using `python -m neural_lam.train_model`.
+Models can be trained using `python -m neural_lam.train_model --config_path `.
Run `python neural_lam.train_model --help` for a full list of training options.
A few of the key ones are outlined below:
-* `--dataset`: Which data to train on
+* `--config_path`: Path to the configuration for neural-lam (for example in `data/myexperiment/config.yaml`).
* `--model`: Which model to train
* `--graph`: Which graph to use with the model
+* `--epochs`: Number of epochs to train for
* `--processor_layers`: Number of GNN layers to use in the processing part of the model
-* `--ar_steps`: Number of time steps to unroll for when making predictions and computing the loss
+* `--ar_steps_train`: Number of time steps to unroll for when making predictions and computing the loss
+* `--ar_steps_eval`: Number of time steps to unroll for during validation steps
Checkpoints of trained models are stored in the `saved_models` directory.
The implemented models are:
@@ -208,13 +449,14 @@ python -m neural_lam.train_model --model hi_lam_parallel --graph hierarchical ..
Checkpoint files for our models trained on the MEPS data are available upon request.
## Evaluate Models
-Evaluation is also done using `python -m neural_lam.train_model`, but using the `--eval` option.
+Evaluation is also done using `python -m neural_lam.train_model --config_path `, but using the `--eval` option.
Use `--eval val` to evaluate the model on the validation set and `--eval test` to evaluate on test data.
-Most of the training options are also relevant for evaluation (not `ar_steps`, evaluation always unrolls full forecasts).
+Most of the training options are also relevant for evaluation.
Some options specifically important for evaluation are:
* `--load`: Path to model checkpoint file (`.ckpt`) to load parameters from
* `--n_example_pred`: Number of example predictions to plot during evaluation.
+* `--ar_steps_eval`: Number of time steps to unroll for during evaluation
**Note:** While it is technically possible to use multiple GPUs for running evaluation, this is strongly discouraged. If using multiple devices the `DistributedSampler` will replicate some samples to make sure all devices have the same batch size, meaning that evaluation metrics will be unreliable.
A possible workaround is to just use batch size 1 during evaluation.
@@ -223,47 +465,7 @@ This issue stems from PyTorch Lightning. See for example [this PR](https://githu
# Repository Structure
Except for training and pre-processing scripts all the source code can be found in the `neural_lam` directory.
Model classes, including abstract base classes, are located in `neural_lam/models`.
-
-## Format of data directory
-It is possible to store multiple datasets in the `data` directory.
-Each dataset contains a set of files with static features and a set of samples.
-The samples are split into different sub-directories for training, validation and testing.
-The directory structure is shown with examples below.
-Script names within parenthesis denote the script used to generate the file.
-```
-data
-├── dataset1
-│ ├── samples - Directory with data samples
-│ │ ├── train - Training data
-│ │ │ ├── nwp_2022040100_mbr000.npy - A time series sample
-│ │ │ ├── nwp_2022040100_mbr001.npy
-│ │ │ ├── ...
-│ │ │ ├── nwp_2022043012_mbr001.npy
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040100.npy - Solar flux forcing
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040112.npy
-│ │ │ ├── ...
-│ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022043012.npy
-│ │ │ ├── wtr_2022040100.npy - Open water features for one sample
-│ │ │ ├── wtr_2022040112.npy
-│ │ │ ├── ...
-│ │ │ └── wtr_202204012.npy
-│ │ ├── val - Validation data
-│ │ └── test - Test data
-│ └── static - Directory with graph information and static features
-│ ├── nwp_xy.npy - Coordinates of grid nodes (part of dataset)
-│ ├── surface_geopotential.npy - Geopotential at surface of grid nodes (part of dataset)
-│ ├── border_mask.npy - Mask with True for grid nodes that are part of border (part of dataset)
-│ ├── grid_features.pt - Static features of grid nodes (neural_lam.create_grid_features)
-│ ├── parameter_mean.pt - Means of state parameters (neural_lam.create_parameter_weights)
-│ ├── parameter_std.pt - Std.-dev. of state parameters (neural_lam.create_parameter_weights)
-│ ├── diff_mean.pt - Means of one-step differences (neural_lam.create_parameter_weights)
-│ ├── diff_std.pt - Std.-dev. of one-step differences (neural_lam.create_parameter_weights)
-│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (neural_lam.create_parameter_weights)
-│ └── parameter_weights.npy - Loss weights for different state parameters (neural_lam.create_parameter_weights)
-├── dataset2
-├── ...
-└── datasetN
-```
+Notebooks for visualization and analysis are located in `docs`.
## Format of graph directory
The `graphs` directory contains generated graph structures that can be used by different graph-based models.
diff --git a/figures/component_dependencies.png b/figures/component_dependencies.png
deleted file mode 100644
index fae77cab..00000000
Binary files a/figures/component_dependencies.png and /dev/null differ
diff --git a/neural_lam/__init__.py b/neural_lam/__init__.py
index dd565a26..da4c4d2e 100644
--- a/neural_lam/__init__.py
+++ b/neural_lam/__init__.py
@@ -1,5 +1,4 @@
# First-party
-import neural_lam.config
import neural_lam.interaction_net
import neural_lam.metrics
import neural_lam.models
diff --git a/neural_lam/config.py b/neural_lam/config.py
index 5891ea74..d3e09697 100644
--- a/neural_lam/config.py
+++ b/neural_lam/config.py
@@ -1,62 +1,171 @@
# Standard library
-import functools
+import dataclasses
from pathlib import Path
+from typing import Dict, Union
# Third-party
-import cartopy.crs as ccrs
-import yaml
-
-
-class Config:
- """
- Class for loading configuration files.
-
- This class loads a configuration file and provides a way to access its
- values as attributes.
- """
-
- def __init__(self, values):
- self.values = values
-
- @classmethod
- def from_file(cls, filepath):
- """Load a configuration file."""
- if filepath.endswith(".yaml"):
- with open(filepath, encoding="utf-8", mode="r") as file:
- return cls(values=yaml.safe_load(file))
- else:
- raise NotImplementedError(Path(filepath).suffix)
-
- def __getattr__(self, name):
- keys = name.split(".")
- value = self.values
- for key in keys:
- if key in value:
- value = value[key]
- else:
- return None
- if isinstance(value, dict):
- return Config(values=value)
- return value
-
- def __getitem__(self, key):
- value = self.values[key]
- if isinstance(value, dict):
- return Config(values=value)
- return value
-
- def __contains__(self, key):
- return key in self.values
-
- def num_data_vars(self):
- """Return the number of data variables for a given key."""
- return len(self.dataset.var_names)
-
- @functools.cached_property
- def coords_projection(self):
- """Return the projection."""
- proj_config = self.values["projection"]
- proj_class_name = proj_config["class"]
- proj_class = getattr(ccrs, proj_class_name)
- proj_params = proj_config.get("kwargs", {})
- return proj_class(**proj_params)
+import dataclass_wizard
+
+# Local
+from .datastore import (
+ DATASTORES,
+ MDPDatastore,
+ NpyFilesDatastoreMEPS,
+ init_datastore,
+)
+
+
+class DatastoreKindStr(str):
+ VALID_KINDS = DATASTORES.keys()
+
+ def __new__(cls, value):
+ if value not in cls.VALID_KINDS:
+ raise ValueError(f"Invalid datastore kind: {value}")
+ return super().__new__(cls, value)
+
+
+@dataclasses.dataclass
+class DatastoreSelection:
+ """
+ Configuration for selecting a datastore to use with neural-lam.
+
+ Attributes
+ ----------
+ kind : DatastoreKindStr
+ The kind of datastore to use, currently `mdp` or `npyfilesmeps` are
+ implemented.
+ config_path : str
+ The path to the configuration file for the selected datastore, this is
+ assumed to be relative to the configuration file for neural-lam.
+ """
+
+ kind: DatastoreKindStr
+ config_path: str
+
+
+@dataclasses.dataclass
+class ManualStateFeatureWeighting:
+ """
+ Configuration for weighting the state features in the loss function where
+ the weights are manually specified.
+
+ Attributes
+ ----------
+ weights : Dict[str, float]
+ Manual weights for the state features.
+ """
+
+ weights: Dict[str, float]
+
+
+@dataclasses.dataclass
+class UniformFeatureWeighting:
+ """
+ Configuration for weighting the state features in the loss function where
+ all state features are weighted equally.
+ """
+
+ pass
+
+
+@dataclasses.dataclass
+class TrainingConfig:
+ """
+ Configuration related to training neural-lam
+
+ Attributes
+ ----------
+ state_feature_weighting : Union[ManualStateFeatureWeighting,
+ UnformFeatureWeighting]
+ The method to use for weighting the state features in the loss
+ function. Defaults to uniform weighting (`UnformFeatureWeighting`, i.e.
+ all features are weighted equally).
+ """
+
+ state_feature_weighting: Union[
+ ManualStateFeatureWeighting, UniformFeatureWeighting
+ ] = dataclasses.field(default_factory=UniformFeatureWeighting)
+
+
+@dataclasses.dataclass
+class NeuralLAMConfig(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
+ """
+ Dataclass for Neural-LAM configuration. This class is used to load and
+ store the configuration for using Neural-LAM.
+
+ Attributes
+ ----------
+ datastore : DatastoreSelection
+ The configuration for the datastore to use.
+ training : TrainingConfig
+ The configuration for training the model.
+ """
+
+ datastore: DatastoreSelection
+ training: TrainingConfig = dataclasses.field(default_factory=TrainingConfig)
+
+ class _(dataclass_wizard.JSONWizard.Meta):
+ """
+ Define the configuration class as a JSON wizard class.
+
+ Together `tag_key` and `auto_assign_tags` enable that when a `Union` of
+ types are used for an attribute, the specific type to deserialize to
+ can be specified in the serialised data using the `tag_key` value. In
+ our case we call the tag key `__config_class__` to indicate to the
+ user that they should pick a dataclass describing configuration in
+ neural-lam. This Union-based selection allows us to support different
+ configuration attributes for different choices of methods for example
+ and is used when picking between different feature weighting methods in
+ the `TrainingConfig` class. `auto_assign_tags` is set to True to
+ automatically set that tag key (i.e. `__config_class__` in the config
+ file) should just be the class name of the dataclass to deserialize to.
+ """
+
+ tag_key = "__config_class__"
+ auto_assign_tags = True
+ # ensure that all parts of the loaded configuration match the
+ # dataclasses used
+ # TODO: this should be enabled once
+ # https://github.com/rnag/dataclass-wizard/issues/137 is fixed, but
+ # currently cannot be used together with `auto_assign_tags` due to a
+ # bug it seems
+ # raise_on_unknown_json_key = True
+
+
+class InvalidConfigError(Exception):
+ pass
+
+
+def load_config_and_datastore(
+ config_path: str,
+) -> tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]:
+ """
+ Load the neural-lam configuration and the datastore specified in the
+ configuration.
+
+ Parameters
+ ----------
+ config_path : str
+ Path to the Neural-LAM configuration file.
+
+ Returns
+ -------
+ tuple[NeuralLAMConfig, Union[MDPDatastore, NpyFilesDatastoreMEPS]]
+ The Neural-LAM configuration and the loaded datastore.
+ """
+ try:
+ config = NeuralLAMConfig.from_yaml_file(config_path)
+ except dataclass_wizard.errors.UnknownJSONKey as ex:
+ raise InvalidConfigError(
+ "There was an error loading the configuration file at "
+ f"{config_path}. "
+ ) from ex
+ # datastore config is assumed to be relative to the config file
+ datastore_config_path = (
+ Path(config_path).parent / config.datastore.config_path
+ )
+ datastore = init_datastore(
+ datastore_kind=config.datastore.kind, config_path=datastore_config_path
+ )
+
+ return config, datastore
diff --git a/neural_lam/create_mesh.py b/neural_lam/create_graph.py
similarity index 74%
rename from neural_lam/create_mesh.py
rename to neural_lam/create_graph.py
index 21b8bf6e..ef979be3 100644
--- a/neural_lam/create_mesh.py
+++ b/neural_lam/create_graph.py
@@ -13,7 +13,8 @@
from torch_geometric.utils.convert import from_networkx
# Local
-from . import config
+from .config import load_config_and_datastore
+from .datastore.base import BaseRegularGridDatastore
def plot_graph(graph, title=None):
@@ -108,8 +109,8 @@ def from_networkx_with_start_index(nx_graph, start_index):
def mk_2d_graph(xy, nx, ny):
- xm, xM = np.amin(xy[0][0, :]), np.amax(xy[0][0, :])
- ym, yM = np.amin(xy[1][:, 0]), np.amax(xy[1][:, 0])
+ xm, xM = np.amin(xy[:, :, 0][:, 0]), np.amax(xy[:, :, 0][:, 0])
+ ym, yM = np.amin(xy[:, :, 1][0, :]), np.amax(xy[:, :, 1][0, :])
# avoid nodes on border
dx = (xM - xm) / nx
@@ -117,19 +118,19 @@ def mk_2d_graph(xy, nx, ny):
lx = np.linspace(xm + dx / 2, xM - dx / 2, nx)
ly = np.linspace(ym + dy / 2, yM - dy / 2, ny)
- mg = np.meshgrid(lx, ly)
- g = networkx.grid_2d_graph(len(ly), len(lx))
+ mg = np.meshgrid(lx, ly, indexing="ij") # Use 'ij' indexing for (Nx,Ny)
+ g = networkx.grid_2d_graph(len(lx), len(ly))
for node in g.nodes:
g.nodes[node]["pos"] = np.array([mg[0][node], mg[1][node]])
# add diagonal edges
g.add_edges_from(
- [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
+ [((x, y), (x + 1, y + 1)) for y in range(ny - 1) for x in range(nx - 1)]
+ [
((x + 1, y), (x, y + 1))
- for x in range(nx - 1)
for y in range(ny - 1)
+ for x in range(nx - 1)
]
)
@@ -153,46 +154,82 @@ def prepend_node_index(graph, new_index):
return networkx.relabel_nodes(graph, to_mapping, copy=True)
-def main(input_args=None):
- parser = ArgumentParser(description="Graph generation arguments")
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
- parser.add_argument(
- "--graph",
- type=str,
- default="multiscale",
- help="Name to save graph as (default: multiscale)",
- )
- parser.add_argument(
- "--plot",
- action="store_true",
- help="If graphs should be plotted during generation "
- "(default: False)",
- )
- parser.add_argument(
- "--levels",
- type=int,
- help="Limit multi-scale mesh to given number of levels, "
- "from bottom up (default: None (no limit))",
- )
- parser.add_argument(
- "--hierarchical",
- action="store_true",
- help="Generate hierarchical mesh graph (default: False)",
- )
- args = parser.parse_args(input_args)
-
- # Load grid positions
- config_loader = config.Config.from_file(args.data_config)
- static_dir_path = os.path.join("data", config_loader.dataset.name, "static")
- graph_dir_path = os.path.join("graphs", args.graph)
+def create_graph(
+ graph_dir_path: str,
+ xy: np.ndarray,
+ n_max_levels: int,
+ hierarchical: bool,
+ create_plot: bool,
+):
+ """
+ Create graph components from `xy` grid coordinates and store in
+ `graph_dir_path`.
+
+ Creates the following files for all graphs:
+ - g2m_edge_index.pt [2, N_g2m_edges]
+ - g2m_features.pt [N_g2m_edges, d_features]
+ - m2g_edge_index.pt [2, N_m2m_edges]
+ - m2g_features.pt [N_m2m_edges, d_features]
+ - m2m_edge_index.pt list of [2, N_m2m_edges_level], length==n_levels
+ - m2m_features.pt list of [N_m2m_edges_level, d_features],
+ length==n_levels
+ - mesh_features.pt list of [N_mesh_nodes_level, d_mesh_static],
+ length==n_levels
+
+ where
+ d_features:
+ number of features per edge (currently d_features==3, for
+ edge-length, x and y)
+ N_g2m_edges:
+ number of edges in the graph from grid-to-mesh
+ N_m2g_edges:
+ number of edges in the graph from mesh-to-grid
+ N_m2m_edges_level:
+ number of edges in the graph from mesh-to-mesh at a given level
+ (list index corresponds to the level)
+ d_mesh_static:
+ number of static features per mesh node (currently
+ d_mesh_static==2, for x and y)
+ N_mesh_nodes_level:
+ number of nodes in the mesh at a given level
+
+ And in addition for hierarchical graphs:
+ - mesh_up_edge_index.pt
+ list of [2, N_mesh_updown_edges_level], length==n_levels-1
+ - mesh_up_features.pt
+ list of [N_mesh_updown_edges_level, d_features], length==n_levels-1
+ - mesh_down_edge_index.pt
+ list of [2, N_mesh_updown_edges_level], length==n_levels-1
+ - mesh_down_features.pt
+ list of [N_mesh_updown_edges_level, d_features], length==n_levels-1
+
+ where N_mesh_updown_edges_level is the number of edges in the graph from
+ mesh-to-mesh between two consecutive levels (list index corresponds index
+ of lower level)
+
+
+ Parameters
+ ----------
+ graph_dir_path : str
+ Path to store the graph components.
+ xy : np.ndarray
+ Grid coordinates, expected to be of shape (Nx, Ny, 2).
+ n_max_levels : int
+ Limit multi-scale mesh to given number of levels, from bottom up
+ (default: None (no limit)).
+ hierarchical : bool
+ Generate hierarchical mesh graph (default: False).
+ create_plot : bool
+ If graphs should be plotted during generation (default: False).
+
+ Returns
+ -------
+ None
+
+ """
os.makedirs(graph_dir_path, exist_ok=True)
- xy = np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
+ print(f"Writing graph components to {graph_dir_path}")
grid_xy = torch.tensor(xy)
pos_max = torch.max(torch.abs(grid_xy))
@@ -202,29 +239,29 @@ def main(input_args=None):
#
# graph geometry
- nx = 3 # number of children = nx**2
- nlev = int(np.log(max(xy.shape)) / np.log(nx))
+ nx = 3 # number of children =nx**2
+ nlev = int(np.log(max(xy.shape[:2])) / np.log(nx))
nleaf = nx**nlev # leaves at the bottom = nleaf**2
mesh_levels = nlev - 1
- if args.levels:
+ if n_max_levels:
# Limit the levels in mesh graph
- mesh_levels = min(mesh_levels, args.levels)
+ mesh_levels = min(mesh_levels, n_max_levels)
- print(f"nlev: {nlev}, nleaf: {nleaf}, mesh_levels: {mesh_levels}")
+ # print(f"nlev: {nlev}, nleaf: {nleaf}, mesh_levels: {mesh_levels}")
# multi resolution tree levels
G = []
for lev in range(1, mesh_levels + 1):
n = int(nleaf / (nx**lev))
g = mk_2d_graph(xy, n, n)
- if args.plot:
+ if create_plot:
plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}")
plt.show()
G.append(g)
- if args.hierarchical:
+ if hierarchical:
# Relabel nodes of each level with level index first
G = [
prepend_node_index(graph, level_i)
@@ -297,7 +334,7 @@ def main(input_args=None):
up_graphs.append(pyg_up)
down_graphs.append(pyg_down)
- if args.plot:
+ if create_plot:
plot_graph(
pyg_down, title=f"Down graph, {from_level} -> {to_level}"
)
@@ -363,7 +400,7 @@ def main(input_args=None):
m2m_graphs = [pyg_m2m]
mesh_pos = [pyg_m2m.pos.to(torch.float32)]
- if args.plot:
+ if create_plot:
plot_graph(pyg_m2m, title="Mesh-to-mesh")
plt.show()
@@ -395,7 +432,7 @@ def main(input_args=None):
)
# grid nodes
- Ny, Nx = xy.shape[1:]
+ Nx, Ny = xy.shape[:2]
G_grid = networkx.grid_2d_graph(Ny, Nx)
G_grid.clear_edges()
@@ -403,7 +440,9 @@ def main(input_args=None):
# vg features (only pos introduced here)
for node in G_grid.nodes:
# pos is in feature but here explicit for convenience
- G_grid.nodes[node]["pos"] = np.array([xy[0][node], xy[1][node]])
+ G_grid.nodes[node]["pos"] = xy[
+ node[1], node[0]
+ ] # xy is already (Nx,Ny,2)
# add 1000 to node key to separate grid nodes (1000,i,j) from mesh nodes
# (i,j) and impose sorting order such that vm are the first nodes
@@ -412,7 +451,9 @@ def main(input_args=None):
# build kd tree for grid point pos
# order in vg_list should be same as in vg_xy
vg_list = list(G_grid.nodes)
- vg_xy = np.array([[xy[0][node[1:]], xy[1][node[1:]]] for node in vg_list])
+ vg_xy = np.array(
+ [xy[node[2], node[1]] for node in vg_list]
+ ) # xy is already (Nx,Ny,2)
kdt_g = scipy.spatial.KDTree(vg_xy)
# now add (all) mesh nodes, include features (pos)
@@ -444,7 +485,7 @@ def main(input_args=None):
pyg_g2m = from_networkx(G_g2m)
- if args.plot:
+ if create_plot:
plot_graph(pyg_g2m, title="Grid-to-mesh")
plt.show()
@@ -483,7 +524,7 @@ def main(input_args=None):
)
pyg_m2g = from_networkx(G_m2g_int)
- if args.plot:
+ if create_plot:
plot_graph(pyg_m2g, title="Mesh-to-grid")
plt.show()
@@ -494,5 +535,76 @@ def main(input_args=None):
save_edges(pyg_m2g, "m2g", graph_dir_path)
+def create_graph_from_datastore(
+ datastore: BaseRegularGridDatastore,
+ output_root_path: str,
+ n_max_levels: int = None,
+ hierarchical: bool = False,
+ create_plot: bool = False,
+):
+ if isinstance(datastore, BaseRegularGridDatastore):
+ xy = datastore.get_xy(category="state", stacked=False)
+ else:
+ raise NotImplementedError(
+ "Only graph creation for BaseRegularGridDatastore is supported"
+ )
+
+ create_graph(
+ graph_dir_path=output_root_path,
+ xy=xy,
+ n_max_levels=n_max_levels,
+ hierarchical=hierarchical,
+ create_plot=create_plot,
+ )
+
+
+def cli(input_args=None):
+ parser = ArgumentParser(description="Graph generation arguments")
+ parser.add_argument(
+ "--config_path",
+ type=str,
+ help="Path to neural-lam configuration file",
+ )
+ parser.add_argument(
+ "--name",
+ type=str,
+ default="multiscale",
+ help="Name to save graph as (default: multiscale)",
+ )
+ parser.add_argument(
+ "--plot",
+ action="store_true",
+ help="If graphs should be plotted during generation "
+ "(default: False)",
+ )
+ parser.add_argument(
+ "--levels",
+ type=int,
+ help="Limit multi-scale mesh to given number of levels, "
+ "from bottom up (default: None (no limit))",
+ )
+ parser.add_argument(
+ "--hierarchical",
+ action="store_true",
+ help="Generate hierarchical mesh graph (default: False)",
+ )
+ args = parser.parse_args(input_args)
+
+ assert (
+ args.config_path is not None
+ ), "Specify your config with --config_path"
+
+ # Load neural-lam configuration and datastore to use
+ _, datastore = load_config_and_datastore(config_path=args.config_path)
+
+ create_graph_from_datastore(
+ datastore=datastore,
+ output_root_path=os.path.join(datastore.root_path, "graph", args.name),
+ n_max_levels=args.levels,
+ hierarchical=args.hierarchical,
+ create_plot=args.plot,
+ )
+
+
if __name__ == "__main__":
- main()
+ cli()
diff --git a/neural_lam/create_grid_features.py b/neural_lam/create_grid_features.py
deleted file mode 100644
index adabd9dc..00000000
--- a/neural_lam/create_grid_features.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Standard library
-import os
-from argparse import ArgumentParser
-
-# Third-party
-import numpy as np
-import torch
-
-# Local
-from . import config
-
-
-def main():
- """
- Pre-compute all static features related to the grid nodes
- """
- parser = ArgumentParser(description="Training arguments")
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
- args = parser.parse_args()
- config_loader = config.Config.from_file(args.data_config)
-
- static_dir_path = os.path.join("data", config_loader.dataset.name, "static")
-
- # -- Static grid node features --
- grid_xy = torch.tensor(
- np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
- ) # (2, N_y, N_x)
- grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2)
- pos_max = torch.max(torch.abs(grid_xy))
- grid_xy = grid_xy / pos_max # Divide by maximum coordinate
-
- geopotential = torch.tensor(
- np.load(os.path.join(static_dir_path, "surface_geopotential.npy"))
- ) # (N_y, N_x)
- geopotential = geopotential.flatten(0, 1).unsqueeze(1) # (N_grid,1)
- gp_min = torch.min(geopotential)
- gp_max = torch.max(geopotential)
- # Rescale geopotential to [0,1]
- geopotential = (geopotential - gp_min) / (gp_max - gp_min) # (N_grid, 1)
-
- grid_border_mask = torch.tensor(
- np.load(os.path.join(static_dir_path, "border_mask.npy")),
- dtype=torch.int64,
- ) # (N_y, N_x)
- grid_border_mask = (
- grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1)
- ) # (N_grid, 1)
-
- # Concatenate grid features
- grid_features = torch.cat(
- (grid_xy, geopotential, grid_border_mask), dim=1
- ) # (N_grid, 4)
-
- torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt"))
-
-
-if __name__ == "__main__":
- main()
diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml
deleted file mode 100644
index f1527849..00000000
--- a/neural_lam/data_config.yaml
+++ /dev/null
@@ -1,64 +0,0 @@
-dataset:
- name: meps_example
- var_names:
- - pres_0g
- - pres_0s
- - nlwrs_0
- - nswrs_0
- - r_2
- - r_65
- - t_2
- - t_65
- - t_500
- - t_850
- - u_65
- - u_850
- - v_65
- - v_850
- - wvint_0
- - z_1000
- - z_500
- var_units:
- - Pa
- - Pa
- - $\mathrm{W}/\mathrm{m}^2$
- - $\mathrm{W}/\mathrm{m}^2$
- - ""
- - ""
- - K
- - K
- - K
- - K
- - m/s
- - m/s
- - m/s
- - m/s
- - $\mathrm{kg}/\mathrm{m}^2$
- - $\mathrm{m}^2/\mathrm{s}^2$
- - $\mathrm{m}^2/\mathrm{s}^2$
- var_longnames:
- - pres_heightAboveGround_0_instant
- - pres_heightAboveSea_0_instant
- - nlwrs_heightAboveGround_0_accum
- - nswrs_heightAboveGround_0_accum
- - r_heightAboveGround_2_instant
- - r_hybrid_65_instant
- - t_heightAboveGround_2_instant
- - t_hybrid_65_instant
- - t_isobaricInhPa_500_instant
- - t_isobaricInhPa_850_instant
- - u_hybrid_65_instant
- - u_isobaricInhPa_850_instant
- - v_hybrid_65_instant
- - v_isobaricInhPa_850_instant
- - wvint_entireAtmosphere_0_instant
- - z_isobaricInhPa_1000_instant
- - z_isobaricInhPa_500_instant
- num_forcing_features: 16
-grid_shape_state: [268, 238]
-projection:
- class: LambertConformal
- kwargs:
- central_longitude: 15.0
- central_latitude: 63.3
- standard_parallels: [63.3, 63.3]
diff --git a/neural_lam/datastore/__init__.py b/neural_lam/datastore/__init__.py
new file mode 100644
index 00000000..40e683ac
--- /dev/null
+++ b/neural_lam/datastore/__init__.py
@@ -0,0 +1,26 @@
+# Local
+from .base import BaseDatastore # noqa
+from .mdp import MDPDatastore # noqa
+from .npyfilesmeps import NpyFilesDatastoreMEPS # noqa
+
+DATASTORE_CLASSES = [
+ MDPDatastore,
+ NpyFilesDatastoreMEPS,
+]
+
+DATASTORES = {
+ datastore.SHORT_NAME: datastore for datastore in DATASTORE_CLASSES
+}
+
+
+def init_datastore(datastore_kind, config_path):
+ DatastoreClass = DATASTORES.get(datastore_kind)
+
+ if DatastoreClass is None:
+ raise NotImplementedError(
+ f"Datastore kind {datastore_kind} is not implemented"
+ )
+
+ datastore = DatastoreClass(config_path=config_path)
+
+ return datastore
diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py
new file mode 100644
index 00000000..0317c2e5
--- /dev/null
+++ b/neural_lam/datastore/base.py
@@ -0,0 +1,553 @@
+# Standard library
+import abc
+import collections
+import dataclasses
+import functools
+from functools import cached_property
+from pathlib import Path
+from typing import List, Union
+
+# Third-party
+import cartopy.crs as ccrs
+import numpy as np
+import xarray as xr
+from pandas.core.indexes.multi import MultiIndex
+
+
+class BaseDatastore(abc.ABC):
+ """
+ Base class for weather data used in the neural-lam package. A datastore
+ defines the interface for accessing weather data by providing methods to
+ access the data in a processed format that can be used for training and
+ evaluation of neural networks.
+
+ NOTE: All methods return either primitive types, `numpy.ndarray`,
+ `xarray.DataArray` or `xarray.Dataset` objects, not `pytorch.Tensor`
+ objects. Conversion to `pytorch.Tensor` objects should be done in the
+ `weather_dataset.WeatherDataset` class (which inherits from
+ `torch.utils.data.Dataset` and uses the datastore to access the data).
+
+ # Forecast vs analysis data
+ If the datastore is used to represent forecast rather than analysis data,
+ then the `is_forecast` attribute should be set to True, and returned data
+ from `get_dataarray` is assumed to have `analysis_time` and `forecast_time`
+ dimensions (rather than just `time`).
+
+ # Ensemble vs deterministic data
+ If the datastore is used to represent ensemble data, then the `is_ensemble`
+ attribute should be set to True, and returned data from `get_dataarray` is
+ assumed to have an `ensemble_member` dimension.
+
+ # Grid index
+ All methods that return data specific to a grid point (like
+ `get_dataarray`) should have a single dimension named `grid_index` that
+ represents the spatial grid index of the data. The actual x, y coordinates
+ of the grid points should be stored in the `x` and `y` coordinates of the
+ dataarray or dataset with the `grid_index` dimension as the coordinate for
+ each of the `x` and `y` coordinates.
+ """
+
+ is_ensemble: bool = False
+ is_forecast: bool = False
+
+ @property
+ @abc.abstractmethod
+ def root_path(self) -> Path:
+ """
+ The root path to the datastore. It is relative to this that any derived
+ files (for example the graph components) are stored.
+
+ Returns
+ -------
+ pathlib.Path
+ The root path to the datastore.
+
+ """
+ pass
+
+ @property
+ @abc.abstractmethod
+ def config(self) -> collections.abc.Mapping:
+ """The configuration of the datastore.
+
+ Returns
+ -------
+ collections.abc.Mapping
+ The configuration of the datastore, any dict like object can be
+ returned.
+
+ """
+ pass
+
+ @property
+ @abc.abstractmethod
+ def step_length(self) -> int:
+ """The step length of the dataset in hours.
+
+ Returns:
+ int: The step length in hours.
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_vars_units(self, category: str) -> List[str]:
+ """Get the units of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The units of the variables.
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_vars_names(self, category: str) -> List[str]:
+ """Get the names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The names of the variables.
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_vars_long_names(self, category: str) -> List[str]:
+ """Get the long names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The long names of the variables.
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_num_data_vars(self, category: str) -> int:
+ """Get the number of data variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ int
+ The number of data variables.
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
+ """
+ Return the standardization (i.e. scaling to mean of 0.0 and standard
+ deviation of 1.0) dataarray for the given category. This should contain
+ a `{category}_mean` and `{category}_std` variable for each variable in
+ the category. For `category=="state"`, the dataarray should also
+ contain a `state_diff_mean` and `state_diff_std` variable for the one-
+ step differences of the state variables. The returned dataarray should
+ at least have dimensions of `({category}_feature)`, but can also
+ include for example `grid_index` (if the standardization is done per
+ grid point for example).
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ xr.Dataset
+ The standardization dataarray for the given category, with variables
+ for the mean and standard deviation of the variables (and
+ differences for state variables).
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_dataarray(
+ self, category: str, split: str
+ ) -> Union[xr.DataArray, None]:
+ """
+ Return the processed data (as a single `xr.DataArray`) for the given
+ category of data and test/train/val-split that covers all the data (in
+ space and time) of a given category (state/forcing/static). A
+ datastore must be able to return for the "state" category, but
+ "forcing" and "static" are optional (in which case the method should
+ return `None`). For the "static" category the `split` is allowed to be
+ `None` because the static data is the same for all splits.
+
+ The returned dataarray is expected to at minimum have dimensions of
+ `(grid_index, {category}_feature)` so that any spatial dimensions have
+ been stacked into a single dimension and all variables and levels have
+ been stacked into a single feature dimension named by the `category` of
+ data being loaded.
+
+ For categories of data that have a time dimension (i.e. not static
+ data), the dataarray is expected additionally have `(analysis_time,
+ elapsed_forecast_duration)` dimensions if `is_forecast` is True, or
+ `(time)` if `is_forecast` is False.
+
+ If the data is ensemble data, the dataarray is expected to have an
+ additional `ensemble_member` dimension.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ split : str
+ The time split to filter the dataset (train/val/test).
+
+ Returns
+ -------
+ xr.DataArray or None
+ The xarray DataArray object with processed dataset.
+
+ """
+ pass
+
+ @cached_property
+ @abc.abstractmethod
+ def boundary_mask(self) -> xr.DataArray:
+ """
+ Return the boundary mask for the dataset, with spatial dimensions
+ stacked. Where the value is 1, the grid point is a boundary point, and
+ where the value is 0, the grid point is not a boundary point.
+
+ Returns
+ -------
+ xr.DataArray
+ The boundary mask for the dataset, with dimensions
+ `('grid_index',)`.
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_xy(self, category: str) -> np.ndarray:
+ """
+ Return the x, y coordinates of the dataset as a numpy arrays for a
+ given category of data.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ np.ndarray
+ The x, y coordinates of the dataset with shape `[n_grid_points, 2]`.
+ """
+
+ @property
+ @abc.abstractmethod
+ def coords_projection(self) -> ccrs.Projection:
+ """Return the projection object for the coordinates.
+
+ The projection object is used to plot the coordinates on a map.
+
+ Returns
+ -------
+ cartopy.crs.Projection:
+ The projection object.
+
+ """
+ pass
+
+ @functools.lru_cache
+ def get_xy_extent(self, category: str) -> List[float]:
+ """
+ Return the extent of the x, y coordinates for a given category of data.
+ The extent should be returned as a list of 4 floats with `[xmin, xmax,
+ ymin, ymax]` which can then be used to set the extent of a plot.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[float]
+ The extent of the x, y coordinates.
+
+ """
+ xy = self.get_xy(category, stacked=False)
+ extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()]
+ return [float(v) for v in extent]
+
+ @property
+ @abc.abstractmethod
+ def num_grid_points(self) -> int:
+ """Return the number of grid points in the dataset.
+
+ Returns
+ -------
+ int
+ The number of grid points in the dataset.
+
+ """
+ pass
+
+ @cached_property
+ @abc.abstractmethod
+ def state_feature_weights_values(self) -> List[float]:
+ """
+ Return the weights for each state feature as a list of floats. The
+ weights are defined by the user in a config file for the datastore.
+
+ Implementations of this method must assert that there is one weight for
+ each state feature in the datastore. The weights can be used to scale
+ the loss function for each state variable (e.g. via the standard
+ deviation of the 1-step differences of the state variables).
+
+ Returns:
+ List[float]: The weights for each state feature.
+ """
+ pass
+
+ @functools.lru_cache
+ def expected_dim_order(self, category: str = None) -> tuple[str]:
+ """
+ Return the expected dimension order for the dataarray or dataset
+ returned by `get_dataarray` for the given category of data. The
+ dimension order is the order of the dimensions in the dataarray or
+ dataset, and is used to check that the data is in the expected format.
+
+ This is necessary so that when stacking and unstacking the spatial grid
+ we can ensure that the dimension order is the same as what is returned
+ from `get_dataarray`. And also ensures that downstream uses of a
+ datastore (e.g. WeatherDataset) sees the data in a common structure.
+
+ If the category is None, then the it assumed that data only represents
+ a 1D scalar field varying with grid-index.
+
+ The order is constructed to match the order in `pytorch.Tensor` objects
+ that will be constructed from the data so that the last two dimensions
+ are always the grid-index and feature dimensions (i.e. the order is
+ `[..., grid_index, {category}_feature]`), with any time-related and
+ ensemble-number dimension(s) coming before these two.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The expected dimension order for the dataarray or dataset.
+
+ """
+ dim_order = []
+
+ if category is not None:
+ if category != "static":
+ # static data does not vary in time
+ if self.is_forecast:
+ dim_order.extend(
+ ["analysis_time", "elapsed_forecast_duration"]
+ )
+ elif not self.is_forecast:
+ dim_order.append("time")
+
+ if self.is_ensemble and category == "state":
+ # XXX: for now we only assume ensemble data for state variables
+ dim_order.append("ensemble_member")
+
+ dim_order.append("grid_index")
+
+ if category is not None:
+ dim_order.append(f"{category}_feature")
+
+ return tuple(dim_order)
+
+
+@dataclasses.dataclass
+class CartesianGridShape:
+ """Dataclass to store the shape of a grid."""
+
+ x: int
+ y: int
+
+
+class BaseRegularGridDatastore(BaseDatastore):
+ """
+ Base class for weather data stored on a regular grid (like a chess-board,
+ as opposed to a irregular grid where each cell cannot be indexed by just
+ two integers, see https://en.wikipedia.org/wiki/Regular_grid). In addition
+ to the methods and attributes required for weather data in general (see
+ `BaseDatastore`) for regular-gridded source data each `grid_index`
+ coordinate value is assumed to be associated with `x` and `y`-values that
+ allow the processed data-arrays can be reshaped back into into 2D
+ xy-gridded arrays.
+
+ The following methods and attributes must be implemented for datastore that
+ represents regular-gridded data:
+ - `grid_shape_state` (property): 2D shape of the grid for the state
+ variables.
+ - `get_xy` (method): Return the x, y coordinates of the dataset, with the
+ option to not stack the coordinates (so that they are returned as a 2D
+ grid).
+
+ The operation of going from (x,y)-indexed regular grid
+ to `grid_index`-indexed data-array is called "stacking" and the reverse
+ operation is called "unstacking". This class provides methods to stack and
+ unstack the spatial grid coordinates of the data-arrays (called
+ `stack_grid_coords` and `unstack_grid_coords` respectively).
+ """
+
+ CARTESIAN_COORDS = ["x", "y"]
+
+ @cached_property
+ @abc.abstractmethod
+ def grid_shape_state(self) -> CartesianGridShape:
+ """The shape of the grid for the state variables.
+
+ Returns
+ -------
+ CartesianGridShape:
+ The shape of the grid for the state variables, which has `x` and
+ `y` attributes.
+
+ """
+ pass
+
+ @abc.abstractmethod
+ def get_xy(self, category: str, stacked: bool = True) -> np.ndarray:
+ """Return the x, y coordinates of the dataset.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ stacked : bool
+ Whether to stack the x, y coordinates. The parameter `stacked` has
+ been introduced in this class. Parent class `BaseDatastore` has the
+ same method signature but without the `stacked` parameter. Defaults
+ to `True` to match the behaviour of `BaseDatastore.get_xy()` which
+ always returns the coordinates stacked.
+
+ Returns
+ -------
+ np.ndarray
+ The x, y coordinates of the dataset, returned differently based on
+ the value of `stacked`: - `stacked==True`: shape `(n_grid_points,
+ 2)` where
+ n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(N_x, N_y, 2)`
+ """
+ pass
+
+ def unstack_grid_coords(
+ self, da_or_ds: Union[xr.DataArray, xr.Dataset]
+ ) -> Union[xr.DataArray, xr.Dataset]:
+ """
+ Unstack the spatial grid coordinates from `grid_index` into separate `x`
+ and `y` dimensions to create a 2D grid. Only performs unstacking if the
+ data is currently stacked (has grid_index dimension).
+
+ Parameters
+ ----------
+ da_or_ds : xr.DataArray or xr.Dataset
+ The dataarray or dataset to unstack the grid coordinates of.
+
+ Returns
+ -------
+ xr.DataArray or xr.Dataset
+ The dataarray or dataset with the grid coordinates unstacked.
+ """
+ # Return original data if already unstacked (no grid_index dimension)
+ if "grid_index" not in da_or_ds.dims:
+ return da_or_ds
+
+ # Check whether `grid_index` is a multi-index
+ if not isinstance(da_or_ds.indexes.get("grid_index"), MultiIndex):
+ da_or_ds = da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS)
+
+ da_or_ds_unstacked = da_or_ds.unstack("grid_index")
+
+ # Ensure that the x, y dimensions are in the correct order
+ dims = da_or_ds_unstacked.dims
+ xy_dim_order = [d for d in dims if d in self.CARTESIAN_COORDS]
+
+ if xy_dim_order != self.CARTESIAN_COORDS:
+ da_or_ds_unstacked = da_or_ds_unstacked.transpose("x", "y")
+
+ return da_or_ds_unstacked
+
+ def stack_grid_coords(
+ self, da_or_ds: Union[xr.DataArray, xr.Dataset]
+ ) -> Union[xr.DataArray, xr.Dataset]:
+ """
+ Stack the spatial grid coordinates (x and y) into a single `grid_index`
+ dimension. Only performs stacking if the data is currently unstacked
+ (has x and y dimensions).
+
+ Parameters
+ ----------
+ da_or_ds : xr.DataArray or xr.Dataset
+ The dataarray or dataset to stack the grid coordinates of.
+
+ Returns
+ -------
+ xr.DataArray or xr.Dataset
+ The dataarray or dataset with the grid coordinates stacked.
+ """
+ # Return original data if already stacked (has grid_index dimension)
+ if "grid_index" in da_or_ds.dims:
+ return da_or_ds
+
+ da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
+
+ # infer what category of data the array represents by finding the
+ # dimension named in the format `{category}_feature`
+ category = None
+ for dim in da_or_ds_stacked.dims:
+ if dim.endswith("_feature"):
+ if category is not None:
+ raise ValueError(
+ "Multiple dimensions ending with '_feature' found in "
+ f"dataarray: {da_or_ds_stacked}. Cannot infer category."
+ )
+ category = dim.split("_")[0]
+
+ dim_order = self.expected_dim_order(category=category)
+
+ return da_or_ds_stacked.transpose(*dim_order)
+
+ @property
+ @functools.lru_cache
+ def num_grid_points(self) -> int:
+ """Return the number of grid points in the dataset.
+
+ Returns
+ -------
+ int
+ The number of grid points in the dataset.
+
+ """
+ return self.grid_shape_state.x * self.grid_shape_state.y
diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py
new file mode 100644
index 00000000..10593a82
--- /dev/null
+++ b/neural_lam/datastore/mdp.py
@@ -0,0 +1,464 @@
+# Standard library
+import warnings
+from functools import cached_property
+from pathlib import Path
+from typing import List
+
+# Third-party
+import cartopy.crs as ccrs
+import mllam_data_prep as mdp
+import xarray as xr
+from loguru import logger
+from numpy import ndarray
+
+# Local
+from .base import BaseRegularGridDatastore, CartesianGridShape
+
+
+class MDPDatastore(BaseRegularGridDatastore):
+ """
+ Datastore class for datasets made with the mllam_data_prep library
+ (https://github.com/mllam/mllam-data-prep). This class wraps the
+ `mllam_data_prep` library to do the necessary transforms to create the
+ different categories (state/forcing/static) of data, with the actual
+ transform to do being specified in the configuration file.
+ """
+
+ SHORT_NAME = "mdp"
+
+ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
+ """
+ Construct a new MDPDatastore from the configuration file at
+ `config_path`. A boundary mask is created with `n_boundary_points`
+ boundary points. If `reuse_existing` is True, the dataset is loaded
+ from a zarr file if it exists (unless the config has been modified
+ since the zarr was created), otherwise it is created from the
+ configuration file.
+
+ Parameters
+ ----------
+ config_path : str
+ The path to the configuration file, this will be fed to the
+ `mllam_data_prep.Config.from_yaml_file` method to then call
+ `mllam_data_prep.create_dataset` to create the dataset.
+ n_boundary_points : int
+ The number of boundary points to use in the boundary mask.
+ reuse_existing : bool
+ Whether to reuse an existing dataset zarr file if it exists and its
+ creation date is newer than the configuration file.
+
+ """
+ self._config_path = Path(config_path)
+ self._root_path = self._config_path.parent
+ self._config = mdp.Config.from_yaml_file(self._config_path)
+ fp_ds = self._root_path / self._config_path.name.replace(
+ ".yaml", ".zarr"
+ )
+
+ self._ds = None
+ if reuse_existing and fp_ds.exists():
+ # check that the zarr directory is newer than the config file
+ if fp_ds.stat().st_mtime < self._config_path.stat().st_mtime:
+ logger.warning(
+ "Config file has been modified since zarr was created. "
+ f"The old zarr archive (in {fp_ds}) will be used."
+ "To generate new zarr-archive, move the old one first."
+ )
+ self._ds = xr.open_zarr(fp_ds, consolidated=True)
+
+ if self._ds is None:
+ self._ds = mdp.create_dataset(config=self._config)
+ self._ds.to_zarr(fp_ds)
+ self._n_boundary_points = n_boundary_points
+
+ print("The loaded datastore contains the following features:")
+ for category in ["state", "forcing", "static"]:
+ if len(self.get_vars_names(category)) > 0:
+ var_names = self.get_vars_names(category)
+ print(f" {category:<8s}: {' '.join(var_names)}")
+
+ # check that all three train/val/test splits are available
+ required_splits = ["train", "val", "test"]
+ available_splits = list(self._ds.splits.split_name.values)
+ if not all(split in available_splits for split in required_splits):
+ raise ValueError(
+ f"Missing required splits: {required_splits} in available "
+ f"splits: {available_splits}"
+ )
+
+ print("With the following splits (over time):")
+ for split in required_splits:
+ da_split = self._ds.splits.sel(split_name=split)
+ da_split_start = da_split.sel(split_part="start").load().item()
+ da_split_end = da_split.sel(split_part="end").load().item()
+ print(f" {split:<8s}: {da_split_start} to {da_split_end}")
+
+ # find out the dimension order for the stacking to grid-index
+ dim_order = None
+ for input_dataset in self._config.inputs.values():
+ dim_order_ = input_dataset.dim_mapping["grid_index"].dims
+ if dim_order is None:
+ dim_order = dim_order_
+ else:
+ assert (
+ dim_order == dim_order_
+ ), "all inputs must have the same dimension order"
+
+ self.CARTESIAN_COORDS = dim_order
+
+ @property
+ def root_path(self) -> Path:
+ """The root path of the dataset.
+
+ Returns
+ -------
+ Path
+ The root path of the dataset.
+
+ """
+ return self._root_path
+
+ @property
+ def config(self) -> mdp.Config:
+ """The configuration of the dataset.
+
+ Returns
+ -------
+ mdp.Config
+ The configuration of the dataset.
+
+ """
+ return self._config
+
+ @property
+ def step_length(self) -> int:
+ """The length of the time steps in hours.
+
+ Returns
+ -------
+ int
+ The length of the time steps in hours.
+
+ """
+ da_dt = self._ds["time"].diff("time")
+ return (da_dt.dt.seconds[0] // 3600).item()
+
+ def get_vars_units(self, category: str) -> List[str]:
+ """Return the units of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The units of the variables in the given category.
+
+ """
+ if category not in self._ds and category == "forcing":
+ warnings.warn("no forcing data found in datastore")
+ return []
+ return self._ds[f"{category}_feature_units"].values.tolist()
+
+ def get_vars_names(self, category: str) -> List[str]:
+ """Return the names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The names of the variables in the given category.
+
+ """
+ if category not in self._ds and category == "forcing":
+ warnings.warn("no forcing data found in datastore")
+ return []
+ return self._ds[f"{category}_feature"].values.tolist()
+
+ def get_vars_long_names(self, category: str) -> List[str]:
+ """
+ Return the long names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The long names of the variables in the given category.
+
+ """
+ if category not in self._ds and category == "forcing":
+ warnings.warn("no forcing data found in datastore")
+ return []
+ return self._ds[f"{category}_feature_long_name"].values.tolist()
+
+ def get_num_data_vars(self, category: str) -> int:
+ """Return the number of variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ int
+ The number of variables in the given category.
+
+ """
+ return len(self.get_vars_names(category))
+
+ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
+ """
+ Return the processed data (as a single `xr.DataArray`) for the given
+ category of data and test/train/val-split that covers all the data (in
+ space and time) of a given category (state/forcin g/static). "state" is
+ the only required category, for other categories, the method will
+ return `None` if the category is not found in the datastore.
+
+ The returned dataarray will at minimum have dimensions of `(grid_index,
+ {category}_feature)` so that any spatial dimensions have been stacked
+ into a single dimension and all variables and levels have been stacked
+ into a single feature dimension named by the `category` of data being
+ loaded.
+
+ For categories of data that have a time dimension (i.e. not static
+ data), the dataarray will additionally have `(analysis_time,
+ elapsed_forecast_duration)` dimensions if `is_forecast` is True, or
+ `(time)` if `is_forecast` is False.
+
+ If the data is ensemble data, the dataarray will have an additional
+ `ensemble_member` dimension.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ split : str
+ The time split to filter the dataset (train/val/test).
+
+ Returns
+ -------
+ xr.DataArray or None
+ The xarray DataArray object with processed dataset.
+
+ """
+ if category not in self._ds and category == "forcing":
+ warnings.warn("no forcing data found in datastore")
+ return None
+
+ da_category = self._ds[category]
+
+ # set units on x y coordinates if missing
+ for coord in ["x", "y"]:
+ if "units" not in da_category[coord].attrs:
+ da_category[coord].attrs["units"] = "m"
+
+ # set multi-index for grid-index
+ da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS)
+
+ if "time" in da_category.dims:
+ t_start = (
+ self._ds.splits.sel(split_name=split)
+ .sel(split_part="start")
+ .load()
+ .item()
+ )
+ t_end = (
+ self._ds.splits.sel(split_name=split)
+ .sel(split_part="end")
+ .load()
+ .item()
+ )
+ da_category = da_category.sel(time=slice(t_start, t_end))
+
+ dim_order = self.expected_dim_order(category=category)
+ return da_category.transpose(*dim_order)
+
+ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
+ """
+ Return the standardization dataarray for the given category. This
+ should contain a `{category}_mean` and `{category}_std` variable for
+ each variable in the category. For `category=="state"`, the dataarray
+ should also contain a `state_diff_mean` and `state_diff_std` variable
+ for the one- step differences of the state variables.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ xr.Dataset
+ The standardization dataarray for the given category, with
+ variables for the mean and standard deviation of the variables (and
+ differences for state variables).
+
+ """
+ ops = ["mean", "std"]
+ split = "train"
+ stats_variables = {
+ f"{category}__{split}__{op}": f"{category}_{op}" for op in ops
+ }
+ if category == "state":
+ stats_variables.update(
+ {f"state__{split}__diff_{op}": f"state_diff_{op}" for op in ops}
+ )
+
+ ds_stats = self._ds[stats_variables.keys()].rename(stats_variables)
+ return ds_stats
+
+ @cached_property
+ def boundary_mask(self) -> xr.DataArray:
+ """
+ Produce a 0/1 mask for the boundary points of the dataset, these will
+ sit at the edges of the domain (in x/y extent) and will be used to mask
+ out the boundary points from the loss function and to overwrite the
+ boundary points from the prediction. For now this is created when the
+ mask is requested, but in the future this could be saved to the zarr
+ file.
+
+ Returns
+ -------
+ xr.DataArray
+ A 0/1 mask for the boundary points of the dataset, where 1 is a
+ boundary point and 0 is not.
+
+ """
+ ds_unstacked = self.unstack_grid_coords(da_or_ds=self._ds)
+ da_state_variable = (
+ ds_unstacked["state"].isel(time=0).isel(state_feature=0)
+ )
+ da_domain_allzero = xr.zeros_like(da_state_variable)
+ ds_unstacked["boundary_mask"] = da_domain_allzero.isel(
+ x=slice(self._n_boundary_points, -self._n_boundary_points),
+ y=slice(self._n_boundary_points, -self._n_boundary_points),
+ )
+ ds_unstacked["boundary_mask"] = ds_unstacked.boundary_mask.fillna(
+ 1
+ ).astype(int)
+ return self.stack_grid_coords(da_or_ds=ds_unstacked.boundary_mask)
+
+ @property
+ def coords_projection(self) -> ccrs.Projection:
+ """
+ Return the projection of the coordinates.
+
+ NOTE: currently this expects the projection information to be in the
+ `extra` section of the configuration file, with a `projection` key
+ containing a `class_name` and `kwargs` for constructing the
+ `cartopy.crs.Projection` object. This is a temporary solution until
+ the projection information can be parsed in the produced dataset
+ itself. `mllam-data-prep` ignores the contents of the `extra` section
+ of the config file which is why we need to check that the necessary
+ parts are there.
+
+ Returns
+ -------
+ ccrs.Projection
+ The projection of the coordinates.
+
+ """
+ if "projection" not in self._config.extra:
+ raise ValueError(
+ "projection information not found in the configuration file "
+ f"({self._config_path}). Please add the projection information"
+ "to the `extra` section of the config, by adding a "
+ "`projection` key with the class name and kwargs of the "
+ "projection."
+ )
+
+ projection_info = self._config.extra["projection"]
+ if "class_name" not in projection_info:
+ raise ValueError(
+ "class_name not found in the projection information. Please "
+ "add the class name of the projection to the `projection` key "
+ "in the `extra` section of the config."
+ )
+ if "kwargs" not in projection_info:
+ raise ValueError(
+ "kwargs not found in the projection information. Please add "
+ "the keyword arguments of the projection to the `projection` "
+ "key in the `extra` section of the config."
+ )
+
+ class_name = projection_info["class_name"]
+ ProjectionClass = getattr(ccrs, class_name)
+ kwargs = projection_info["kwargs"]
+
+ globe_kwargs = kwargs.pop("globe", {})
+ if len(globe_kwargs) > 0:
+ kwargs["globe"] = ccrs.Globe(**globe_kwargs)
+
+ return ProjectionClass(**kwargs)
+
+ @cached_property
+ def grid_shape_state(self):
+ """The shape of the cartesian grid for the state variables.
+
+ Returns
+ -------
+ CartesianGridShape
+ The shape of the cartesian grid for the state variables.
+
+ """
+ ds_state = self.unstack_grid_coords(self._ds["state"])
+ da_x, da_y = ds_state.x, ds_state.y
+ assert da_x.ndim == da_y.ndim == 1
+ return CartesianGridShape(x=da_x.size, y=da_y.size)
+
+ def get_xy(self, category: str, stacked: bool) -> ndarray:
+ """Return the x, y coordinates of the dataset.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ stacked : bool
+ Whether to stack the x, y coordinates.
+
+ Returns
+ -------
+ np.ndarray
+ The x, y coordinates of the dataset, returned differently based on
+ the value of `stacked`:
+ - `stacked==True`: shape `(n_grid_points, 2)` where
+ n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(N_x, N_y, 2)`
+
+ """
+ # assume variables are stored in dimensions [grid_index, ...]
+ ds_category = self.unstack_grid_coords(da_or_ds=self._ds[category])
+
+ da_xs = ds_category.x
+ da_ys = ds_category.y
+
+ assert da_xs.ndim == da_ys.ndim == 1, "x and y coordinates must be 1D"
+
+ da_x, da_y = xr.broadcast(da_xs, da_ys)
+ da_xy = xr.concat([da_x, da_y], dim="grid_coord")
+
+ if stacked:
+ da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose(
+ "grid_index",
+ "grid_coord",
+ )
+ else:
+ dims = [
+ "x",
+ "y",
+ "grid_coord",
+ ]
+ da_xy = da_xy.transpose(*dims)
+
+ return da_xy.values
diff --git a/neural_lam/datastore/npyfilesmeps/__init__.py b/neural_lam/datastore/npyfilesmeps/__init__.py
new file mode 100644
index 00000000..397a5075
--- /dev/null
+++ b/neural_lam/datastore/npyfilesmeps/__init__.py
@@ -0,0 +1,2 @@
+# Local
+from .store import NpyFilesDatastoreMEPS # noqa
diff --git a/neural_lam/create_parameter_weights.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py
similarity index 71%
rename from neural_lam/create_parameter_weights.py
rename to neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py
index 4867e609..f2c80e8a 100644
--- a/neural_lam/create_parameter_weights.py
+++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py
@@ -2,16 +2,17 @@
import os
import subprocess
from argparse import ArgumentParser
+from pathlib import Path
# Third-party
-import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
-# Local
-from . import WeatherDataset, config
+# First-party
+from neural_lam import WeatherDataset
+from neural_lam.datastore import init_datastore
class PaddedWeatherDataset(torch.utils.data.Dataset):
@@ -101,6 +102,10 @@ def save_stats(
mean = torch.mean(means, dim=0) # (d_features,)
second_moment = torch.mean(squares, dim=0) # (d_features,)
std = torch.sqrt(second_moment - mean**2) # (d_features,)
+ print(
+ f"Saving {filename_prefix} mean and std.-dev. to "
+ f"{filename_prefix}_mean.pt and {filename_prefix}_std.pt"
+ )
torch.save(
mean.cpu(), os.path.join(static_dir_path, f"{filename_prefix}_mean.pt")
)
@@ -119,100 +124,65 @@ def save_stats(
flux_mean = torch.mean(flux_means) # (,)
flux_second_moment = torch.mean(flux_squares) # (,)
flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,)
+ print("Saving flux mean and std.-dev. to flux_stats.pt")
torch.save(
torch.stack((flux_mean, flux_std)).cpu(),
os.path.join(static_dir_path, "flux_stats.pt"),
)
-def main():
+def main(
+ datastore_config_path, batch_size, step_length, n_workers, distributed
+):
"""
Pre-compute parameter weights to be used in loss function
+
+ Arguments
+ ---------
+ datastore_config_path : str
+ Path to datastore config file
+ batch_size : int
+ Batch size when iterating over the dataset
+ step_length : int
+ Step length in hours to consider single time step
+ n_workers : int
+ Number of workers in data loader
+ distributed : bool
+ Run the script in distributed
"""
- parser = ArgumentParser(description="Training arguments")
- parser.add_argument(
- "--data_config",
- type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
- )
- parser.add_argument(
- "--batch_size",
- type=int,
- default=32,
- help="Batch size when iterating over the dataset",
- )
- parser.add_argument(
- "--step_length",
- type=int,
- default=3,
- help="Step length in hours to consider single time step (default: 3)",
- )
- parser.add_argument(
- "--n_workers",
- type=int,
- default=4,
- help="Number of workers in data loader (default: 4)",
- )
- parser.add_argument(
- "--distributed",
- action="store_true",
- help="Run the script in distributed mode (default: False)",
- )
- args = parser.parse_args()
- distributed = bool(args.distributed)
rank = get_rank()
world_size = get_world_size()
- config_loader = config.Config.from_file(args.data_config)
+ datastore = init_datastore(
+ datastore_kind="npyfilesmeps", config_path=datastore_config_path
+ )
- if distributed:
+ static_dir_path = Path(datastore_config_path).parent / "static"
+ os.makedirs(static_dir_path, exist_ok=True)
+ if distributed:
setup(rank, world_size)
device = torch.device(
f"cuda:{rank}" if torch.cuda.is_available() else "cpu"
)
torch.cuda.set_device(device) if torch.cuda.is_available() else None
- if rank == 0:
- static_dir_path = os.path.join(
- "data", config_loader.dataset.name, "static"
- )
- # Create parameter weights based on height
- # based on fig A.1 in graph cast paper
- w_dict = {
- "2": 1.0,
- "0": 0.1,
- "65": 0.065,
- "1000": 0.1,
- "850": 0.05,
- "500": 0.03,
- }
- w_list = np.array(
- [
- w_dict[par.split("_")[-2]]
- for par in config_loader.dataset.var_longnames
- ]
- )
- print("Saving parameter weights...")
- np.save(
- os.path.join(static_dir_path, "parameter_weights.npy"),
- w_list.astype("float32"),
- )
-
- # Load dataset without any subsampling
+ # Setting this to the original value of the Oskarsson et al. paper (2023)
+ # 65 forecast steps - 2 initial steps = 63
+ ar_steps = 63
ds = WeatherDataset(
- config_loader.dataset.name,
+ datastore=datastore,
split="train",
- subsample_step=1,
- pred_length=63,
+ ar_steps=ar_steps,
standardize=False,
+ num_past_forcing_steps=0,
+ num_future_forcing_steps=0,
)
if distributed:
ds = PaddedWeatherDataset(
ds,
world_size,
- args.batch_size,
+ batch_size,
)
sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=False
@@ -221,9 +191,9 @@ def main():
sampler = None
loader = torch.utils.data.DataLoader(
ds,
- args.batch_size,
+ batch_size,
shuffle=False,
- num_workers=args.n_workers,
+ num_workers=n_workers,
sampler=sampler,
)
@@ -231,7 +201,7 @@ def main():
print("Computing mean and std.-dev. for parameters...")
means, squares, flux_means, flux_squares = [], [], [], []
- for init_batch, target_batch, forcing_batch in tqdm(loader):
+ for init_batch, target_batch, forcing_batch, _ in tqdm(loader):
if distributed:
init_batch, target_batch, forcing_batch = (
init_batch.to(device),
@@ -240,8 +210,8 @@ def main():
)
# (N_batch, N_t, N_grid, d_features)
batch = torch.cat((init_batch, target_batch), dim=1)
- # Flux at 1st windowed position is index 1 in forcing
- flux_batch = forcing_batch[:, :, :, 1]
+ # Flux at 1st windowed position is index 0 in forcing
+ flux_batch = forcing_batch[:, :, :, 0]
# (N_batch, d_features,)
means.append(torch.mean(batch, dim=(1, 2)).cpu())
squares.append(
@@ -254,29 +224,34 @@ def main():
means_gathered, squares_gathered = [None] * world_size, [
None
] * world_size
- flux_means_gathered, flux_squares_gathered = [None] * world_size, [
- None
- ] * world_size
+ flux_means_gathered, flux_squares_gathered = (
+ [None] * world_size,
+ [None] * world_size,
+ )
dist.all_gather_object(means_gathered, torch.cat(means, dim=0))
dist.all_gather_object(squares_gathered, torch.cat(squares, dim=0))
dist.all_gather_object(flux_means_gathered, flux_means)
dist.all_gather_object(flux_squares_gathered, flux_squares)
if rank == 0:
- means_gathered, squares_gathered = torch.cat(
- means_gathered, dim=0
- ), torch.cat(squares_gathered, dim=0)
- flux_means_gathered, flux_squares_gathered = torch.tensor(
- flux_means_gathered
- ), torch.tensor(flux_squares_gathered)
+ means_gathered, squares_gathered = (
+ torch.cat(means_gathered, dim=0),
+ torch.cat(squares_gathered, dim=0),
+ )
+ flux_means_gathered, flux_squares_gathered = (
+ torch.tensor(flux_means_gathered),
+ torch.tensor(flux_squares_gathered),
+ )
original_indices = ds.get_original_indices()
- means, squares = [means_gathered[i] for i in original_indices], [
- squares_gathered[i] for i in original_indices
- ]
- flux_means, flux_squares = [
- flux_means_gathered[i] for i in original_indices
- ], [flux_squares_gathered[i] for i in original_indices]
+ means, squares = (
+ [means_gathered[i] for i in original_indices],
+ [squares_gathered[i] for i in original_indices],
+ )
+ flux_means, flux_squares = (
+ [flux_means_gathered[i] for i in original_indices],
+ [flux_squares_gathered[i] for i in original_indices],
+ )
else:
means = [torch.cat(means, dim=0)] # (N_batch, d_features,)
squares = [torch.cat(squares, dim=0)] # (N_batch, d_features,)
@@ -299,17 +274,18 @@ def main():
if rank == 0:
print("Computing mean and std.-dev. for one-step differences...")
ds_standard = WeatherDataset(
- config_loader.dataset.name,
+ datastore=datastore,
split="train",
- subsample_step=1,
- pred_length=63,
+ ar_steps=ar_steps,
standardize=True,
+ num_past_forcing_steps=0,
+ num_future_forcing_steps=0,
) # Re-load with standardization
if distributed:
ds_standard = PaddedWeatherDataset(
ds_standard,
world_size,
- args.batch_size,
+ batch_size,
)
sampler_standard = DistributedSampler(
ds_standard, num_replicas=world_size, rank=rank, shuffle=False
@@ -318,16 +294,18 @@ def main():
sampler_standard = None
loader_standard = torch.utils.data.DataLoader(
ds_standard,
- args.batch_size,
+ batch_size,
shuffle=False,
- num_workers=args.n_workers,
+ num_workers=n_workers,
sampler=sampler_standard,
)
- used_subsample_len = (65 // args.step_length) * args.step_length
+ used_subsample_len = (65 // step_length) * step_length
diff_means, diff_squares = [], []
- for init_batch, target_batch, _ in tqdm(loader_standard, disable=rank != 0):
+ for init_batch, target_batch, _, _ in tqdm(
+ loader_standard, disable=rank != 0
+ ):
if distributed:
init_batch, target_batch = init_batch.to(device), target_batch.to(
device
@@ -337,13 +315,13 @@ def main():
# Note: batch contains only 1h-steps
stepped_batch = torch.cat(
[
- batch[:, ss_i : used_subsample_len : args.step_length]
- for ss_i in range(args.step_length)
+ batch[:, ss_i:used_subsample_len:step_length]
+ for ss_i in range(step_length)
],
dim=0,
)
# (N_batch', N_t, N_grid, d_features),
- # N_batch' = args.step_length*N_batch
+ # N_batch' = step_length*N_batch
batch_diffs = stepped_batch[:, 1:] - stepped_batch[:, :-1]
# (N_batch', N_t-1, N_grid, d_features)
diff_means.append(torch.mean(batch_diffs, dim=(1, 2)).cpu())
@@ -353,9 +331,10 @@ def main():
if distributed and world_size > 1:
dist.barrier()
- diff_means_gathered, diff_squares_gathered = [None] * world_size, [
- None
- ] * world_size
+ diff_means_gathered, diff_squares_gathered = (
+ [None] * world_size,
+ [None] * world_size,
+ )
dist.all_gather_object(
diff_means_gathered, torch.cat(diff_means, dim=0)
)
@@ -364,19 +343,21 @@ def main():
)
if rank == 0:
- diff_means_gathered, diff_squares_gathered = torch.cat(
- diff_means_gathered, dim=0
- ).view(-1, *diff_means[0].shape), torch.cat(
- diff_squares_gathered, dim=0
- ).view(
- -1, *diff_squares[0].shape
+ diff_means_gathered, diff_squares_gathered = (
+ torch.cat(diff_means_gathered, dim=0).view(
+ -1, *diff_means[0].shape
+ ),
+ torch.cat(diff_squares_gathered, dim=0).view(
+ -1, *diff_squares[0].shape
+ ),
)
original_indices = ds_standard.get_original_window_indices(
- args.step_length
+ step_length
+ )
+ diff_means, diff_squares = (
+ [diff_means_gathered[i] for i in original_indices],
+ [diff_squares_gathered[i] for i in original_indices],
)
- diff_means, diff_squares = [
- diff_means_gathered[i] for i in original_indices
- ], [diff_squares_gathered[i] for i in original_indices]
diff_means = [torch.cat(diff_means, dim=0)] # (N_batch', d_features,)
diff_squares = [torch.cat(diff_squares, dim=0)] # (N_batch', d_features,)
@@ -388,5 +369,47 @@ def main():
dist.destroy_process_group()
+def cli():
+ parser = ArgumentParser(description="Training arguments")
+ parser.add_argument(
+ "--datastore_config_path",
+ type=str,
+ help="Path to data config file",
+ )
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ default=32,
+ help="Batch size when iterating over the dataset",
+ )
+ parser.add_argument(
+ "--step_length",
+ type=int,
+ default=3,
+ help="Step length in hours to consider single time step (default: 3)",
+ )
+ parser.add_argument(
+ "--n_workers",
+ type=int,
+ default=4,
+ help="Number of workers in data loader (default: 4)",
+ )
+ parser.add_argument(
+ "--distributed",
+ action="store_true",
+ help="Run the script in distributed mode (default: False)",
+ )
+ args = parser.parse_args()
+ distributed = bool(args.distributed)
+
+ main(
+ datastore_config_path=args.datastore_config_path,
+ batch_size=args.batch_size,
+ step_length=args.step_length,
+ n_workers=args.n_workers,
+ distributed=distributed,
+ )
+
+
if __name__ == "__main__":
- main()
+ cli()
diff --git a/neural_lam/datastore/npyfilesmeps/config.py b/neural_lam/datastore/npyfilesmeps/config.py
new file mode 100644
index 00000000..1a9d7295
--- /dev/null
+++ b/neural_lam/datastore/npyfilesmeps/config.py
@@ -0,0 +1,66 @@
+# Standard library
+from dataclasses import dataclass, field
+from typing import Any, Dict, List
+
+# Third-party
+import dataclass_wizard
+
+
+@dataclass
+class Projection:
+ """Represents the projection information for a dataset, including the type
+ of projection and its parameters. Capable of creating a cartopy.crs
+ projection object.
+
+ Attributes:
+ class_name: The class name of the projection, this should be a valid
+ cartopy.crs class.
+ kwargs: A dictionary of keyword arguments specific to the projection
+ type.
+
+ """
+
+ class_name: str
+ kwargs: Dict[str, Any]
+
+
+@dataclass
+class Dataset:
+ """Contains information about the dataset, including variable names, units,
+ and descriptions.
+
+ Attributes:
+ name: The name of the dataset.
+ var_names: A list of variable names in the dataset.
+ var_units: A list of units for each variable.
+ var_longnames: A list of long, descriptive names for each variable.
+ num_forcing_features: The number of forcing features in the dataset.
+
+ """
+
+ name: str
+ var_names: List[str]
+ var_units: List[str]
+ var_longnames: List[str]
+ num_forcing_features: int
+ num_timesteps: int
+ step_length: int
+ num_ensemble_members: int
+ remove_state_features_with_index: List[int] = field(default_factory=list)
+
+
+@dataclass
+class NpyDatastoreConfig(dataclass_wizard.YAMLWizard):
+ """Configuration for loading and processing a dataset, including dataset
+ details, grid shape, and projection information.
+
+ Attributes:
+ dataset: An instance of Dataset containing details about the dataset.
+ grid_shape_state: A list representing the shape of the grid state.
+ projection: An instance of Projection containing projection details.
+
+ """
+
+ dataset: Dataset
+ grid_shape_state: List[int]
+ projection: Projection
diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py
new file mode 100644
index 00000000..42e80706
--- /dev/null
+++ b/neural_lam/datastore/npyfilesmeps/store.py
@@ -0,0 +1,788 @@
+"""
+Numpy-files based datastore to support the MEPS example dataset introduced in
+neural-lam v0.1.0.
+"""
+
+# Standard library
+import functools
+import re
+import warnings
+from functools import cached_property
+from pathlib import Path
+from typing import List
+
+# Third-party
+import cartopy.crs as ccrs
+import dask
+import dask.array
+import dask.delayed
+import numpy as np
+import parse
+import torch
+import xarray as xr
+from xarray.core.dataarray import DataArray
+
+# Local
+from ..base import BaseRegularGridDatastore, CartesianGridShape
+from .config import NpyDatastoreConfig
+
+STATE_FILENAME_FORMAT = "nwp_{analysis_time:%Y%m%d%H}_mbr{member_id:03d}.npy"
+TOA_SW_DOWN_FLUX_FILENAME_FORMAT = (
+ "nwp_toa_downwelling_shortwave_flux_{analysis_time:%Y%m%d%H}.npy"
+)
+OPEN_WATER_FILENAME_FORMAT = "wtr_{analysis_time:%Y%m%d%H}.npy"
+
+
+def _load_np(fp, add_feature_dim, feature_dim_mask=None):
+ arr = np.load(fp)
+ if add_feature_dim:
+ arr = arr[..., np.newaxis]
+ if feature_dim_mask is not None:
+ arr = arr[..., feature_dim_mask]
+ return arr
+
+
+class NpyFilesDatastoreMEPS(BaseRegularGridDatastore):
+ __doc__ = f"""
+ Represents a dataset stored as numpy files on disk. The dataset is assumed
+ to be stored in a directory structure where each sample is stored in a
+ separate file. The file-name format is assumed to be
+ '{STATE_FILENAME_FORMAT}'
+
+ The MEPS dataset is organised into three splits: train, val, and test. Each
+ split has a set of files which are:
+
+ - `{STATE_FILENAME_FORMAT}`:
+ The state variables for a forecast started at `analysis_time` with
+ member id `member_id`. The dimensions of the array are
+ `[forecast_timestep, y, x, feature]`.
+
+ - `{TOA_SW_DOWN_FLUX_FILENAME_FORMAT}`:
+ The top-of-atmosphere downwelling shortwave flux at `time`. The
+ dimensions of the array are `[forecast_timestep, y, x]`.
+
+ - `{OPEN_WATER_FILENAME_FORMAT}`:
+ The open water fraction at `time`. The dimensions of the array are
+ `[y, x]`.
+
+
+ Folder structure:
+
+ meps_example_reduced
+ ├── data_config.yaml
+ ├── samples
+ │ ├── test
+ │ │ ├── nwp_2022090100_mbr000.npy
+ │ │ ├── nwp_2022090100_mbr001.npy
+ │ │ ├── nwp_2022090112_mbr000.npy
+ │ │ ├── nwp_2022090112_mbr001.npy
+ │ │ ├── ...
+ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022090100.npy
+ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022090112.npy
+ │ │ ├── ...
+ │ │ ├── wtr_2022090100.npy
+ │ │ ├── wtr_2022090112.npy
+ │ │ └── ...
+ │ ├── train
+ │ │ ├── nwp_2022040100_mbr000.npy
+ │ │ ├── nwp_2022040100_mbr001.npy
+ │ │ ├── ...
+ │ │ ├── nwp_2022040112_mbr000.npy
+ │ │ ├── nwp_2022040112_mbr001.npy
+ │ │ ├── ...
+ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040100.npy
+ │ │ ├── nwp_toa_downwelling_shortwave_flux_2022040112.npy
+ │ │ ├── ...
+ │ │ ├── wtr_2022040100.npy
+ │ │ ├── wtr_2022040112.npy
+ │ │ └── ...
+ │ └── val
+ │ ├── nwp_2022060500_mbr000.npy
+ │ ├── nwp_2022060500_mbr001.npy
+ │ ├── ...
+ │ ├── nwp_2022060512_mbr000.npy
+ │ ├── nwp_2022060512_mbr001.npy
+ │ ├── ...
+ │ ├── nwp_toa_downwelling_shortwave_flux_2022060500.npy
+ │ ├── nwp_toa_downwelling_shortwave_flux_2022060512.npy
+ │ ├── ...
+ │ ├── wtr_2022060500.npy
+ │ ├── wtr_2022060512.npy
+ │ └── ...
+ └── static
+ ├── border_mask.npy
+ ├── diff_mean.pt
+ ├── diff_std.pt
+ ├── flux_stats.pt
+ ├── grid_features.pt
+ ├── nwp_xy.npy
+ ├── parameter_mean.pt
+ ├── parameter_std.pt
+ ├── parameter_weights.npy
+ └── surface_geopotential.npy
+
+ For the MEPS dataset:
+ N_t' = 65
+ N_t = 65//subsample_step (= 21 for 3h steps)
+ dim_y = 268
+ dim_x = 238
+ N_grid = 268x238 = 63784
+ d_features = 17 (d_features' = 18)
+ d_forcing = 5
+
+ For the MEPS reduced dataset:
+ N_t' = 65
+ N_t = 65//subsample_step (= 21 for 3h steps)
+ dim_y = 134
+ dim_x = 119
+ N_grid = 134x119 = 15946
+ d_features = 8
+ d_forcing = 1
+ """
+ SHORT_NAME = "npyfilesmeps"
+
+ is_ensemble = True
+ is_forecast = True
+
+ def __init__(
+ self,
+ config_path,
+ ):
+ """
+ Create a new NpyFilesDatastore using the configuration file at the
+ given path. The config file should be a YAML file and will be loaded
+ into an instance of the `NpyDatastoreConfig` dataclass.
+
+ Internally, the datastore uses dask.delayed to load the data from the
+ numpy files, so that the data isn't actually loaded until it's needed.
+
+ Parameters
+ ----------
+ config_path : str
+ The path to the configuration file for the datastore.
+
+ """
+ self._config_path = Path(config_path)
+ self._root_path = self._config_path.parent
+ self._config = NpyDatastoreConfig.from_yaml_file(self._config_path)
+
+ self._num_ensemble_members = self.config.dataset.num_ensemble_members
+ self._num_timesteps = self.config.dataset.num_timesteps
+ self._step_length = self.config.dataset.step_length
+ self._remove_state_features_with_index = (
+ self.config.dataset.remove_state_features_with_index
+ )
+
+ @property
+ def root_path(self) -> Path:
+ """
+ The root path of the datastore on disk. This is the directory relative
+ to which graphs and other files can be stored.
+
+ Returns
+ -------
+ Path
+ The root path of the datastore
+
+ """
+ return self._root_path
+
+ @property
+ def config(self) -> NpyDatastoreConfig:
+ """The configuration for the datastore.
+
+ Returns
+ -------
+ NpyDatastoreConfig
+ The configuration for the datastore.
+
+ """
+ return self._config
+
+ def get_dataarray(self, category: str, split: str) -> DataArray:
+ """
+ Get the data array for the given category and split of data. If the
+ category is 'state', the data array will be a concatenation of the data
+ arrays for all ensemble members. The data will be loaded as a dask
+ array, so that the data isn't actually loaded until it's needed.
+
+ Parameters
+ ----------
+ category : str
+ The category of the data to load. One of 'state', 'forcing', or
+ 'static'.
+ split : str
+ The dataset split to load the data for. One of 'train', 'val', or
+ 'test'.
+
+ Returns
+ -------
+ xr.DataArray
+ The data array for the given category and split, with dimensions
+ per category:
+ state: `[elapsed_forecast_duration, analysis_time, grid_index,
+ feature, ensemble_member]`
+ forcing: `[elapsed_forecast_duration, analysis_time, grid_index,
+ feature]`
+ static: `[grid_index, feature]`
+
+ """
+ if category == "state":
+ das = []
+ # for the state category, we need to load all ensemble members
+ for member in range(self._num_ensemble_members):
+ da_member = self._get_single_timeseries_dataarray(
+ features=self.get_vars_names(category="state"),
+ split=split,
+ member=member,
+ )
+ das.append(da_member)
+ da = xr.concat(das, dim="ensemble_member")
+
+ elif category == "forcing":
+ # the forcing features are in separate files, so we need to load
+ # them separately
+ features = ["toa_downwelling_shortwave_flux", "open_water_fraction"]
+ das = [
+ self._get_single_timeseries_dataarray(
+ features=[feature], split=split
+ )
+ for feature in features
+ ]
+ da = xr.concat(das, dim="feature")
+
+ # add datetime forcing as a feature
+ # to do this we create a forecast time variable which has the
+ # dimensions of (analysis_time, elapsed_forecast_duration) with
+ # values that are the actual forecast time of each time step. By
+ # calling .chunk({"elapsed_forecast_duration": 1}) this time
+ # variable is turned into a dask array and so execution of the
+ # calculation is delayed until the feature values are actually
+ # used.
+ da_forecast_time = (
+ da.analysis_time + da.elapsed_forecast_duration
+ ).chunk({"elapsed_forecast_duration": 1})
+ da_datetime_forcing_features = self._calc_datetime_forcing_features(
+ da_time=da_forecast_time
+ )
+ da = xr.concat([da, da_datetime_forcing_features], dim="feature")
+
+ elif category == "static":
+ # the static features are collected in three files:
+ # - surface_geopotential
+ # - border_mask
+ # - x, y
+ das = []
+ for features in [
+ ["surface_geopotential"],
+ ["border_mask"],
+ ["x", "y"],
+ ]:
+ da = self._get_single_timeseries_dataarray(
+ features=features, split=split
+ )
+ das.append(da)
+ da = xr.concat(das, dim="feature")
+
+ else:
+ raise NotImplementedError(category)
+
+ da = da.rename(dict(feature=f"{category}_feature"))
+
+ # stack the [x, y] dimensions into a `grid_index` dimension
+ da = self.stack_grid_coords(da)
+
+ # check that we have the right features
+ actual_features = da[f"{category}_feature"].values.tolist()
+ expected_features = self.get_vars_names(category=category)
+ if actual_features != expected_features:
+ raise ValueError(
+ f"Expected features {expected_features}, got {actual_features}"
+ )
+
+ dim_order = self.expected_dim_order(category=category)
+ da = da.transpose(*dim_order)
+
+ return da
+
+ def _get_single_timeseries_dataarray(
+ self, features: List[str], split: str, member: int = None
+ ) -> DataArray:
+ """
+ Get the data array spanning the complete time series for a given set of
+ features and split of data. For state features the `member` argument
+ should be specified to select the ensemble member to load. The data
+ will be loaded using dask.delayed, so that the data isn't actually
+ loaded until it's needed.
+
+ Parameters
+ ----------
+ features : List[str]
+ The list of features to load the data for. For the 'state'
+ category, this should be the result of
+ `self.get_vars_names(category="state")`, for the 'forcing' category
+ this should be the list of forcing features to load, and for the
+ 'static' category this should be the list of static features to
+ load.
+ split : str
+ The dataset split to load the data for. One of 'train', 'val', or
+ 'test'.
+ member : int, optional
+ The ensemble member to load. Only applicable for the 'state'
+ category.
+
+ Returns
+ -------
+ xr.DataArray
+ The data array for the given category and split, with dimensions
+ `[elapsed_forecast_duration, analysis_time, grid_index, feature]`
+ for all categories of data
+
+ """
+ if (
+ set(features).difference(self.get_vars_names(category="static"))
+ == set()
+ ):
+ assert split in (
+ "train",
+ "val",
+ "test",
+ None,
+ ), "Unknown dataset split"
+ else:
+ assert split in (
+ "train",
+ "val",
+ "test",
+ ), f"Unknown dataset split {split} for features {features}"
+
+ if member is not None and features != self.get_vars_names(
+ category="state"
+ ):
+ raise ValueError(
+ "Member can only be specified for the 'state' category"
+ )
+
+ concat_axis = 0
+
+ file_params = {}
+ add_feature_dim = False
+ features_vary_with_analysis_time = True
+ feature_dim_mask = None
+ if features == self.get_vars_names(category="state"):
+ filename_format = STATE_FILENAME_FORMAT
+ file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
+ # only select one member for now
+ file_params["member_id"] = member
+ fp_samples = self.root_path / "samples" / split
+ if self._remove_state_features_with_index:
+ n_to_drop = len(self._remove_state_features_with_index)
+ feature_dim_mask = np.ones(
+ len(features) + n_to_drop, dtype=bool
+ )
+ feature_dim_mask[self._remove_state_features_with_index] = False
+ elif features == ["toa_downwelling_shortwave_flux"]:
+ filename_format = TOA_SW_DOWN_FLUX_FILENAME_FORMAT
+ file_dims = ["elapsed_forecast_duration", "y", "x", "feature"]
+ add_feature_dim = True
+ fp_samples = self.root_path / "samples" / split
+ elif features == ["open_water_fraction"]:
+ filename_format = OPEN_WATER_FILENAME_FORMAT
+ file_dims = ["y", "x", "feature"]
+ add_feature_dim = True
+ fp_samples = self.root_path / "samples" / split
+ elif features == ["surface_geopotential"]:
+ filename_format = "surface_geopotential.npy"
+ file_dims = ["y", "x", "feature"]
+ add_feature_dim = True
+ features_vary_with_analysis_time = False
+ # XXX: surface_geopotential is the same for all splits, and so
+ # saved in static/
+ fp_samples = self.root_path / "static"
+ elif features == ["border_mask"]:
+ filename_format = "border_mask.npy"
+ file_dims = ["y", "x", "feature"]
+ add_feature_dim = True
+ features_vary_with_analysis_time = False
+ # XXX: border_mask is the same for all splits, and so saved in
+ # static/
+ fp_samples = self.root_path / "static"
+ elif features == ["x", "y"]:
+ filename_format = "nwp_xy.npy"
+ # NB: for x, y the feature dimension is the first one
+ file_dims = ["feature", "y", "x"]
+ features_vary_with_analysis_time = False
+ # XXX: x, y are the same for all splits, and so saved in static/
+ fp_samples = self.root_path / "static"
+ else:
+ raise NotImplementedError(
+ f"Reading of variables set `{features}` not supported"
+ )
+
+ if features_vary_with_analysis_time:
+ dims = ["analysis_time"] + file_dims
+ else:
+ dims = file_dims
+
+ coords = {}
+ arr_shape = []
+
+ xy = self.get_xy(category="state", stacked=False)
+ xs = xy[:, :, 0]
+ ys = xy[:, :, 1]
+ # Check if x-coordinates are constant along columns
+ assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant"
+ # Check if y-coordinates are constant along rows
+ assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant"
+ # Extract unique x and y coordinates
+ x = xs[:, 0] # Unique x-coordinates (changes along the first axis)
+ y = ys[0, :] # Unique y-coordinates (changes along the second axis)
+ for d in dims:
+ if d == "elapsed_forecast_duration":
+ coord_values = (
+ self.step_length
+ * np.arange(self._num_timesteps)
+ * np.timedelta64(1, "h")
+ )
+ elif d == "analysis_time":
+ coord_values = self._get_analysis_times(split=split)
+ elif d == "y":
+ coord_values = y
+ elif d == "x":
+ coord_values = x
+ elif d == "feature":
+ coord_values = features
+ else:
+ raise NotImplementedError(f"Dimension {d} not supported")
+
+ coords[d] = coord_values
+ if d != "analysis_time":
+ # analysis_time varies across the different files, but not
+ # within a single file
+ arr_shape.append(len(coord_values))
+
+ if features_vary_with_analysis_time:
+ filepaths = [
+ fp_samples
+ / filename_format.format(
+ analysis_time=analysis_time, **file_params
+ )
+ for analysis_time in coords["analysis_time"]
+ ]
+ else:
+ filepaths = [fp_samples / filename_format.format(**file_params)]
+
+ # use dask.delayed to load the numpy files, so that loading isn't
+ # done until the data is actually needed
+ arrays = [
+ dask.array.from_delayed(
+ dask.delayed(_load_np)(
+ fp=fp,
+ add_feature_dim=add_feature_dim,
+ feature_dim_mask=feature_dim_mask,
+ ),
+ shape=arr_shape,
+ dtype=np.float32,
+ )
+ for fp in filepaths
+ ]
+
+ # read a single timestep and check the shape
+ arr0 = arrays[0].compute()
+ if not list(arr0.shape) == arr_shape:
+ raise Exception(
+ f"Expected shape {arr_shape} for a single file, got "
+ f"{list(arr0.shape)}. Maybe the number of features given "
+ f"in the datastore config ({features}) is incorrect?"
+ )
+
+ if features_vary_with_analysis_time:
+ arr_all = dask.array.stack(arrays, axis=concat_axis)
+ else:
+ arr_all = arrays[0]
+
+ da = xr.DataArray(arr_all, dims=dims, coords=coords)
+
+ return da
+
+ def _get_analysis_times(self, split) -> List[np.datetime64]:
+ """Get the analysis times for the given split by parsing the filenames
+ of all the files found for the given split.
+
+ Parameters
+ ----------
+ split : str
+ The dataset split to get the analysis times for.
+
+ Returns
+ -------
+ List[dt.datetime]
+ The analysis times for the given split.
+
+ """
+ pattern = re.sub(r"{analysis_time:[^}]*}", "*", STATE_FILENAME_FORMAT)
+ pattern = re.sub(r"{member_id:[^}]*}", "*", pattern)
+
+ sample_dir = self.root_path / "samples" / split
+ sample_files = sample_dir.glob(pattern)
+ times = []
+ for fp in sample_files:
+ name_parts = parse.parse(STATE_FILENAME_FORMAT, fp.name)
+ times.append(name_parts["analysis_time"])
+
+ if len(times) == 0:
+ raise ValueError(
+ f"No files found in {sample_dir} with pattern {pattern}"
+ )
+
+ return times
+
+ def _calc_datetime_forcing_features(self, da_time: xr.DataArray):
+ da_hour_angle = da_time.dt.hour / 12 * np.pi
+ da_year_angle = da_time.dt.dayofyear / 365 * 2 * np.pi
+
+ da_datetime_forcing = xr.concat(
+ (
+ np.sin(da_hour_angle),
+ np.cos(da_hour_angle),
+ np.sin(da_year_angle),
+ np.cos(da_year_angle),
+ ),
+ dim="feature",
+ )
+ da_datetime_forcing = (da_datetime_forcing + 1) / 2 # Rescale to [0,1]
+ da_datetime_forcing["feature"] = [
+ "sin_hour",
+ "cos_hour",
+ "sin_year",
+ "cos_year",
+ ]
+
+ return da_datetime_forcing
+
+ def get_vars_units(self, category: str) -> List[str]:
+ if category == "state":
+ return self.config.dataset.var_units
+ elif category == "forcing":
+ return [
+ "W/m^2",
+ "1",
+ "1",
+ "1",
+ "1",
+ "1",
+ ]
+ elif category == "static":
+ return ["m^2/s^2", "1", "m", "m"]
+ else:
+ raise NotImplementedError(f"Category {category} not supported")
+
+ def get_vars_names(self, category: str) -> List[str]:
+ if category == "state":
+ return self.config.dataset.var_names
+ elif category == "forcing":
+ # XXX: this really shouldn't be hard-coded here, this should be in
+ # the config
+ return [
+ "toa_downwelling_shortwave_flux",
+ "open_water_fraction",
+ "sin_hour",
+ "cos_hour",
+ "sin_year",
+ "cos_year",
+ ]
+ elif category == "static":
+ return ["surface_geopotential", "border_mask", "x", "y"]
+ else:
+ raise NotImplementedError(f"Category {category} not supported")
+
+ def get_vars_long_names(self, category: str) -> List[str]:
+ if category == "state":
+ return self.config.dataset.var_longnames
+ else:
+ # TODO: should we add these?
+ return self.get_vars_names(category=category)
+
+ def get_num_data_vars(self, category: str) -> int:
+ return len(self.get_vars_names(category=category))
+
+ def get_xy(self, category: str, stacked: bool) -> np.ndarray:
+ """Return the x, y coordinates of the dataset.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ stacked : bool
+ Whether to stack the x, y coordinates.
+
+ Returns
+ -------
+ np.ndarray
+ The x, y coordinates of the dataset (with x first then y second),
+ returned differently based on the value of `stacked`:
+ - `stacked==True`: shape `(n_grid_points, 2)` where
+ n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(N_x, N_y, 2)`
+
+ """
+
+ # the array on disk has shape [2, N_y, N_x], where dimension 0
+ # contains the [x,y] coordinate pairs for each grid point
+ arr = np.load(self.root_path / "static" / "nwp_xy.npy")
+ arr_shape = arr.shape
+
+ assert arr_shape[0] == 2, "Expected 2D array"
+ grid_shape = self.grid_shape_state
+ assert arr_shape[1:] == (grid_shape.y, grid_shape.x), "Unexpected shape"
+
+ arr = arr.transpose(2, 1, 0)
+
+ if stacked:
+ return arr.reshape(-1, 2)
+ else:
+ return arr
+
+ @property
+ def step_length(self) -> int:
+ """The length of each time step in hours.
+
+ Returns
+ -------
+ int
+ The length of each time step in hours.
+
+ """
+ return self._step_length
+
+ @cached_property
+ def grid_shape_state(self) -> CartesianGridShape:
+ """The shape of the cartesian grid for the state variables.
+
+ Returns
+ -------
+ CartesianGridShape
+ The shape of the cartesian grid for the state variables.
+
+ """
+ ny, nx = self.config.grid_shape_state
+ return CartesianGridShape(x=nx, y=ny)
+
+ @cached_property
+ def boundary_mask(self) -> xr.DataArray:
+ """The boundary mask for the dataset. This is a binary mask that is 1
+ where the grid cell is on the boundary of the domain, and 0 otherwise.
+
+ Returns
+ -------
+ xr.DataArray
+ The boundary mask for the dataset, with dimensions `[grid_index]`.
+
+ """
+ xy = self.get_xy(category="state", stacked=False)
+ xs = xy[:, :, 0]
+ ys = xy[:, :, 1]
+ # Check if x-coordinates are constant along columns
+ assert np.allclose(xs, xs[:, [0]]), "x-coordinates are not constant"
+ # Check if y-coordinates are constant along rows
+ assert np.allclose(ys, ys[[0], :]), "y-coordinates are not constant"
+ # Extract unique x and y coordinates
+ x = xs[:, 0] # Unique x-coordinates (changes along the first axis)
+ y = ys[0, :] # Unique y-coordinates (changes along the second axis)
+ values = np.load(self.root_path / "static" / "border_mask.npy")
+ da_mask = xr.DataArray(
+ values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask"
+ )
+ da_mask_stacked_xy = self.stack_grid_coords(da_mask).astype(int)
+ return da_mask_stacked_xy
+
+ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
+ """Return the standardization dataarray for the given category. This
+ should contain a `{category}_mean` and `{category}_std` variable for
+ each variable in the category. For `category=="state"`, the dataarray
+ should also contain a `state_diff_mean` and `state_diff_std` variable
+ for the one- step differences of the state variables.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ xr.Dataset
+ The standardization dataarray for the given category, with
+ variables for the mean and standard deviation of the variables (and
+ differences for state variables).
+
+ """
+
+ def load_pickled_tensor(fn):
+ return torch.load(
+ self.root_path / "static" / fn, weights_only=True
+ ).numpy()
+
+ mean_diff_values = None
+ std_diff_values = None
+ if category == "state":
+ mean_values = load_pickled_tensor("parameter_mean.pt")
+ std_values = load_pickled_tensor("parameter_std.pt")
+ try:
+ mean_diff_values = load_pickled_tensor("diff_mean.pt")
+ std_diff_values = load_pickled_tensor("diff_std.pt")
+ except FileNotFoundError:
+ warnings.warn(f"Could not load diff mean/std for {category}")
+ # XXX: this is a hack, but when running
+ # compute_standardization_stats the diff mean/std files are
+ # created, but require the std and mean files
+ mean_diff_values = np.empty_like(mean_values)
+ std_diff_values = np.empty_like(std_values)
+
+ elif category == "forcing":
+ flux_stats = load_pickled_tensor("flux_stats.pt") # (2,)
+ flux_mean, flux_std = flux_stats
+ # manually add hour sin/cos and day-of-year sin/cos stats for now
+ # the mean/std for open_water_fraction is hardcoded for now
+ mean_values = np.array([flux_mean, 0.0, 0.0, 0.0, 0.0, 0.0])
+ std_values = np.array([flux_std, 1.0, 1.0, 1.0, 1.0, 1.0])
+
+ elif category == "static":
+ da_static = self.get_dataarray(category="static", split="train")
+ da_static_mean = da_static.mean(dim=["grid_index"]).compute()
+ da_static_std = da_static.std(dim=["grid_index"]).compute()
+ mean_values = da_static_mean.values
+ std_values = da_static_std.values
+ else:
+ raise NotImplementedError(f"Category {category} not supported")
+
+ feature_dim_name = f"{category}_feature"
+ variables = {
+ f"{category}_mean": (feature_dim_name, mean_values),
+ f"{category}_std": (feature_dim_name, std_values),
+ }
+
+ if mean_diff_values is not None and std_diff_values is not None:
+ variables["state_diff_mean"] = (feature_dim_name, mean_diff_values)
+ variables["state_diff_std"] = (feature_dim_name, std_diff_values)
+
+ ds_norm = xr.Dataset(
+ variables,
+ coords={feature_dim_name: self.get_vars_names(category=category)},
+ )
+
+ return ds_norm
+
+ @functools.cached_property
+ def coords_projection(self) -> ccrs.Projection:
+ """The projection of the spatial coordinates.
+
+ Returns
+ -------
+ ccrs.Projection
+ The projection of the spatial coordinates.
+
+ """
+ proj_class_name = self.config.projection.class_name
+ ProjectionClass = getattr(ccrs, proj_class_name)
+ proj_params = self.config.projection.kwargs
+ return ProjectionClass(**proj_params)
diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py
new file mode 100644
index 00000000..2d477271
--- /dev/null
+++ b/neural_lam/datastore/plot_example.py
@@ -0,0 +1,189 @@
+# Third-party
+import matplotlib.pyplot as plt
+
+# Local
+from . import DATASTORES, init_datastore
+
+
+def plot_example_from_datastore(
+ category,
+ datastore,
+ col_dim,
+ split="train",
+ standardize=True,
+ selection={},
+ index_selection={},
+):
+ """
+ Create a plot of the data from the datastore.
+
+ Parameters
+ ----------
+ category : str
+ Category of data to plot, one of "state", "forcing", or "static".
+ datastore : Datastore
+ Datastore to retrieve data from.
+ col_dim : str
+ Dimension to use for plot facetting into columns. This can be a
+ template string that can be formatted with the category name.
+ split : str, optional
+ Split of data to plot, by default "train".
+ standardize : bool, optional
+ Whether to standardize the data before plotting, by default True.
+ selection : dict, optional
+ Selections to apply to the dataarray, for example
+ `time="1990-09-03T0:00" would select this single timestep, by default
+ {}.
+ index_selection: dict, optional
+ Index-based selection to apply to the dataarray, for example
+ `time=0` would select the first item along the `time` dimension, by
+ default {}.
+
+ Returns
+ -------
+ Figure
+ Matplotlib figure object.
+ """
+ da = datastore.get_dataarray(category=category, split=split)
+ if standardize:
+ da_stats = datastore.get_standardization_dataarray(category=category)
+ da = (da - da_stats[f"{category}_mean"]) / da_stats[f"{category}_std"]
+ da = datastore.unstack_grid_coords(da)
+
+ if len(selection) > 0:
+ da = da.sel(**selection)
+ if len(index_selection) > 0:
+ da = da.isel(**index_selection)
+
+ col = col_dim.format(category=category)
+
+ # check that the column dimension exists and that the resulting shape is 2D
+ if col not in da.dims:
+ raise ValueError(f"Column dimension {col} not found in dataarray.")
+ da_col_item = da.isel({col: 0}).squeeze()
+ if not len(da_col_item.shape) == 2:
+ raise ValueError(
+ f"Column dimension {col} and selection {selection} does not "
+ "result in a 2D dataarray. Please adjust the column dimension "
+ "and/or selection. Instead the resulting dataarray is:\n"
+ f"{da_col_item}"
+ )
+
+ crs = datastore.coords_projection
+ col_wrap = min(4, int(da[col].count()))
+ g = da.plot(
+ x="x",
+ y="y",
+ col=col,
+ col_wrap=col_wrap,
+ subplot_kws={"projection": crs},
+ transform=crs,
+ size=4,
+ )
+ for ax in g.axes.flat:
+ ax.coastlines()
+ ax.gridlines(draw_labels=["left", "bottom"])
+ ax.set_extent(datastore.get_xy_extent(category=category), crs=crs)
+
+ return g.fig
+
+
+if __name__ == "__main__":
+ # Standard library
+ import argparse
+
+ def _parse_dict(arg_str):
+ key, value = arg_str.split("=")
+ for op in [int, float]:
+ try:
+ value = op(value)
+ break
+ except ValueError:
+ pass
+ return key, value
+
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "--datastore_kind",
+ type=str,
+ choices=DATASTORES.keys(),
+ default="mdp",
+ help="Kind of datastore to use",
+ )
+ parser.add_argument(
+ "--datastore_config_path",
+ type=str,
+ default=None,
+ help="Path for the datastore config",
+ )
+ parser.add_argument(
+ "--category",
+ default="state",
+ help="Category of data to plot",
+ choices=["state", "forcing", "static"],
+ )
+ parser.add_argument(
+ "--split", default="train", help="Split of data to plot"
+ )
+ parser.add_argument(
+ "--col-dim",
+ default="{category}_feature",
+ help="Dimension to use for plot facetting into columns",
+ )
+ parser.add_argument(
+ "--disable-standardize",
+ dest="standardize",
+ action="store_false",
+ help="Disable standardization of data",
+ )
+ # add the ability to create dictionary of kwargs
+ parser.add_argument(
+ "--selection",
+ nargs="+",
+ default=[],
+ type=_parse_dict,
+ help="Selections to apply to the dataarray, for example "
+ "`time='1990-09-03T0:00' would select this single timestep",
+ )
+ parser.add_argument(
+ "--index-selection",
+ nargs="+",
+ default=[],
+ type=_parse_dict,
+ help="Index-based selection to apply to the dataarray, for example "
+ "`time=0` would select the first item along the `time` dimension",
+ )
+ args = parser.parse_args()
+
+ assert (
+ args.datastore_config_path is not None
+ ), "Specify your datastore config with --datastore_config_path"
+
+ selection = dict(args.selection)
+ index_selection = dict(args.index_selection)
+
+ # check that column dimension is not in the selection
+ if args.col_dim.format(category=args.category) in selection:
+ raise ValueError(
+ f"Column dimension {args.col_dim.format(category=args.category)} "
+ f"cannot be in the selection ({selection}). Please adjust the "
+ "column dimension and/or selection."
+ )
+
+ datastore = init_datastore(
+ datastore_kind=args.datastore_kind,
+ config_path=args.datastore_config_path,
+ )
+
+ plot_example_from_datastore(
+ args.category,
+ datastore,
+ split=args.split,
+ col_dim=args.col_dim,
+ standardize=args.standardize,
+ selection=selection,
+ index_selection=index_selection,
+ )
+ plt.show()
diff --git a/neural_lam/loss_weighting.py b/neural_lam/loss_weighting.py
new file mode 100644
index 00000000..c842b202
--- /dev/null
+++ b/neural_lam/loss_weighting.py
@@ -0,0 +1,106 @@
+# Local
+from .config import (
+ ManualStateFeatureWeighting,
+ NeuralLAMConfig,
+ UniformFeatureWeighting,
+)
+from .datastore.base import BaseDatastore
+
+
+def get_manual_state_feature_weights(
+ weighting_config: ManualStateFeatureWeighting, datastore: BaseDatastore
+) -> list[float]:
+ """
+ Return the state feature weights as a list of floats in the order of the
+ state features in the datastore.
+
+ Parameters
+ ----------
+ weighting_config : ManualStateFeatureWeighting
+ Configuration object containing the manual state feature weights.
+ datastore : BaseDatastore
+ Datastore object containing the state features.
+
+ Returns
+ -------
+ list[float]
+ List of floats containing the state feature weights.
+ """
+ state_feature_names = datastore.get_vars_names(category="state")
+ feature_weight_names = weighting_config.weights.keys()
+
+ # Check that the state_feature_weights dictionary has a weight for each
+ # state feature in the datastore.
+ if set(feature_weight_names) != set(state_feature_names):
+ additional_features = set(feature_weight_names) - set(
+ state_feature_names
+ )
+ missing_features = set(state_feature_names) - set(feature_weight_names)
+ raise ValueError(
+ f"State feature weights must be provided for each state feature"
+ f"in the datastore ({state_feature_names}). {missing_features}"
+ " are missing and weights are defined for the features "
+ f"{additional_features} which are not in the datastore."
+ )
+
+ state_feature_weights = [
+ weighting_config.weights[feature] for feature in state_feature_names
+ ]
+ return state_feature_weights
+
+
+def get_uniform_state_feature_weights(datastore: BaseDatastore) -> list[float]:
+ """
+ Return the state feature weights as a list of floats in the order of the
+ state features in the datastore.
+
+ The weights are uniform, i.e. 1.0/n_features for each feature.
+
+ Parameters
+ ----------
+ datastore : BaseDatastore
+ Datastore object containing the state features.
+
+ Returns
+ -------
+ list[float]
+ List of floats containing the state feature weights.
+ """
+ state_feature_names = datastore.get_vars_names(category="state")
+ n_features = len(state_feature_names)
+ return [1.0 / n_features] * n_features
+
+
+def get_state_feature_weighting(
+ config: NeuralLAMConfig, datastore: BaseDatastore
+) -> list[float]:
+ """
+ Return the state feature weights as a list of floats in the order of the
+ state features in the datastore. The weights are determined based on the
+ configuration in the NeuralLAMConfig object.
+
+ Parameters
+ ----------
+ config : NeuralLAMConfig
+ Configuration object for neural-lam.
+ datastore : BaseDatastore
+ Datastore object containing the state features.
+
+ Returns
+ -------
+ list[float]
+ List of floats containing the state feature weights.
+ """
+ weighting_config = config.training.state_feature_weighting
+
+ if isinstance(weighting_config, ManualStateFeatureWeighting):
+ weights = get_manual_state_feature_weights(weighting_config, datastore)
+ elif isinstance(weighting_config, UniformFeatureWeighting):
+ weights = get_uniform_state_feature_weights(datastore)
+ else:
+ raise NotImplementedError(
+ "Unsupported state feature weighting configuration: "
+ f"{weighting_config}"
+ )
+
+ return weights
diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py
index e94de8c6..bc4c6719 100644
--- a/neural_lam/models/ar_model.py
+++ b/neural_lam/models/ar_model.py
@@ -9,7 +9,10 @@
import wandb
# Local
-from .. import config, metrics, utils, vis
+from .. import metrics, vis
+from ..config import NeuralLAMConfig
+from ..datastore import BaseDatastore
+from ..loss_weighting import get_state_feature_weighting
class ARModel(pl.LightningModule):
@@ -21,35 +24,78 @@ class ARModel(pl.LightningModule):
# pylint: disable=arguments-differ
# Disable to override args/kwargs from superclass
- def __init__(self, args):
+ def __init__(
+ self,
+ args,
+ config: NeuralLAMConfig,
+ datastore: BaseDatastore,
+ ):
super().__init__()
- self.save_hyperparameters()
+ self.save_hyperparameters(ignore=["datastore"])
self.args = args
- self.config_loader = config.Config.from_file(args.data_config)
+ self._datastore = datastore
+ num_state_vars = datastore.get_num_data_vars(category="state")
+ num_forcing_vars = datastore.get_num_data_vars(category="forcing")
+ da_static_features = datastore.get_dataarray(
+ category="static", split=None
+ )
+ da_state_stats = datastore.get_standardization_dataarray(
+ category="state"
+ )
+ da_boundary_mask = datastore.boundary_mask
+ num_past_forcing_steps = args.num_past_forcing_steps
+ num_future_forcing_steps = args.num_future_forcing_steps
+
+ # Load static features for grid/data, NB: self.predict_step assumes
+ # dimension order to be (grid_index, static_feature)
+ arr_static = da_static_features.transpose(
+ "grid_index", "static_feature"
+ ).values
+ self.register_buffer(
+ "grid_static_features",
+ torch.tensor(arr_static, dtype=torch.float32),
+ persistent=False,
+ )
+
+ state_stats = {
+ "state_mean": torch.tensor(
+ da_state_stats.state_mean.values, dtype=torch.float32
+ ),
+ "state_std": torch.tensor(
+ da_state_stats.state_std.values, dtype=torch.float32
+ ),
+ "diff_mean": torch.tensor(
+ da_state_stats.state_diff_mean.values, dtype=torch.float32
+ ),
+ "diff_std": torch.tensor(
+ da_state_stats.state_diff_std.values, dtype=torch.float32
+ ),
+ }
+
+ for key, val in state_stats.items():
+ self.register_buffer(key, val, persistent=False)
- # Load static features for grid/data
- static_data_dict = utils.load_static_data(
- self.config_loader.dataset.name
+ state_feature_weights = get_state_feature_weighting(
+ config=config, datastore=datastore
+ )
+ self.feature_weights = torch.tensor(
+ state_feature_weights, dtype=torch.float32
)
- for static_data_name, static_data_tensor in static_data_dict.items():
- self.register_buffer(
- static_data_name, static_data_tensor, persistent=False
- )
# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
if self.output_std:
# Pred. dim. in grid cell
- self.grid_output_dim = 2 * self.config_loader.num_data_vars()
+ self.grid_output_dim = 2 * num_state_vars
else:
# Pred. dim. in grid cell
- self.grid_output_dim = self.config_loader.num_data_vars()
+ self.grid_output_dim = num_state_vars
# Store constant per-variable std.-dev. weighting
- # Note that this is the inverse of the multiplicative weighting
+ # NOTE that this is the inverse of the multiplicative weighting
# in wMSE/wMAE
self.register_buffer(
"per_var_std",
- self.step_diff_std / torch.sqrt(self.param_weights),
+ self.diff_std / torch.sqrt(self.feature_weights),
persistent=False,
)
@@ -58,21 +104,29 @@ def __init__(self, args):
self.num_grid_nodes,
grid_static_dim,
) = self.grid_static_features.shape
+
self.grid_dim = (
- 2 * self.config_loader.num_data_vars()
+ 2 * self.grid_output_dim
+ grid_static_dim
- + self.config_loader.dataset.num_forcing_features
+ + num_forcing_vars
+ * (num_past_forcing_steps + num_future_forcing_steps + 1)
)
# Instantiate loss function
self.loss = metrics.get_metric(args.loss)
+ boundary_mask = torch.tensor(
+ da_boundary_mask.values, dtype=torch.float32
+ ).unsqueeze(
+ 1
+ ) # add feature dim
+
+ self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
- "interior_mask", 1.0 - self.border_mask, persistent=False
+ "interior_mask", 1.0 - self.boundary_mask, persistent=False
) # (num_grid_nodes, 1), 1 for non-border
- self.step_length = args.step_length # Number of hours per pred. step
self.val_metrics = {
"mse": [],
}
@@ -116,18 +170,18 @@ def expand_to_batch(x, batch_size):
def predict_step(self, prev_state, prev_prev_state, forcing):
"""
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1
- prev_state: (B, num_grid_nodes, feature_dim), X_t
- prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1}
- forcing: (B, num_grid_nodes, forcing_dim)
+ prev_state: (B, num_grid_nodes, feature_dim), X_t prev_prev_state: (B,
+ num_grid_nodes, feature_dim), X_{t-1} forcing: (B, num_grid_nodes,
+ forcing_dim)
"""
raise NotImplementedError("No prediction step implemented")
def unroll_prediction(self, init_states, forcing_features, true_states):
"""
Roll out prediction taking multiple autoregressive steps with model
- init_states: (B, 2, num_grid_nodes, d_f)
- forcing_features: (B, pred_steps, num_grid_nodes, d_static_f)
- true_states: (B, pred_steps, num_grid_nodes, d_f)
+ init_states: (B, 2, num_grid_nodes, d_f) forcing_features: (B,
+ pred_steps, num_grid_nodes, d_static_f) true_states: (B, pred_steps,
+ num_grid_nodes, d_f)
"""
prev_prev_state = init_states[:, 0]
prev_state = init_states[:, 1]
@@ -142,12 +196,12 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, forcing
)
- # state: (B, num_grid_nodes, d_f)
- # pred_std: (B, num_grid_nodes, d_f) or None
+ # state: (B, num_grid_nodes, d_f) pred_std: (B, num_grid_nodes,
+ # d_f) or None
# Overwrite border with true state
new_state = (
- self.border_mask * border_state
+ self.boundary_mask * border_state
+ self.interior_mask * pred_state
)
@@ -173,32 +227,27 @@ def unroll_prediction(self, init_states, forcing_features, true_states):
def common_step(self, batch):
"""
- Predict on single batch
- batch consists of:
- init_states: (B, 2, num_grid_nodes, d_features)
- target_states: (B, pred_steps, num_grid_nodes, d_features)
- forcing_features: (B, pred_steps, num_grid_nodes, d_forcing),
+ Predict on single batch batch consists of: init_states: (B, 2,
+ num_grid_nodes, d_features) target_states: (B, pred_steps,
+ num_grid_nodes, d_features) forcing_features: (B, pred_steps,
+ num_grid_nodes, d_forcing),
where index 0 corresponds to index 1 of init_states
"""
- (
- init_states,
- target_states,
- forcing_features,
- ) = batch
+ (init_states, target_states, forcing_features, batch_times) = batch
prediction, pred_std = self.unroll_prediction(
init_states, forcing_features, target_states
) # (B, pred_steps, num_grid_nodes, d_f)
- # prediction: (B, pred_steps, num_grid_nodes, d_f)
- # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)
+ # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
+ # pred_steps, num_grid_nodes, d_f) or (d_f,)
- return prediction, target_states, pred_std
+ return prediction, target_states, pred_std, batch_times
def training_step(self, batch):
"""
Train on single batch
"""
- prediction, target, pred_std = self.common_step(batch)
+ prediction, target, pred_std, _ = self.common_step(batch)
# Compute loss
batch_loss = torch.mean(
@@ -209,14 +258,19 @@ def training_step(self, batch):
log_dict = {"train_loss": batch_loss}
self.log_dict(
- log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True
+ log_dict,
+ prog_bar=True,
+ on_step=True,
+ on_epoch=True,
+ sync_dist=True,
+ batch_size=batch[0].shape[0],
)
return batch_loss
def all_gather_cat(self, tensor_to_gather):
"""
- Gather tensors across all ranks, and concatenate across dim. 0
- (instead of stacking in new dim. 0)
+ Gather tensors across all ranks, and concatenate across dim. 0 (instead
+ of stacking in new dim. 0)
tensor_to_gather: (d1, d2, ...), distributed over K ranks
@@ -230,7 +284,7 @@ def validation_step(self, batch, batch_idx):
"""
Run validation on single batch
"""
- prediction, target, pred_std = self.common_step(batch)
+ prediction, target, pred_std, _ = self.common_step(batch)
time_step_loss = torch.mean(
self.loss(
@@ -244,10 +298,15 @@ def validation_step(self, batch, batch_idx):
val_log_dict = {
f"val_loss_unroll{step}": time_step_loss[step - 1]
for step in self.args.val_steps_to_log
+ if step <= len(time_step_loss)
}
val_log_dict["val_mean_loss"] = mean_loss
self.log_dict(
- val_log_dict, on_step=False, on_epoch=True, sync_dist=True
+ val_log_dict,
+ on_step=False,
+ on_epoch=True,
+ sync_dist=True,
+ batch_size=batch[0].shape[0],
)
# Store MSEs
@@ -276,9 +335,10 @@ def test_step(self, batch, batch_idx):
"""
Run test on single batch
"""
- prediction, target, pred_std = self.common_step(batch)
- # prediction: (B, pred_steps, num_grid_nodes, d_f)
- # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,)
+ # TODO Here batch_times can be used for plotting routines
+ prediction, target, pred_std, batch_times = self.common_step(batch)
+ # prediction: (B, pred_steps, num_grid_nodes, d_f) pred_std: (B,
+ # pred_steps, num_grid_nodes, d_f) or (d_f,)
time_step_loss = torch.mean(
self.loss(
@@ -296,13 +356,16 @@ def test_step(self, batch, batch_idx):
test_log_dict["test_mean_loss"] = mean_loss
self.log_dict(
- test_log_dict, on_step=False, on_epoch=True, sync_dist=True
+ test_log_dict,
+ on_step=False,
+ on_epoch=True,
+ sync_dist=True,
+ batch_size=batch[0].shape[0],
)
- # Compute all evaluation metrics for error maps
- # Note: explicitly list metrics here, as test_metrics can contain
- # additional ones, computed differently, but that should be aggregated
- # on_test_epoch_end
+ # Compute all evaluation metrics for error maps Note: explicitly list
+ # metrics here, as test_metrics can contain additional ones, computed
+ # differently, but that should be aggregated on_test_epoch_end
for metric_name in ("mse", "mae"):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
@@ -338,7 +401,8 @@ def test_step(self, batch, batch_idx):
):
# Need to plot more example predictions
n_additional_examples = min(
- prediction.shape[0], self.n_example_pred - self.plotted_examples
+ prediction.shape[0],
+ self.n_example_pred - self.plotted_examples,
)
self.plot_examples(
@@ -349,19 +413,19 @@ def plot_examples(self, batch, n_examples, prediction=None):
"""
Plot the first n_examples forecasts from batch
- batch: batch with data to plot corresponding forecasts for
- n_examples: number of forecasts to plot
- prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction.
+ batch: batch with data to plot corresponding forecasts for n_examples:
+ number of forecasts to plot prediction: (B, pred_steps, num_grid_nodes,
+ d_f), existing prediction.
Generate if None.
"""
if prediction is None:
- prediction, target = self.common_step(batch)
+ prediction, target, _, _ = self.common_step(batch)
target = batch[1]
# Rescale to original data scale
- prediction_rescaled = prediction * self.data_std + self.data_mean
- target_rescaled = target * self.data_std + self.data_mean
+ prediction_rescaled = prediction * self.state_std + self.state_mean
+ target_rescaled = target * self.state_std + self.state_mean
# Iterate over the examples
for pred_slice, target_slice in zip(
@@ -395,18 +459,17 @@ def plot_examples(self, batch, n_examples, prediction=None):
# Create one figure per variable at this time step
var_figs = [
vis.plot_prediction(
- pred_t[:, var_i],
- target_t[:, var_i],
- self.interior_mask[:, 0],
- self.config_loader,
+ pred=pred_t[:, var_i],
+ target=target_t[:, var_i],
+ datastore=self._datastore,
title=f"{var_name} ({var_unit}), "
- f"t={t_i} ({self.step_length * t_i} h)",
+ f"t={t_i} ({self._datastore.step_length * t_i} h)",
vrange=var_vrange,
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
- self.config_loader.dataset.var_names,
- self.config_loader.dataset.var_units,
+ self._datastore.get_vars_names("state"),
+ self._datastore.get_vars_units("state"),
var_vranges,
)
)
@@ -417,7 +480,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
{
f"{var_name}_example_{example_i}": wandb.Image(fig)
for var_name, fig in zip(
- self.config_loader.dataset.var_names, var_figs
+ self._datastore.get_vars_names("state"), var_figs
)
}
)
@@ -441,19 +504,19 @@ def plot_examples(self, batch, n_examples, prediction=None):
def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
- Put together a dict with everything to log for one metric.
- Also saves plots as pdf and csv if using test prefix.
+ Put together a dict with everything to log for one metric. Also saves
+ plots as pdf and csv if using test prefix.
metric_tensor: (pred_steps, d_f), metric values per time and variable
- prefix: string, prefix to use for logging
- metric_name: string, name of the metric
+ prefix: string, prefix to use for logging metric_name: string, name of
+ the metric
- Return:
- log_dict: dict with everything to log for given metric
+ Return: log_dict: dict with everything to log for given metric
"""
log_dict = {}
metric_fig = vis.plot_error_map(
- metric_tensor, self.config_loader, step_length=self.step_length
+ errors=metric_tensor,
+ datastore=self._datastore,
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
@@ -471,17 +534,13 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
)
# Check if metrics are watched, log exact values for specific vars
+ var_names = self._datastore.get_vars_names(category="state")
if full_log_name in self.args.metrics_watch:
for var_i, timesteps in self.args.var_leads_metrics_watch.items():
- var = self.config_loader.dataset.var_names[var_i]
- log_dict.update(
- {
- f"{full_log_name}_{var}_step_{step}": metric_tensor[
- step - 1, var_i
- ] # 1-indexed in data_config
- for step in timesteps
- }
- )
+ var_name = var_names[var_i]
+ for step in timesteps:
+ key = f"{full_log_name}_{var_name}_step_{step}"
+ log_dict[key] = metric_tensor[step - 1, var_i]
return log_dict
@@ -508,8 +567,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
metric_tensor_averaged = torch.sqrt(metric_tensor_averaged)
metric_name = metric_name.replace("mse", "rmse")
- # Note: we here assume rescaling for all metrics is linear
- metric_rescaled = metric_tensor_averaged * self.data_std
+ # NOTE: we here assume rescaling for all metrics is linear
+ metric_rescaled = metric_tensor_averaged * self.state_std
# (pred_steps, d_f)
log_dict.update(
self.create_metric_log_dict(
@@ -523,8 +582,8 @@ def aggregate_and_plot_metrics(self, metrics_dict, prefix):
def on_test_epoch_end(self):
"""
- Compute test metrics and make plots at the end of test epoch.
- Will gather stored tensors and perform plotting and logging on rank 0.
+ Compute test metrics and make plots at the end of test epoch. Will
+ gather stored tensors and perform plotting and logging on rank 0.
"""
# Create error maps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")
@@ -540,10 +599,10 @@ def on_test_epoch_end(self):
loss_map_figs = [
vis.plot_spatial_error(
- loss_map,
- self.interior_mask[:, 0],
- self.config_loader,
- title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",
+ error=loss_map,
+ datastore=self._datastore,
+ title=f"Test loss, t={t_i} "
+ f"({self._datastore.step_length * t_i} h)",
)
for t_i, loss_map in zip(
self.args.val_steps_to_log, mean_spatial_loss
@@ -557,7 +616,7 @@ def on_test_epoch_end(self):
# also make without title and save as pdf
pdf_loss_map_figs = [
vis.plot_spatial_error(
- loss_map, self.interior_mask[:, 0], self.config_loader
+ error=loss_map, datastore=self._datastore
)
for loss_map in mean_spatial_loss
]
diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py
index 99629073..6233b4d1 100644
--- a/neural_lam/models/base_graph_model.py
+++ b/neural_lam/models/base_graph_model.py
@@ -3,6 +3,8 @@
# Local
from .. import utils
+from ..config import NeuralLAMConfig
+from ..datastore import BaseDatastore
from ..interaction_net import InteractionNet
from .ar_model import ARModel
@@ -13,13 +15,16 @@ class BaseGraphModel(ARModel):
the encode-process-decode idea.
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
+ super().__init__(args, config=config, datastore=datastore)
# Load graph with static features
# NOTE: (IMPORTANT!) mesh nodes MUST have the first
# num_mesh_nodes indices,
- self.hierarchical, graph_ldict = utils.load_graph(args.graph)
+ graph_dir_path = datastore.root_path / "graph" / args.graph
+ self.hierarchical, graph_ldict = utils.load_graph(
+ graph_dir_path=graph_dir_path
+ )
for name, attr_value in graph_ldict.items():
# Make BufferLists module members and register tensors as buffers
if isinstance(attr_value, torch.Tensor):
@@ -157,7 +162,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
pred_delta_mean, pred_std_raw = net_output.chunk(
2, dim=-1
) # both (B, num_grid_nodes, d_f)
- # Note: The predicted std. is not scaled in any way here
+ # NOTE: The predicted std. is not scaled in any way here
# linter for some reason does not think softplus is callable
# pylint: disable-next=not-callable
pred_std = torch.nn.functional.softplus(pred_std_raw)
@@ -166,9 +171,7 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
pred_std = None
# Rescale with one-step difference statistics
- rescaled_delta_mean = (
- pred_delta_mean * self.step_diff_std + self.step_diff_mean
- )
+ rescaled_delta_mean = pred_delta_mean * self.diff_std + self.diff_mean
# Residual connection for full state
return prev_state + rescaled_delta_mean, pred_std
diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py
index a2ebcc1b..8ec46b4f 100644
--- a/neural_lam/models/base_hi_graph_model.py
+++ b/neural_lam/models/base_hi_graph_model.py
@@ -3,6 +3,8 @@
# Local
from .. import utils
+from ..config import NeuralLAMConfig
+from ..datastore import BaseDatastore
from ..interaction_net import InteractionNet
from .base_graph_model import BaseGraphModel
@@ -12,8 +14,8 @@ class BaseHiGraphModel(BaseGraphModel):
Base class for hierarchical graph models.
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
+ super().__init__(args, config=config, datastore=datastore)
# Track number of nodes, edges on each level
# Flatten lists for efficient embedding
diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py
index d73f7ad8..68b7d01e 100644
--- a/neural_lam/models/graph_lam.py
+++ b/neural_lam/models/graph_lam.py
@@ -3,6 +3,8 @@
# Local
from .. import utils
+from ..config import NeuralLAMConfig
+from ..datastore import BaseDatastore
from ..interaction_net import InteractionNet
from .base_graph_model import BaseGraphModel
@@ -15,8 +17,8 @@ class GraphLAM(BaseGraphModel):
Oskarsson et al. (2023).
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
+ super().__init__(args, config=config, datastore=datastore)
assert (
not self.hierarchical
diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py
index 4f3aec05..c340c95d 100644
--- a/neural_lam/models/hi_lam.py
+++ b/neural_lam/models/hi_lam.py
@@ -2,6 +2,8 @@
from torch import nn
# Local
+from ..config import NeuralLAMConfig
+from ..datastore import BaseDatastore
from ..interaction_net import InteractionNet
from .base_hi_graph_model import BaseHiGraphModel
@@ -13,8 +15,8 @@ class HiLAM(BaseHiGraphModel):
The Hi-LAM model from Oskarsson et al. (2023)
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
+ super().__init__(args, config=config, datastore=datastore)
# Make down GNNs, both for down edges and same level
self.mesh_down_gnns = nn.ModuleList(
@@ -200,5 +202,6 @@ def hi_processor_step(
up_same_gnns,
)
- # Note: We return all, even though only down edges really are used later
+ # NOTE: We return all, even though only down edges really are used
+ # later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py
index b40a9424..a0a84d29 100644
--- a/neural_lam/models/hi_lam_parallel.py
+++ b/neural_lam/models/hi_lam_parallel.py
@@ -3,6 +3,8 @@
import torch_geometric as pyg
# Local
+from ..config import NeuralLAMConfig
+from ..datastore import BaseDatastore
from ..interaction_net import InteractionNet
from .base_hi_graph_model import BaseHiGraphModel
@@ -16,8 +18,8 @@ class HiLAMParallel(BaseHiGraphModel):
of Hi-LAM.
"""
- def __init__(self, args):
- super().__init__(args)
+ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
+ super().__init__(args, config=config, datastore=datastore)
# Processor GNNs
# Create the complete edge_index combining all edges for processing
@@ -92,5 +94,6 @@ def hi_processor_step(
self.num_levels + (self.num_levels - 1) :
] # Last are down edges
- # Note: We return all, even though only down edges really are used later
+ # TODO: We return all, even though only down edges really are used
+ # later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
diff --git a/plot_graph.py b/neural_lam/plot_graph.py
similarity index 88%
rename from plot_graph.py
rename to neural_lam/plot_graph.py
index e47e62c0..999c8e53 100644
--- a/plot_graph.py
+++ b/neural_lam/plot_graph.py
@@ -1,4 +1,5 @@
# Standard library
+import os
from argparse import ArgumentParser
# Third-party
@@ -6,8 +7,9 @@
import plotly.graph_objects as go
import torch_geometric as pyg
-# First-party
-from neural_lam import config, utils
+# Local
+from . import utils
+from .config import load_config_and_datastore
MESH_HEIGHT = 0.1
MESH_LEVEL_DIST = 0.2
@@ -15,15 +17,13 @@
def main():
- """
- Plot graph structure in 3D using plotly
- """
+ """Plot graph structure in 3D using plotly."""
parser = ArgumentParser(description="Plot graph")
parser.add_argument(
- "--data_config",
+ "--datastore_config_path",
type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
+ default="tests/datastore_examples/mdp/config.yaml",
+ help="Path for the datastore config",
)
parser.add_argument(
"--graph",
@@ -43,10 +43,17 @@ def main():
)
args = parser.parse_args()
- config_loader = config.Config.from_file(args.data_config)
+ _, datastore = load_config_and_datastore(
+ config_path=args.datastore_config_path
+ )
+
+ xy = datastore.get_xy("state", stacked=True) # (N_grid, 2)
+ pos_max = np.max(np.abs(xy))
+ grid_pos = xy / pos_max # Divide by maximum coordinate
# Load graph data
- hierarchical, graph_ldict = utils.load_graph(args.graph)
+ graph_dir_path = os.path.join(datastore.root_path, "graph", args.graph)
+ hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path)
(g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
graph_ldict["g2m_edge_index"],
graph_ldict["m2g_edge_index"],
@@ -58,12 +65,6 @@ def main():
)
mesh_static_features = graph_ldict["mesh_static_features"]
- grid_static_features = utils.load_static_data(config_loader.dataset.name)[
- "grid_static_features"
- ]
-
- # Extract values needed, turn to numpy
- grid_pos = grid_static_features[:, :2].numpy()
# Add in z-dimension
z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],))
grid_pos = np.concatenate(
diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py
index c1a6cb89..74146c89 100644
--- a/neural_lam/train_model.py
+++ b/neural_lam/train_model.py
@@ -8,10 +8,13 @@
import pytorch_lightning as pl
import torch
from lightning_fabric.utilities import seed
+from loguru import logger
# Local
-from . import WeatherDataset, config, utils
+from . import utils
+from .config import load_config_and_datastore
from .models import GraphLAM, HiLAM, HiLAMParallel
+from .weather_dataset import WeatherDataModule
MODELS = {
"graph_lam": GraphLAM,
@@ -20,18 +23,16 @@
}
+@logger.catch
def main(input_args=None):
- """
- Main function for training and evaluating models
- """
+ """Main function for training and evaluating models."""
parser = ArgumentParser(
description="Train or evaluate NeurWP models for LAM"
)
parser.add_argument(
- "--data_config",
+ "--config_path",
type=str,
- default="neural_lam/data_config.yaml",
- help="Path to data config file (default: neural_lam/data_config.yaml)",
+ help="Path to the configuration for neural-lam",
)
parser.add_argument(
"--model",
@@ -39,17 +40,11 @@ def main(input_args=None):
default="graph_lam",
help="Model architecture to train/evaluate (default: graph_lam)",
)
- parser.add_argument(
- "--subset_ds",
- action="store_true",
- help="Use only a small subset of the dataset, for debugging"
- "(default: false)",
- )
parser.add_argument(
"--seed", type=int, default=42, help="random seed (default: 42)"
)
parser.add_argument(
- "--n_workers",
+ "--num_workers",
type=int,
default=4,
help="Number of workers in data loader (default: 4)",
@@ -124,31 +119,18 @@ def main(input_args=None):
# Training options
parser.add_argument(
- "--ar_steps",
+ "--ar_steps_train",
type=int,
default=1,
- help="Number of steps to unroll prediction for in loss (1-19) "
+ help="Number of steps to unroll prediction for in loss function "
"(default: 1)",
)
- parser.add_argument(
- "--control_only",
- action="store_true",
- help="Train only on control member of ensemble data "
- "(default: False)",
- )
parser.add_argument(
"--loss",
type=str,
default="wmse",
help="Loss function to use, see metric.py (default: wmse)",
)
- parser.add_argument(
- "--step_length",
- type=int,
- default=3,
- help="Step length in hours to consider single time step 1-3 "
- "(default: 3)",
- )
parser.add_argument(
"--lr", type=float, default=1e-3, help="learning rate (default: 0.001)"
)
@@ -167,6 +149,13 @@ def main(input_args=None):
help="Eval model on given data split (val/test) "
"(default: None (train model))",
)
+ parser.add_argument(
+ "--ar_steps_eval",
+ type=int,
+ default=10,
+ help="Number of steps to unroll prediction for during evaluation "
+ "(default: 10)",
+ )
parser.add_argument(
"--n_example_pred",
type=int,
@@ -184,9 +173,10 @@ def main(input_args=None):
)
parser.add_argument(
"--val_steps_to_log",
- type=list,
+ nargs="+",
+ type=int,
default=[1, 2, 3, 5, 10, 15, 19],
- help="Steps to log val loss for (default: [1, 2, 3, 5, 10, 15, 19])",
+ help="Steps to log val loss for (default: 1 2 3 5 10 15 19)",
)
parser.add_argument(
"--metrics_watch",
@@ -201,15 +191,28 @@ def main(input_args=None):
help="""JSON string with variable-IDs and lead times to log watched
metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""",
)
+ parser.add_argument(
+ "--num_past_forcing_steps",
+ type=int,
+ default=1,
+ help="Number of past time steps to use as input for forcing data",
+ )
+ parser.add_argument(
+ "--num_future_forcing_steps",
+ type=int,
+ default=1,
+ help="Number of future time steps to use as input for forcing data",
+ )
args = parser.parse_args(input_args)
args.var_leads_metrics_watch = {
int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items()
}
- config_loader = config.Config.from_file(args.data_config)
# Asserts for arguments
+ assert (
+ args.config_path is not None
+ ), "Specify your config with --config_path"
assert args.model in MODELS, f"Unknown model: {args.model}"
- assert args.step_length <= 3, "Too high step length"
assert args.eval in (
None,
"val",
@@ -222,33 +225,19 @@ def main(input_args=None):
# Set seed
seed.seed_everything(args.seed)
- # Load data
- train_loader = torch.utils.data.DataLoader(
- WeatherDataset(
- config_loader.dataset.name,
- pred_length=args.ar_steps,
- split="train",
- subsample_step=args.step_length,
- subset=args.subset_ds,
- control_only=args.control_only,
- ),
- args.batch_size,
- shuffle=True,
- num_workers=args.n_workers,
- )
- max_pred_length = (65 // args.step_length) - 2 # 19
- val_loader = torch.utils.data.DataLoader(
- WeatherDataset(
- config_loader.dataset.name,
- pred_length=max_pred_length,
- split="val",
- subsample_step=args.step_length,
- subset=args.subset_ds,
- control_only=args.control_only,
- ),
- args.batch_size,
- shuffle=False,
- num_workers=args.n_workers,
+ # Load neural-lam configuration and datastore to use
+ config, datastore = load_config_and_datastore(config_path=args.config_path)
+
+ # Create datamodule
+ data_module = WeatherDataModule(
+ datastore=datastore,
+ ar_steps_train=args.ar_steps_train,
+ ar_steps_eval=args.ar_steps_eval,
+ standardize=True,
+ num_past_forcing_steps=args.num_past_forcing_steps,
+ num_future_forcing_steps=args.num_future_forcing_steps,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
)
# Instantiate model + trainer
@@ -261,12 +250,13 @@ def main(input_args=None):
device_name = "cpu"
# Load model parameters Use new args for model
- model_class = MODELS[args.model]
- model = model_class(args)
+ ModelClass = MODELS[args.model]
+ model = ModelClass(args, config=config, datastore=datastore)
- prefix = "subset-" if args.subset_ds else ""
if args.eval:
- prefix = prefix + f"eval-{args.eval}-"
+ prefix = f"eval-{args.eval}-"
+ else:
+ prefix = "train-"
run_name = (
f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"
f"{time.strftime('%m_%d_%H')}-{random_run_id:04d}"
@@ -279,7 +269,9 @@ def main(input_args=None):
save_last=True,
)
logger = pl.loggers.WandbLogger(
- project=args.wandb_project, name=run_name, config=args
+ project=args.wandb_project,
+ name=run_name,
+ config=dict(training=vars(args), datastore=datastore._config),
)
trainer = pl.Trainer(
max_epochs=args.epochs,
@@ -296,36 +288,12 @@ def main(input_args=None):
# Only init once, on rank 0 only
if trainer.global_rank == 0:
utils.init_wandb_metrics(
- logger, args.val_steps_to_log
+ logger, val_steps=args.val_steps_to_log
) # Do after wandb.init
-
if args.eval:
- if args.eval == "val":
- eval_loader = val_loader
- else: # Test
- eval_loader = torch.utils.data.DataLoader(
- WeatherDataset(
- config_loader.dataset.name,
- pred_length=max_pred_length,
- split="test",
- subsample_step=args.step_length,
- subset=args.subset_ds,
- ),
- args.batch_size,
- shuffle=False,
- num_workers=args.n_workers,
- )
-
- print(f"Running evaluation on {args.eval}")
- trainer.test(model=model, dataloaders=eval_loader, ckpt_path=args.load)
+ trainer.test(model=model, datamodule=data_module, ckpt_path=args.load)
else:
- # Train model
- trainer.fit(
- model=model,
- train_dataloaders=train_loader,
- val_dataloaders=val_loader,
- ckpt_path=args.load,
- )
+ trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load)
if __name__ == "__main__":
diff --git a/neural_lam/utils.py b/neural_lam/utils.py
index c47c44ff..4a0752e4 100644
--- a/neural_lam/utils.py
+++ b/neural_lam/utils.py
@@ -3,90 +3,11 @@
import shutil
# Third-party
-import numpy as np
import torch
from torch import nn
from tueplots import bundles, figsizes
-def load_dataset_stats(dataset_name, device="cpu"):
- """
- Load arrays with stored dataset statistics from pre-processing
- """
- static_dir_path = os.path.join("data", dataset_name, "static")
-
- def loads_file(fn):
- return torch.load(
- os.path.join(static_dir_path, fn),
- map_location=device,
- weights_only=True,
- )
-
- data_mean = loads_file("parameter_mean.pt") # (d_features,)
- data_std = loads_file("parameter_std.pt") # (d_features,)
-
- flux_stats = loads_file("flux_stats.pt") # (2,)
- flux_mean, flux_std = flux_stats
-
- return {
- "data_mean": data_mean,
- "data_std": data_std,
- "flux_mean": flux_mean,
- "flux_std": flux_std,
- }
-
-
-def load_static_data(dataset_name, device="cpu"):
- """
- Load static files related to dataset
- """
- static_dir_path = os.path.join("data", dataset_name, "static")
-
- def loads_file(fn):
- return torch.load(
- os.path.join(static_dir_path, fn),
- map_location=device,
- weights_only=True,
- )
-
- # Load border mask, 1. if node is part of border, else 0.
- border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy"))
- border_mask = (
- torch.tensor(border_mask_np, dtype=torch.float32, device=device)
- .flatten(0, 1)
- .unsqueeze(1)
- ) # (N_grid, 1)
-
- grid_static_features = loads_file(
- "grid_features.pt"
- ) # (N_grid, d_grid_static)
-
- # Load step diff stats
- step_diff_mean = loads_file("diff_mean.pt") # (d_f,)
- step_diff_std = loads_file("diff_std.pt") # (d_f,)
-
- # Load parameter std for computing validation errors in original data scale
- data_mean = loads_file("parameter_mean.pt") # (d_features,)
- data_std = loads_file("parameter_std.pt") # (d_features,)
-
- # Load loss weighting vectors
- param_weights = torch.tensor(
- np.load(os.path.join(static_dir_path, "parameter_weights.npy")),
- dtype=torch.float32,
- device=device,
- ) # (d_f,)
-
- return {
- "border_mask": border_mask,
- "grid_static_features": grid_static_features,
- "step_diff_mean": step_diff_mean,
- "step_diff_std": step_diff_std,
- "data_mean": data_mean,
- "data_std": data_std,
- "param_weights": param_weights,
- }
-
-
class BufferList(nn.Module):
"""
A list of torch buffer tensors that sit together as a Module with no
@@ -112,12 +33,50 @@ def __iter__(self):
return (self[i] for i in range(len(self)))
-def load_graph(graph_name, device="cpu"):
- """
- Load all tensors representing the graph
+def load_graph(graph_dir_path, device="cpu"):
+ """Load all tensors representing the graph from `graph_dir_path`.
+
+ Needs the following files for all graphs:
+ - m2m_edge_index.pt
+ - g2m_edge_index.pt
+ - m2g_edge_index.pt
+ - m2m_features.pt
+ - g2m_features.pt
+ - m2g_features.pt
+ - mesh_features.pt
+
+ And in addition for hierarchical graphs:
+ - mesh_up_edge_index.pt
+ - mesh_down_edge_index.pt
+ - mesh_up_features.pt
+ - mesh_down_features.pt
+
+ Parameters
+ ----------
+ graph_dir_path : str
+ Path to directory containing the graph files.
+ device : str
+ Device to load tensors to.
+
+ Returns
+ -------
+ hierarchical : bool
+ Whether the graph is hierarchical.
+ graph : dict
+ Dictionary containing the graph tensors, with keys as follows:
+ - g2m_edge_index
+ - m2g_edge_index
+ - m2m_edge_index
+ - mesh_up_edge_index
+ - mesh_down_edge_index
+ - g2m_features
+ - m2g_features
+ - m2m_features
+ - mesh_up_features
+ - mesh_down_features
+ - mesh_static_features
+
"""
- # Define helper lambda function
- graph_dir_path = os.path.join("graphs", graph_name)
def loads_file(fn):
return torch.load(
@@ -137,7 +96,8 @@ def loads_file(fn):
hierarchical = n_levels > 1 # Nor just single level mesh graph
# Load static edge features
- m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f)
+ # List of (M_m2m[l], d_edge_f)
+ m2m_features = loads_file("m2m_features.pt")
g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f)
m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f)
@@ -259,9 +219,9 @@ def fractional_plot_bundle(fraction):
Get the tueplots bundle, but with figure width as a fraction of
the page width.
"""
- # If latex is not available, some visualizations might not render correctly,
- # but will at least not raise an error.
- # Alternatively, use unicode raised numbers.
+ # If latex is not available, some visualizations might not render
+ # correctly, but will at least not raise an error. Alternatively, use
+ # unicode raised numbers.
usetex = True if shutil.which("latex") else False
bundle = bundles.neurips2023(usetex=usetex, family="serif")
bundle.update(figsizes.neurips2023())
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index 2f22bef1..b9d18b39 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -5,10 +5,11 @@
# Local
from . import utils
+from .datastore.base import BaseRegularGridDatastore
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_error_map(errors, data_config, title=None, step_length=3):
+def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
@@ -16,6 +17,7 @@ def plot_error_map(errors, data_config, title=None, step_length=3):
"""
errors_np = errors.T.cpu().numpy() # (d_f, pred_steps)
d_f, pred_steps = errors_np.shape
+ step_length = datastore.step_length
# Normalize all errors to [0,1] for color map
max_errors = errors_np.max(axis=1) # d_f
@@ -48,11 +50,10 @@ def plot_error_map(errors, data_config, title=None, step_length=3):
ax.set_xlabel("Lead time (h)", size=label_size)
ax.set_yticks(np.arange(d_f))
+ var_names = datastore.get_vars_names(category="state")
+ var_units = datastore.get_vars_units(category="state")
y_ticklabels = [
- f"{name} ({unit})"
- for name, unit in zip(
- data_config.dataset.var_names, data_config.dataset.var_units
- )
+ f"{name} ({unit})" for name, unit in zip(var_names, var_units)
]
ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size)
@@ -64,11 +65,17 @@ def plot_error_map(errors, data_config, title=None, step_length=3):
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
- pred, target, obs_mask, data_config, title=None, vrange=None
+ pred,
+ target,
+ datastore: BaseRegularGridDatastore,
+ title=None,
+ vrange=None,
):
"""
Plot example prediction and grond truth.
+
Each has shape (N_grid,)
+
"""
# Get common scale for values
if vrange is None:
@@ -77,8 +84,11 @@ def plot_prediction(
else:
vmin, vmax = vrange
+ extent = datastore.get_xy_extent("state")
+
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state)
+ da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
+ mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
@@ -87,16 +97,21 @@ def plot_prediction(
1,
2,
figsize=(13, 7),
- subplot_kw={"projection": data_config.coords_projection},
+ subplot_kw={"projection": datastore.coords_projection},
)
# Plot pred and target
for ax, data in zip(axes, (target, pred)):
ax.coastlines() # Add coastline outlines
- data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy()
+ data_grid = (
+ data.reshape(list(datastore.grid_shape_state.values.values()))
+ .cpu()
+ .numpy()
+ )
im = ax.imshow(
data_grid,
origin="lower",
+ extent=extent,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
@@ -116,7 +131,9 @@ def plot_prediction(
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
-def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
+def plot_spatial_error(
+ error, datastore: BaseRegularGridDatastore, title=None, vrange=None
+):
"""
Plot errors over spatial map
Error and obs_mask has shape (N_grid,)
@@ -128,23 +145,31 @@ def plot_spatial_error(error, obs_mask, data_config, title=None, vrange=None):
else:
vmin, vmax = vrange
+ extent = datastore.get_xy_extent("state")
+
# Set up masking of border region
- mask_reshaped = obs_mask.reshape(*data_config.grid_shape_state)
+ da_mask = datastore.unstack_grid_coords(datastore.boundary_mask)
+ mask_reshaped = da_mask.values
pixel_alpha = (
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region
fig, ax = plt.subplots(
figsize=(5, 4.8),
- subplot_kw={"projection": data_config.coords_projection},
+ subplot_kw={"projection": datastore.coords_projection},
)
ax.coastlines() # Add coastline outlines
- error_grid = error.reshape(*data_config.grid_shape_state).cpu().numpy()
+ error_grid = (
+ error.reshape(list(datastore.grid_shape_state.values.values()))
+ .cpu()
+ .numpy()
+ )
im = ax.imshow(
error_grid,
origin="lower",
+ extent=extent,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py
index 29977789..532e3c90 100644
--- a/neural_lam/weather_dataset.py
+++ b/neural_lam/weather_dataset.py
@@ -1,262 +1,695 @@
# Standard library
-import datetime as dt
-import glob
-import os
+import datetime
+import warnings
+from typing import Union
# Third-party
import numpy as np
+import pytorch_lightning as pl
import torch
+import xarray as xr
-# Local
-from . import utils
+# First-party
+from neural_lam.datastore.base import BaseDatastore
class WeatherDataset(torch.utils.data.Dataset):
- """
- For our dataset:
- N_t' = 65
- N_t = 65//subsample_step (= 21 for 3h steps)
- dim_y = 268
- dim_x = 238
- N_grid = 268x238 = 63784
- d_features = 17 (d_features' = 18)
- d_forcing = 5
+ """Dataset class for weather data.
+
+ This class loads and processes weather data from a given datastore.
+
+ Parameters
+ ----------
+ datastore : BaseDatastore
+ The datastore to load the data from (e.g. mdp).
+ split : str, optional
+ The data split to use ("train", "val" or "test"). Default is "train".
+ ar_steps : int, optional
+ The number of autoregressive steps. Default is 3.
+ num_past_forcing_steps: int, optional
+ Number of past time steps to include in forcing input. If set to i,
+ forcing from times t-i, t-i+1, ..., t-1, t (and potentially beyond,
+ given num_future_forcing_steps) are included as forcing inputs at time t
+ Default is 1.
+ num_future_forcing_steps: int, optional
+ Number of future time steps to include in forcing input. If set to j,
+ forcing from times t, t+1, ..., t+j-1, t+j (and potentially times before
+ t, given num_past_forcing_steps) are included as forcing inputs at time
+ t. Default is 1.
+ standardize : bool, optional
+ Whether to standardize the data. Default is True.
"""
def __init__(
self,
- dataset_name,
- pred_length=19,
+ datastore: BaseDatastore,
split="train",
- subsample_step=3,
+ ar_steps=3,
+ num_past_forcing_steps=1,
+ num_future_forcing_steps=1,
standardize=True,
- subset=False,
- control_only=False,
):
super().__init__()
- assert split in ("train", "val", "test"), "Unknown dataset split"
- self.sample_dir_path = os.path.join(
- "data", dataset_name, "samples", split
- )
+ self.split = split
+ self.ar_steps = ar_steps
+ self.datastore = datastore
+ self.num_past_forcing_steps = num_past_forcing_steps
+ self.num_future_forcing_steps = num_future_forcing_steps
- member_file_regexp = (
- "nwp*mbr000.npy" if control_only else "nwp*mbr*.npy"
+ self.da_state = self.datastore.get_dataarray(
+ category="state", split=self.split
)
- sample_paths = glob.glob(
- os.path.join(self.sample_dir_path, member_file_regexp)
+ self.da_forcing = self.datastore.get_dataarray(
+ category="forcing", split=self.split
)
- self.sample_names = [path.split("/")[-1][4:-4] for path in sample_paths]
- # Now on form "yyymmddhh_mbrXXX"
- if subset:
- self.sample_names = self.sample_names[:50] # Limit to 50 samples
+ # check that with the provided data-arrays and ar_steps that we have a
+ # non-zero amount of samples
+ if self.__len__() <= 0:
+ raise ValueError(
+ "The provided datastore only provides "
+ f"{len(self.da_state.time)} total time steps, which is too few "
+ "to create a single sample for the WeatherDataset "
+ f"configuration used in the `{split}` split. You could try "
+ "either reducing the number of autoregressive steps "
+ "(`ar_steps`) and/or the forcing window size "
+ "(`num_past_forcing_steps` and `num_future_forcing_steps`)"
+ )
+
+ # Check the dimensions and their ordering
+ parts = dict(state=self.da_state)
+ if self.da_forcing is not None:
+ parts["forcing"] = self.da_forcing
- self.sample_length = pred_length + 2 # 2 init states
- self.subsample_step = subsample_step
- self.original_sample_length = (
- 65 // self.subsample_step
- ) # 21 for 3h steps
- assert (
- self.sample_length <= self.original_sample_length
- ), "Requesting too long time series samples"
+ for part, da in parts.items():
+ expected_dim_order = self.datastore.expected_dim_order(
+ category=part
+ )
+ if da.dims != expected_dim_order:
+ raise ValueError(
+ f"The dimension order of the `{part}` data ({da.dims}) "
+ f"does not match the expected dimension order "
+ f"({expected_dim_order}). Maybe you forgot to transpose "
+ "the data in `BaseDatastore.get_dataarray`?"
+ )
# Set up for standardization
+ # TODO: This will become part of ar_model.py soon!
self.standardize = standardize
if standardize:
- ds_stats = utils.load_dataset_stats(dataset_name, "cpu")
- self.data_mean, self.data_std, self.flux_mean, self.flux_std = (
- ds_stats["data_mean"],
- ds_stats["data_std"],
- ds_stats["flux_mean"],
- ds_stats["flux_std"],
+ self.ds_state_stats = self.datastore.get_standardization_dataarray(
+ category="state"
)
- # If subsample index should be sampled (only duing training)
- self.random_subsample = split == "train"
+ self.da_state_mean = self.ds_state_stats.state_mean
+ self.da_state_std = self.ds_state_stats.state_std
+
+ if self.da_forcing is not None:
+ self.ds_forcing_stats = (
+ self.datastore.get_standardization_dataarray(
+ category="forcing"
+ )
+ )
+ self.da_forcing_mean = self.ds_forcing_stats.forcing_mean
+ self.da_forcing_std = self.ds_forcing_stats.forcing_std
def __len__(self):
- return len(self.sample_names)
+ if self.datastore.is_forecast:
+ # for now we simply create a single sample for each analysis time
+ # and then take the first (2 + ar_steps) forecast times. In
+ # addition we only use the first ensemble member (if ensemble data
+ # has been provided).
+ # This means that for each analysis time we get a single sample
- def __getitem__(self, idx):
- # === Sample ===
- sample_name = self.sample_names[idx]
- sample_path = os.path.join(
- self.sample_dir_path, f"nwp_{sample_name}.npy"
+ if self.datastore.is_ensemble:
+ warnings.warn(
+ "only using first ensemble member, so dataset size is "
+ " effectively reduced by the number of ensemble members "
+ f"({self.da_state.ensemble_member.size})",
+ UserWarning,
+ )
+
+ # check that there are enough forecast steps available to create
+ # samples given the number of autoregressive steps requested
+ n_forecast_steps = self.da_state.elapsed_forecast_duration.size
+ if n_forecast_steps < 2 + self.ar_steps:
+ raise ValueError(
+ "The number of forecast steps available "
+ f"({n_forecast_steps}) is less than the required "
+ f"2+ar_steps (2+{self.ar_steps}={2 + self.ar_steps}) for "
+ "creating a sample with initial and target states."
+ )
+
+ return self.da_state.analysis_time.size
+ else:
+ # Calculate the number of samples in the dataset n_samples = total
+ # time steps - (autoregressive steps + past forcing + future
+ # forcing)
+ #:
+ # Where:
+ # - total time steps: len(self.da_state.time)
+ # - autoregressive steps: self.ar_steps
+ # - past forcing: max(2, self.num_past_forcing_steps) (at least 2
+ # time steps are required for the initial state)
+ # - future forcing: self.num_future_forcing_steps
+ return (
+ len(self.da_state.time)
+ - self.ar_steps
+ - max(2, self.num_past_forcing_steps)
+ - self.num_future_forcing_steps
+ )
+
+ def _slice_state_time(self, da_state, idx, n_steps: int):
+ """
+ Produce a time slice of the given dataarray `da_state` (state) starting
+ at `idx` and with `n_steps` steps. An `offset`is calculated based on the
+ `num_past_forcing_steps` class attribute. `Offset` is used to offset the
+ start of the sample, to assert that enough previous time steps are
+ available for the 2 initial states and any corresponding forcings
+ (calculated in `_slice_forcing_time`).
+
+ Parameters
+ ----------
+ da_state : xr.DataArray
+ The dataarray to slice. This is expected to have a `time` dimension
+ if the datastore is providing analysis only data, and a
+ `analysis_time` and `elapsed_forecast_duration` dimensions if the
+ datastore is providing forecast data.
+ idx : int
+ The index of the time step to start the sample from.
+ n_steps : int
+ The number of time steps to include in the sample.
+
+ Returns
+ -------
+ da_sliced : xr.DataArray
+ The sliced dataarray with dims ('time', 'grid_index',
+ 'state_feature').
+ """
+ # The current implementation requires at least 2 time steps for the
+ # initial state (see GraphCast).
+ init_steps = 2
+ # slice the dataarray to include the required number of time steps
+ if self.datastore.is_forecast:
+ start_idx = max(0, self.num_past_forcing_steps - init_steps)
+ end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps
+ # this implies that the data will have both `analysis_time` and
+ # `elapsed_forecast_duration` dimensions for forecasts. We for now
+ # simply select a analysis time and the first `n_steps` forecast
+ # times (given no offset). Note that this means that we get one
+ # sample per forecast, always starting at forecast time 2.
+ da_sliced = da_state.isel(
+ analysis_time=idx,
+ elapsed_forecast_duration=slice(start_idx, end_idx),
+ )
+ # create a new time dimension so that the produced sample has a
+ # `time` dimension, similarly to the analysis only data
+ da_sliced["time"] = (
+ da_sliced.analysis_time + da_sliced.elapsed_forecast_duration
+ )
+ da_sliced = da_sliced.swap_dims(
+ {"elapsed_forecast_duration": "time"}
+ )
+ else:
+ # For analysis data we slice the time dimension directly. The offset
+ # is only relevant for the very first (and last) samples in the
+ # dataset.
+ start_idx = idx + max(0, self.num_past_forcing_steps - init_steps)
+ end_idx = (
+ idx + max(init_steps, self.num_past_forcing_steps) + n_steps
+ )
+ da_sliced = da_state.isel(time=slice(start_idx, end_idx))
+ return da_sliced
+
+ def _slice_forcing_time(self, da_forcing, idx, n_steps: int):
+ """
+ Produce a time slice of the given dataarray `da_forcing` (forcing)
+ starting at `idx` and with `n_steps` steps. An `offset` is calculated
+ based on the `num_past_forcing_steps` class attribute. It is used to
+ offset the start of the sample, to ensure that enough previous time
+ steps are available for the forcing data. The forcing data is windowed
+ around the current autoregressive time step to include the past and
+ future forcings.
+
+ Parameters
+ ----------
+ da_forcing : xr.DataArray
+ The forcing dataarray to slice. This is expected to have a `time`
+ dimension if the datastore is providing analysis only data, and a
+ `analysis_time` and `elapsed_forecast_duration` dimensions if the
+ datastore is providing forecast data.
+ idx : int
+ The index of the time step to start the sample from.
+ n_steps : int
+ The number of time steps to include in the sample.
+
+ Returns
+ -------
+ da_concat : xr.DataArray
+ The sliced dataarray with dims ('time', 'grid_index',
+ 'window', 'forcing_feature').
+ """
+ # The current implementation requires at least 2 time steps for the
+ # initial state (see GraphCast). The forcing data is windowed around the
+ # current autregressive time step. The two `init_steps` can also be used
+ # as past forcings.
+ init_steps = 2
+ da_list = []
+
+ if self.datastore.is_forecast:
+ # This implies that the data will have both `analysis_time` and
+ # `elapsed_forecast_duration` dimensions for forecasts. We for now
+ # simply select an analysis time and the first `n_steps` forecast
+ # times (given no offset). Note that this means that we get one
+ # sample per forecast.
+ # Add a 'time' dimension using the actual forecast times
+ offset = max(init_steps, self.num_past_forcing_steps)
+ for step in range(n_steps):
+ start_idx = offset + step - self.num_past_forcing_steps
+ end_idx = offset + step + self.num_future_forcing_steps
+
+ current_time = (
+ da_forcing.analysis_time[idx]
+ + da_forcing.elapsed_forecast_duration[offset + step]
+ )
+
+ da_sliced = da_forcing.isel(
+ analysis_time=idx,
+ elapsed_forecast_duration=slice(start_idx, end_idx + 1),
+ )
+
+ da_sliced = da_sliced.rename(
+ {"elapsed_forecast_duration": "window"}
+ )
+
+ # Assign the 'window' coordinate to be relative positions
+ da_sliced = da_sliced.assign_coords(
+ window=np.arange(len(da_sliced.window))
+ )
+
+ da_sliced = da_sliced.expand_dims(
+ dim={"time": [current_time.values]}
+ )
+
+ da_list.append(da_sliced)
+
+ # Concatenate the list of DataArrays along the 'time' dimension
+ da_concat = xr.concat(da_list, dim="time")
+
+ else:
+ # For analysis data, we slice the time dimension directly. The
+ # offset is only relevant for the very first (and last) samples in
+ # the dataset.
+ offset = idx + max(init_steps, self.num_past_forcing_steps)
+ for step in range(n_steps):
+ start_idx = offset + step - self.num_past_forcing_steps
+ end_idx = offset + step + self.num_future_forcing_steps
+
+ # Slice the data over the desired time window
+ da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1))
+
+ da_sliced = da_sliced.rename({"time": "window"})
+
+ # Assign the 'window' coordinate to be relative positions
+ da_sliced = da_sliced.assign_coords(
+ window=np.arange(len(da_sliced.window))
+ )
+
+ # Add a 'time' dimension to keep track of steps using actual
+ # time coordinates
+ current_time = da_forcing.time[offset + step]
+ da_sliced = da_sliced.expand_dims(
+ dim={"time": [current_time.values]}
+ )
+
+ da_list.append(da_sliced)
+
+ # Concatenate the list of DataArrays along the 'time' dimension
+ da_concat = xr.concat(da_list, dim="time")
+
+ return da_concat
+
+ def _build_item_dataarrays(self, idx):
+ """
+ Create the dataarrays for the initial states, target states and forcing
+ data for the sample at index `idx`.
+
+ Parameters
+ ----------
+ idx : int
+ The index of the sample to create the dataarrays for.
+
+ Returns
+ -------
+ da_init_states : xr.DataArray
+ The dataarray for the initial states.
+ da_target_states : xr.DataArray
+ The dataarray for the target states.
+ da_forcing_windowed : xr.DataArray
+ The dataarray for the forcing data, windowed for the sample.
+ da_target_times : xr.DataArray
+ The dataarray for the target times.
+ """
+ # handling ensemble data
+ if self.datastore.is_ensemble:
+ # for the now the strategy is to only include the first ensemble
+ # member
+ # XXX: this could be changed to include all ensemble members by
+ # splitting `idx` into two parts, one for the analysis time and one
+ # for the ensemble member and then increasing self.__len__ to
+ # include all ensemble members
+ warnings.warn(
+ "only use of ensemble member 0 (the first member) is "
+ "implemented for ensemble data"
+ )
+ i_ensemble = 0
+ da_state = self.da_state.isel(ensemble_member=i_ensemble)
+ else:
+ da_state = self.da_state
+
+ if self.da_forcing is not None:
+ if "ensemble_member" in self.da_forcing.dims:
+ raise NotImplementedError(
+ "Ensemble member not yet supported for forcing data"
+ )
+ da_forcing = self.da_forcing
+ else:
+ da_forcing = None
+
+ # handle time sampling in a way that is compatible with both analysis
+ # and forecast data
+ da_state = self._slice_state_time(
+ da_state=da_state, idx=idx, n_steps=self.ar_steps
)
- try:
- full_sample = torch.tensor(
- np.load(sample_path), dtype=torch.float32
- ) # (N_t', dim_y, dim_x, d_features')
- except ValueError:
- print(f"Failed to load {sample_path}")
-
- # Only use every ss_step:th time step, sample which of ss_step
- # possible such time series
- if self.random_subsample:
- subsample_index = torch.randint(0, self.subsample_step, ()).item()
+ if da_forcing is not None:
+ da_forcing_windowed = self._slice_forcing_time(
+ da_forcing=da_forcing, idx=idx, n_steps=self.ar_steps
+ )
+
+ # load the data into memory
+ da_state.load()
+ if da_forcing is not None:
+ da_forcing_windowed.load()
+
+ da_init_states = da_state.isel(time=slice(0, 2))
+ da_target_states = da_state.isel(time=slice(2, None))
+ da_target_times = da_target_states.time
+
+ if self.standardize:
+ da_init_states = (
+ da_init_states - self.da_state_mean
+ ) / self.da_state_std
+ da_target_states = (
+ da_target_states - self.da_state_mean
+ ) / self.da_state_std
+
+ if da_forcing is not None:
+ # XXX: Here we implicitly assume that the last dimension of the
+ # forcing data is the forcing feature dimension. To standardize
+ # on `.device` we need a different implementation. (e.g. a
+ # tensor with repeated means and stds for each "windowed" time.)
+ da_forcing_windowed = (
+ da_forcing_windowed - self.da_forcing_mean
+ ) / self.da_forcing_std
+
+ if da_forcing is not None:
+ # stack the `forcing_feature` and `window_sample` dimensions into a
+ # single `forcing_feature` dimension
+ da_forcing_windowed = da_forcing_windowed.stack(
+ forcing_feature_windowed=("forcing_feature", "window")
+ )
else:
- subsample_index = 0
- subsample_end_index = self.original_sample_length * self.subsample_step
- sample = full_sample[
- subsample_index : subsample_end_index : self.subsample_step
- ]
- # (N_t, dim_y, dim_x, d_features')
-
- # Remove feature 15, "z_height_above_ground"
- sample = torch.cat(
- (sample[:, :, :, :15], sample[:, :, :, 16:]), dim=3
- ) # (N_t, dim_y, dim_x, d_features)
-
- # Accumulate solar radiation instead of just subsampling
- rad_features = full_sample[:, :, :, 2:4] # (N_t', dim_y, dim_x, 2)
- # Accumulate for first time step
- init_accum_rad = torch.sum(
- rad_features[: (subsample_index + 1)], dim=0, keepdim=True
- ) # (1, dim_y, dim_x, 2)
- # Accumulate for rest of subsampled sequence
- in_subsample_len = (
- subsample_end_index - self.subsample_step + subsample_index + 1
+ # create an empty forcing tensor with the right shape
+ da_forcing_windowed = xr.DataArray(
+ data=np.empty(
+ (self.ar_steps, da_state.grid_index.size, 0),
+ ),
+ dims=("time", "grid_index", "forcing_feature"),
+ coords={
+ "time": da_target_times,
+ "grid_index": da_state.grid_index,
+ "forcing_feature": [],
+ },
+ )
+
+ return (
+ da_init_states,
+ da_target_states,
+ da_forcing_windowed,
+ da_target_times,
)
- rad_features_in_subsample = rad_features[
- (subsample_index + 1) : in_subsample_len
- ] # (N_t*, dim_y, dim_x, 2), N_t* = (N_t-1)*ss_step
- _, dim_y, dim_x, _ = sample.shape
- rest_accum_rad = torch.sum(
- rad_features_in_subsample.view(
- self.original_sample_length - 1,
- self.subsample_step,
- dim_y,
- dim_x,
- 2,
- ),
- dim=1,
- ) # (N_t-1, dim_y, dim_x, 2)
- accum_rad = torch.cat(
- (init_accum_rad, rest_accum_rad), dim=0
- ) # (N_t, dim_y, dim_x, 2)
- # Replace in sample
- sample[:, :, :, 2:4] = accum_rad
-
- # Flatten spatial dim
- sample = sample.flatten(1, 2) # (N_t, N_grid, d_features)
-
- # Uniformly sample time id to start sample from
- init_id = torch.randint(
- 0, 1 + self.original_sample_length - self.sample_length, ()
+
+ def __getitem__(self, idx):
+ """
+ Return a single training sample, which consists of the initial states,
+ target states, forcing and batch times.
+
+ The implementation currently uses xarray.DataArray objects for the
+ standardization (scaling to mean 0.0 and standard deviation of 1.0) so
+ that we can make us of xarray's broadcasting capabilities. This makes
+ it possible to standardization with both global means, but also for
+ example where a grid-point mean has been computed. This code will have
+ to be replace if standardization is to be done on the GPU to handle
+ different shapes of the standardization.
+
+ Parameters
+ ----------
+ idx : int
+ The index of the sample to return, this will refer to the time of
+ the initial state.
+
+ Returns
+ -------
+ init_states : TrainingSample
+ A training sample object containing the initial states, target
+ states, forcing and batch times. The batch times are the times of
+ the target steps.
+
+ """
+ (
+ da_init_states,
+ da_target_states,
+ da_forcing_windowed,
+ da_target_times,
+ ) = self._build_item_dataarrays(idx=idx)
+
+ tensor_dtype = torch.float32
+
+ init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype)
+ target_states = torch.tensor(
+ da_target_states.values, dtype=tensor_dtype
)
- sample = sample[init_id : (init_id + self.sample_length)]
- # (sample_length, N_grid, d_features)
- if self.standardize:
- # Standardize sample
- sample = (sample - self.data_mean) / self.data_std
-
- # Split up sample in init. states and target states
- init_states = sample[:2] # (2, N_grid, d_features)
- target_states = sample[2:] # (sample_length-2, N_grid, d_features)
-
- # === Forcing features ===
- # Now batch-static features are just part of forcing,
- # repeated over temporal dimension
- # Load water coverage
- sample_datetime = sample_name[:10]
- water_path = os.path.join(
- self.sample_dir_path, f"wtr_{sample_datetime}.npy"
+ target_times = torch.tensor(
+ da_target_times.astype("datetime64[ns]").astype("int64").values,
+ dtype=torch.int64,
)
- water_cover_features = torch.tensor(
- np.load(water_path), dtype=torch.float32
- ).unsqueeze(
- -1
- ) # (dim_y, dim_x, 1)
- # Flatten
- water_cover_features = water_cover_features.flatten(0, 1) # (N_grid, 1)
- # Expand over temporal dimension
- water_cover_expanded = water_cover_features.unsqueeze(0).expand(
- self.sample_length - 2, -1, -1 # -2 as added on after windowing
- ) # (sample_len, N_grid, 1)
-
- # TOA flux
- flux_path = os.path.join(
- self.sample_dir_path,
- f"nwp_toa_downwelling_shortwave_flux_{sample_datetime}.npy",
+
+ forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype)
+
+ # init_states: (2, N_grid, d_features)
+ # target_states: (ar_steps, N_grid, d_features)
+ # forcing: (ar_steps, N_grid, d_windowed_forcing)
+ # target_times: (ar_steps,)
+
+ return init_states, target_states, forcing, target_times
+
+ def __iter__(self):
+ """
+ Convenience method to iterate over the dataset.
+
+ This isn't used by pytorch DataLoader which itself implements an
+ iterator that uses Dataset.__getitem__ and Dataset.__len__.
+
+ """
+ for i in range(len(self)):
+ yield self[i]
+
+ def create_dataarray_from_tensor(
+ self,
+ tensor: torch.Tensor,
+ time: Union[datetime.datetime, list[datetime.datetime]],
+ category: str,
+ ):
+ """
+ Construct a xarray.DataArray from a `pytorch.Tensor` with coordinates
+ for `grid_index`, `time` and `{category}_feature` matching the shape
+ and number of times provided and add the x/y coordinates from the
+ datastore.
+
+ The number if times provided is expected to match the shape of the
+ tensor. For a 2D tensor, the dimensions are assumed to be (grid_index,
+ {category}_feature) and only a single time should be provided. For a 3D
+ tensor, the dimensions are assumed to be (time, grid_index,
+ {category}_feature) and a list of times should be provided.
+
+ Parameters
+ ----------
+ tensor : torch.Tensor
+ The tensor to construct the DataArray from, this assumed to have
+ the same dimension ordering as returned by the __getitem__ method
+ (i.e. time, grid_index, {category}_feature).
+ time : datetime.datetime or list[datetime.datetime]
+ The time or times of the tensor.
+ category : str
+ The category of the tensor, either "state", "forcing" or "static".
+
+ Returns
+ -------
+ da : xr.DataArray
+ The constructed DataArray.
+ """
+
+ def _is_listlike(obj):
+ # match list, tuple, numpy array
+ return hasattr(obj, "__iter__") and not isinstance(obj, str)
+
+ add_time_as_dim = False
+ if len(tensor.shape) == 2:
+ dims = ["grid_index", f"{category}_feature"]
+ if _is_listlike(time):
+ raise ValueError(
+ "Expected a single time for a 2D tensor with assumed "
+ "dimensions (grid_index, {category}_feature), but got "
+ f"{len(time)} times"
+ )
+ elif len(tensor.shape) == 3:
+ add_time_as_dim = True
+ dims = ["time", "grid_index", f"{category}_feature"]
+ if not _is_listlike(time):
+ raise ValueError(
+ "Expected a list of times for a 3D tensor with assumed "
+ "dimensions (time, grid_index, {category}_feature), but "
+ "got a single time"
+ )
+ else:
+ raise ValueError(
+ "Expected tensor to have 2 or 3 dimensions, but got "
+ f"{len(tensor.shape)}"
+ )
+
+ da_datastore_state = getattr(self, f"da_{category}")
+ da_grid_index = da_datastore_state.grid_index
+ da_state_feature = da_datastore_state.state_feature
+
+ coords = {
+ f"{category}_feature": da_state_feature,
+ "grid_index": da_grid_index,
+ }
+ if add_time_as_dim:
+ coords["time"] = time
+
+ da = xr.DataArray(
+ tensor.numpy(),
+ dims=dims,
+ coords=coords,
)
- flux = torch.tensor(np.load(flux_path), dtype=torch.float32).unsqueeze(
- -1
- ) # (N_t', dim_y, dim_x, 1)
- if self.standardize:
- flux = (flux - self.flux_mean) / self.flux_std
-
- # Flatten and subsample flux forcing
- flux = flux.flatten(1, 2) # (N_t, N_grid, 1)
- flux = flux[subsample_index :: self.subsample_step] # (N_t, N_grid, 1)
- flux = flux[
- init_id : (init_id + self.sample_length)
- ] # (sample_len, N_grid, 1)
-
- # Time of day and year
- dt_obj = dt.datetime.strptime(sample_datetime, "%Y%m%d%H")
- dt_obj = dt_obj + dt.timedelta(
- hours=2 + subsample_index
- ) # Offset for first index
- # Extract for initial step
- init_hour_in_day = dt_obj.hour
- start_of_year = dt.datetime(dt_obj.year, 1, 1)
- init_seconds_into_year = (dt_obj - start_of_year).total_seconds()
-
- # Add increments for all steps
- hour_inc = (
- torch.arange(self.sample_length) * self.subsample_step
- ) # (sample_len,)
- hour_of_day = (
- init_hour_in_day + hour_inc
- ) # (sample_len,), Can be > 24 but ok
- second_into_year = (
- init_seconds_into_year + hour_inc * 3600
- ) # (sample_len,)
- # can roll over to next year, ok because periodicity
-
- # Encode as sin/cos
- # ! Make this more flexible in a separate create_forcings.py script
- seconds_in_year = 365 * 24 * 3600
- hour_angle = (hour_of_day / 12) * torch.pi # (sample_len,)
- year_angle = (
- (second_into_year / seconds_in_year) * 2 * torch.pi
- ) # (sample_len,)
- datetime_forcing = torch.stack(
- (
- torch.sin(hour_angle),
- torch.cos(hour_angle),
- torch.sin(year_angle),
- torch.cos(year_angle),
- ),
- dim=1,
- ) # (N_t, 4)
- datetime_forcing = (datetime_forcing + 1) / 2 # Rescale to [0,1]
- datetime_forcing = datetime_forcing.unsqueeze(1).expand(
- -1, flux.shape[1], -1
- ) # (sample_len, N_grid, 4)
-
- # Put forcing features together
- forcing_features = torch.cat(
- (flux, datetime_forcing), dim=-1
- ) # (sample_len, N_grid, d_forcing)
-
- # Combine forcing over each window of 3 time steps
- forcing_windowed = torch.cat(
- (
- forcing_features[:-2],
- forcing_features[1:-1],
- forcing_features[2:],
- ),
- dim=2,
- ) # (sample_len-2, N_grid, 3*d_forcing)
- # Now index 0 of ^ corresponds to forcing at index 0-2 of sample
-
- # batch-static water cover is added after windowing,
- # as it is static over time
- forcing = torch.cat((water_cover_expanded, forcing_windowed), dim=2)
- # (sample_len-2, N_grid, forcing_dim)
-
- return init_states, target_states, forcing
+ for grid_coord in ["x", "y"]:
+ if (
+ grid_coord in da_datastore_state.coords
+ and grid_coord not in da.coords
+ ):
+ da.coords[grid_coord] = da_datastore_state[grid_coord]
+
+ if not add_time_as_dim:
+ da.coords["time"] = time
+
+ return da
+
+
+class WeatherDataModule(pl.LightningDataModule):
+ """DataModule for weather data."""
+
+ def __init__(
+ self,
+ datastore: BaseDatastore,
+ ar_steps_train=3,
+ ar_steps_eval=25,
+ standardize=True,
+ num_past_forcing_steps=1,
+ num_future_forcing_steps=1,
+ batch_size=4,
+ num_workers=16,
+ ):
+ super().__init__()
+ self._datastore = datastore
+ self.num_past_forcing_steps = num_past_forcing_steps
+ self.num_future_forcing_steps = num_future_forcing_steps
+ self.ar_steps_train = ar_steps_train
+ self.ar_steps_eval = ar_steps_eval
+ self.standardize = standardize
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.train_dataset = None
+ self.val_dataset = None
+ self.test_dataset = None
+ if num_workers > 0:
+ # default to spawn for now, as the default on linux "fork" hangs
+ # when using dask (which the npyfilesmeps datastore uses)
+ self.multiprocessing_context = "spawn"
+ else:
+ self.multiprocessing_context = None
+
+ def setup(self, stage=None):
+ if stage == "fit" or stage is None:
+ self.train_dataset = WeatherDataset(
+ datastore=self._datastore,
+ split="train",
+ ar_steps=self.ar_steps_train,
+ standardize=self.standardize,
+ num_past_forcing_steps=self.num_past_forcing_steps,
+ num_future_forcing_steps=self.num_future_forcing_steps,
+ )
+ self.val_dataset = WeatherDataset(
+ datastore=self._datastore,
+ split="val",
+ ar_steps=self.ar_steps_eval,
+ standardize=self.standardize,
+ num_past_forcing_steps=self.num_past_forcing_steps,
+ num_future_forcing_steps=self.num_future_forcing_steps,
+ )
+
+ if stage == "test" or stage is None:
+ self.test_dataset = WeatherDataset(
+ datastore=self._datastore,
+ split="test",
+ ar_steps=self.ar_steps_eval,
+ standardize=self.standardize,
+ num_past_forcing_steps=self.num_past_forcing_steps,
+ num_future_forcing_steps=self.num_future_forcing_steps,
+ )
+
+ def train_dataloader(self):
+ """Load train dataset."""
+ return torch.utils.data.DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=True,
+ multiprocessing_context=self.multiprocessing_context,
+ persistent_workers=True,
+ )
+
+ def val_dataloader(self):
+ """Load validation dataset."""
+ return torch.utils.data.DataLoader(
+ self.val_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
+ multiprocessing_context=self.multiprocessing_context,
+ persistent_workers=True,
+ )
+
+ def test_dataloader(self):
+ """Load test dataset."""
+ return torch.utils.data.DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ shuffle=False,
+ multiprocessing_context=self.multiprocessing_context,
+ persistent_workers=True,
+ )
diff --git a/pyproject.toml b/pyproject.toml
index 14b7e69a..f0bc0851 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -3,9 +3,9 @@ name = "neural-lam"
version = "0.2.0"
description = "LAM-based data-driven forecasting"
authors = [
- {name = "Joel Oskarsson", email = "joel.oskarsson@liu.se"},
- {name = "Simon Adamov", email = "Simon.Adamov@meteoswiss.ch"},
- {name = "Leif Denby", email = "lcd@dmi.dk"},
+ { name = "Joel Oskarsson", email = "joel.oskarsson@liu.se" },
+ { name = "Simon Adamov", email = "Simon.Adamov@meteoswiss.ch" },
+ { name = "Leif Denby", email = "lcd@dmi.dk" },
]
# PEP 621 project metadata
@@ -24,15 +24,15 @@ dependencies = [
"plotly>=5.15.0",
"torch>=2.3.0",
"torch-geometric==2.3.1",
+ "parse>=1.20.2",
+ "dataclass-wizard>=0.22.3",
+ "mllam-data-prep>=0.5.0",
]
requires-python = ">=3.9"
[project.optional-dependencies]
-dev = [
- "pre-commit>=3.8.0",
- "pytest>=8.3.2",
- "pooch>=1.8.2",
-]
+dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"]
+
[tool.setuptools]
py-modules = ["neural_lam"]
@@ -59,6 +59,7 @@ known_first_party = [
# Add first-party modules that may be misclassified by isort
"neural_lam",
]
+line_length = 80
[tool.flake8]
max-line-length = 80
@@ -80,12 +81,9 @@ ignore = [
"create_mesh.py", # Disable linting for now, as major rework is planned/expected
]
# Temporary fix for import neural_lam statements until set up as proper package
-init-hook='import sys; sys.path.append(".")'
+init-hook = 'import sys; sys.path.append(".")'
[tool.pylint.TYPECHECK]
-generated-members = [
- "numpy.*",
- "torch.*",
-]
+generated-members = ["numpy.*", "torch.*"]
[tool.pylint.'MESSAGES CONTROL']
disable = [
"C0114", # 'missing-module-docstring', Do not require module docstrings
@@ -96,11 +94,11 @@ disable = [
"W0223", # 'abstract-method', Subclasses do not have to override all abstract methods
]
[tool.pylint.DESIGN]
-max-statements=100 # Allow for some more involved functions
+max-statements = 100 # Allow for some more involved functions
[tool.pylint.IMPORTS]
-allow-any-import-level="neural_lam"
+allow-any-import-level = "neural_lam"
[tool.pylint.SIMILARITIES]
-min-similarity-lines=10
+min-similarity-lines = 10
[tool.pdm]
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 00000000..6f579621
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,106 @@
+# Standard library
+import os
+from pathlib import Path
+
+# Third-party
+import pooch
+import yaml
+
+# First-party
+from neural_lam.datastore import DATASTORES, init_datastore
+from neural_lam.datastore.npyfilesmeps import (
+ compute_standardization_stats as compute_standardization_stats_meps,
+)
+
+# Local
+from .dummy_datastore import DummyDatastore
+
+# Disable weights and biases to avoid unnecessary logging
+# and to avoid having to deal with authentication
+os.environ["WANDB_DISABLED"] = "true"
+
+DATASTORE_EXAMPLES_ROOT_PATH = Path("tests/datastore_examples")
+
+# Initializing variables for the s3 client
+S3_BUCKET_NAME = "mllam-testdata"
+S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int"
+S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.2.0.zip"
+S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH])
+TEST_DATA_KNOWN_HASH = (
+ "7ff2e07e04cfcd77631115f800c9d49188bb2a7c2a2777da3cea219f926d0c86"
+)
+
+
+def download_meps_example_reduced_dataset():
+ # Download and unzip test data into data/meps_example_reduced
+ root_path = DATASTORE_EXAMPLES_ROOT_PATH / "npyfilesmeps"
+ dataset_path = root_path / "meps_example_reduced"
+
+ pooch.retrieve(
+ url=S3_FULL_PATH,
+ known_hash=TEST_DATA_KNOWN_HASH,
+ processor=pooch.Unzip(extract_dir=""),
+ path=root_path,
+ fname="meps_example_reduced.zip",
+ )
+
+ config_path = dataset_path / "meps_example_reduced.datastore.yaml"
+
+ with open(config_path, "r") as f:
+ config = yaml.safe_load(f)
+
+ if "class" in config["projection"]:
+ # XXX: should update the dataset stored on S3 with the change below
+ #
+ # rename the `projection.class` key to `projection.class_name` in the
+ # config this is because the `class` key is reserved for the class
+ # attribute of the object and so we can't use it to define a python
+ # dataclass
+ config["projection"]["class_name"] = config["projection"].pop("class")
+
+ with open(config_path, "w") as f:
+ yaml.dump(config, f)
+
+ # create parameters, only run if the files we expect are not present
+ expected_parameter_files = [
+ "parameter_mean.pt",
+ "parameter_std.pt",
+ "diff_mean.pt",
+ "diff_std.pt",
+ ]
+ expected_parameter_filepaths = [
+ dataset_path / "static" / fn for fn in expected_parameter_files
+ ]
+ if any(not p.exists() for p in expected_parameter_filepaths):
+ compute_standardization_stats_meps.main(
+ datastore_config_path=config_path,
+ batch_size=8,
+ step_length=3,
+ n_workers=0,
+ distributed=False,
+ )
+
+ return config_path
+
+
+DATASTORES_EXAMPLES = dict(
+ mdp=(
+ DATASTORE_EXAMPLES_ROOT_PATH
+ / "mdp"
+ / "danra_100m_winds"
+ / "danra.datastore.yaml"
+ ),
+ npyfilesmeps=download_meps_example_reduced_dataset(),
+ dummydata=None,
+)
+
+DATASTORES[DummyDatastore.SHORT_NAME] = DummyDatastore
+
+
+def init_datastore_example(datastore_kind):
+ datastore = init_datastore(
+ datastore_kind=datastore_kind,
+ config_path=DATASTORES_EXAMPLES[datastore_kind],
+ )
+
+ return datastore
diff --git a/tests/datastore_examples/.gitignore b/tests/datastore_examples/.gitignore
new file mode 100644
index 00000000..e84e6493
--- /dev/null
+++ b/tests/datastore_examples/.gitignore
@@ -0,0 +1,2 @@
+npyfilesmeps/*.zip
+npyfilesmeps/meps_example_reduced/
diff --git a/tests/datastore_examples/mdp/danra_100m_winds/.gitignore b/tests/datastore_examples/mdp/danra_100m_winds/.gitignore
new file mode 100644
index 00000000..f2828f46
--- /dev/null
+++ b/tests/datastore_examples/mdp/danra_100m_winds/.gitignore
@@ -0,0 +1,2 @@
+*.zarr/
+graph/
diff --git a/tests/datastore_examples/mdp/danra_100m_winds/config.yaml b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml
new file mode 100644
index 00000000..0bb5c5ec
--- /dev/null
+++ b/tests/datastore_examples/mdp/danra_100m_winds/config.yaml
@@ -0,0 +1,9 @@
+datastore:
+ kind: mdp
+ config_path: danra.datastore.yaml
+training:
+ state_feature_weighting:
+ __config_class__: ManualStateFeatureWeighting
+ weights:
+ u100m: 1.0
+ v100m: 1.0
diff --git a/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml
new file mode 100644
index 00000000..3edf1267
--- /dev/null
+++ b/tests/datastore_examples/mdp/danra_100m_winds/danra.datastore.yaml
@@ -0,0 +1,99 @@
+schema_version: v0.5.0
+dataset_version: v0.1.0
+
+output:
+ variables:
+ static: [grid_index, static_feature]
+ state: [time, grid_index, state_feature]
+ forcing: [time, grid_index, forcing_feature]
+ coord_ranges:
+ time:
+ start: 1990-09-03T00:00
+ end: 1990-09-09T00:00
+ step: PT3H
+ chunking:
+ time: 1
+ splitting:
+ dim: time
+ splits:
+ train:
+ start: 1990-09-03T00:00
+ end: 1990-09-06T00:00
+ compute_statistics:
+ ops: [mean, std, diff_mean, diff_std]
+ dims: [grid_index, time]
+ val:
+ start: 1990-09-06T00:00
+ end: 1990-09-07T00:00
+ test:
+ start: 1990-09-07T00:00
+ end: 1990-09-09T00:00
+
+inputs:
+ danra_height_levels:
+ path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/height_levels.zarr
+ dims: [time, x, y, altitude]
+ variables:
+ u:
+ altitude:
+ values: [100,]
+ units: m
+ v:
+ altitude:
+ values: [100, ]
+ units: m
+ dim_mapping:
+ time:
+ method: rename
+ dim: time
+ state_feature:
+ method: stack_variables_by_var_name
+ dims: [altitude]
+ name_format: "{var_name}{altitude}m"
+ grid_index:
+ method: stack
+ dims: [x, y]
+ target_output_variable: state
+
+ danra_surface:
+ path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/single_levels.zarr
+ dims: [time, x, y]
+ variables:
+ # use surface incoming shortwave radiation as forcing
+ - swavr0m
+ dim_mapping:
+ time:
+ method: rename
+ dim: time
+ grid_index:
+ method: stack
+ dims: [x, y]
+ forcing_feature:
+ method: stack_variables_by_var_name
+ name_format: "{var_name}"
+ target_output_variable: forcing
+
+ danra_lsm:
+ path: https://mllam-test-data.s3.eu-north-1.amazonaws.com/lsm.zarr
+ dims: [x, y]
+ variables:
+ - lsm
+ dim_mapping:
+ grid_index:
+ method: stack
+ dims: [x, y]
+ static_feature:
+ method: stack_variables_by_var_name
+ name_format: "{var_name}"
+ target_output_variable: static
+
+extra:
+ projection:
+ class_name: LambertConformal
+ kwargs:
+ central_longitude: 25.0
+ central_latitude: 56.7
+ standard_parallels: [56.7, 56.7]
+ globe:
+ semimajor_axis: 6367470.0
+ semiminor_axis: 6367470.0
diff --git a/tests/dummy_datastore.py b/tests/dummy_datastore.py
new file mode 100644
index 00000000..9075d404
--- /dev/null
+++ b/tests/dummy_datastore.py
@@ -0,0 +1,449 @@
+# Standard library
+import datetime
+import tempfile
+from functools import cached_property
+from pathlib import Path
+from typing import List, Union
+
+# Third-party
+import isodate
+import numpy as np
+import xarray as xr
+from cartopy import crs as ccrs
+from numpy import ndarray
+
+# First-party
+from neural_lam.datastore.base import (
+ BaseRegularGridDatastore,
+ CartesianGridShape,
+)
+
+
+class DummyDatastore(BaseRegularGridDatastore):
+ """
+ Datastore that creates some dummy data for testing purposes. The data
+ consists of state, forcing, and static variables, and is stored in a
+ regular grid (using Lambert Azimuthal Equal Area projection). The domain
+ is centered on Denmark and has a size of 500x500 km.
+ """
+
+ SHORT_NAME = "dummydata"
+ T0 = isodate.parse_datetime("2021-01-01T00:00:00")
+ N_FEATURES = dict(state=5, forcing=2, static=1)
+ CARTESIAN_COORDS = ["x", "y"]
+
+ # center the domain on Denmark
+ latlon_center = [56, 10] # latitude, longitude
+ bbox_size_km = [500, 500] # km
+
+ def __init__(
+ self, config_path=None, n_grid_points=10000, n_timesteps=10
+ ) -> None:
+ """
+ Create a dummy datastore with random data.
+
+ Parameters
+ ----------
+ config_path : None
+ No config file is needed for the dummy datastore. This argument is
+ only present to match the signature of the other datastores.
+ n_grid_points : int
+ The number of grid points in the dataset. Must be a perfect square.
+ n_timesteps : int
+ The number of timesteps in the dataset.
+ """
+ assert (
+ config_path is None
+ ), "No config file is needed for the dummy datastore"
+
+ # Ensure n_grid_points is a perfect square
+ n_points_1d = int(np.sqrt(n_grid_points))
+ assert (
+ n_points_1d * n_points_1d == n_grid_points
+ ), "n_grid_points must be a perfect square"
+
+ # create equal area grid
+ lx, ly = self.bbox_size_km
+ x = np.linspace(-lx / 2.0 * 1.0e3, lx / 2.0 * 1.0e3, n_points_1d)
+ y = np.linspace(-ly / 2.0 * 1.0e3, ly / 2.0 * 1.0e3, n_points_1d)
+
+ xs, ys = np.meshgrid(x, y)
+
+ # Create lat/lon coordinates using equal area projection
+ lon_mesh, lat_mesh = (
+ ccrs.PlateCarree()
+ .transform_points(
+ src_crs=self.coords_projection,
+ x=xs.flatten(),
+ y=ys.flatten(),
+ )[:, :2]
+ .T
+ )
+
+ # Create base dataset with proper coordinates
+ self.ds = xr.Dataset(
+ coords={
+ "x": (
+ "x",
+ x,
+ {"units": "m"},
+ ), # Use first column for x coordinates
+ "y": (
+ "y",
+ y,
+ {"units": "m"},
+ ), # Use first row for y coordinates
+ "longitude": (
+ "grid_index",
+ lon_mesh.flatten(),
+ {"units": "degrees_east"},
+ ),
+ "latitude": (
+ "grid_index",
+ lat_mesh.flatten(),
+ {"units": "degrees_north"},
+ ),
+ }
+ )
+ # Create data variables with proper dimensions
+ for category, n in self.N_FEATURES.items():
+ feature_names = [f"{category}_feat_{i}" for i in range(n)]
+ feature_units = ["-" for _ in range(n)] # Placeholder units
+ feature_long_names = [
+ f"Long name for {name}" for name in feature_names
+ ]
+
+ self.ds[f"{category}_feature"] = feature_names
+ self.ds[f"{category}_feature_units"] = (
+ f"{category}_feature",
+ feature_units,
+ )
+ self.ds[f"{category}_feature_long_name"] = (
+ f"{category}_feature",
+ feature_long_names,
+ )
+
+ # Define dimensions and create random data
+ dims = ["grid_index", f"{category}_feature"]
+ if category != "static":
+ dims.append("time")
+ shape = (n_grid_points, n, n_timesteps)
+ else:
+ shape = (n_grid_points, n)
+
+ # Create random data
+ data = np.random.randn(*shape)
+
+ # Create DataArray with proper dimensions
+ self.ds[category] = xr.DataArray(
+ data,
+ dims=dims,
+ coords={
+ f"{category}_feature": feature_names,
+ },
+ )
+
+ if category != "static":
+ dt = datetime.timedelta(hours=self.step_length)
+ times = [self.T0 + dt * i for i in range(n_timesteps)]
+ self.ds.coords["time"] = times
+
+ # Add boundary mask
+ self.ds["boundary_mask"] = xr.DataArray(
+ np.random.choice([0, 1], size=(n_points_1d, n_points_1d)),
+ dims=["x", "y"],
+ )
+
+ # Stack the spatial dimensions into grid_index
+ self.ds = self.ds.stack(grid_index=self.CARTESIAN_COORDS)
+
+ # Create temporary directory for storing derived files
+ self._tempdir = tempfile.TemporaryDirectory()
+ self._root_path = Path(self._tempdir.name)
+ self._num_grid_points = n_grid_points
+
+ @property
+ def root_path(self) -> Path:
+ """
+ The root path to the datastore. It is relative to this that any derived
+ files (for example the graph components) are stored.
+
+ Returns
+ -------
+ pathlib.Path
+ The root path to the datastore.
+
+ """
+ return self._root_path
+
+ @property
+ def config(self) -> dict:
+ """The configuration of the datastore.
+
+ Returns
+ -------
+ collections.abc.Mapping
+ The configuration of the datastore, any dict like object can be
+ returned.
+
+ """
+ return {}
+
+ @property
+ def step_length(self) -> int:
+ """The step length of the dataset in hours.
+
+ Returns:
+ int: The step length in hours.
+
+ """
+ return 1
+
+ def get_vars_names(self, category: str) -> list[str]:
+ """Get the names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The names of the variables.
+
+ """
+ return self.ds[f"{category}_feature"].values.tolist()
+
+ def get_vars_units(self, category: str) -> list[str]:
+ """Get the units of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The units of the variables.
+
+ """
+ return self.ds[f"{category}_feature_units"].values.tolist()
+
+ def get_vars_long_names(self, category: str) -> List[str]:
+ """Get the long names of the variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ List[str]
+ The long names of the variables.
+
+ """
+ return self.ds[f"{category}_feature_long_name"].values.tolist()
+
+ def get_num_data_vars(self, category: str) -> int:
+ """Get the number of data variables in the given category.
+
+ Parameters
+ ----------
+ category : str
+ The category of the variables (state/forcing/static).
+
+ Returns
+ -------
+ int
+ The number of data variables.
+
+ """
+ return self.ds[f"{category}_feature"].size
+
+ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
+ """
+ Return the standardization (i.e. scaling to mean of 0.0 and standard
+ deviation of 1.0) dataarray for the given category. This should contain
+ a `{category}_mean` and `{category}_std` variable for each variable in
+ the category. For `category=="state"`, the dataarray should also
+ contain a `state_diff_mean` and `state_diff_std` variable for the one-
+ step differences of the state variables. The returned dataarray should
+ at least have dimensions of `({category}_feature)`, but can also
+ include for example `grid_index` (if the standardization is done per
+ grid point for example).
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+
+ Returns
+ -------
+ xr.Dataset
+ The standardization dataarray for the given category, with variables
+ for the mean and standard deviation of the variables (and
+ differences for state variables).
+
+ """
+ ds_standardization = xr.Dataset()
+
+ ops = ["mean", "std"]
+ if category == "state":
+ ops += ["diff_mean", "diff_std"]
+
+ for op in ops:
+ da_op = xr.ones_like(self.ds[f"{category}_feature"]).astype(float)
+ ds_standardization[f"{category}_{op}"] = da_op
+
+ return ds_standardization
+
+ def get_dataarray(
+ self, category: str, split: str
+ ) -> Union[xr.DataArray, None]:
+ """
+ Return the processed data (as a single `xr.DataArray`) for the given
+ category of data and test/train/val-split that covers all the data (in
+ space and time) of a given category (state/forcing/static). A
+ datastore must be able to return for the "state" category, but
+ "forcing" and "static" are optional (in which case the method should
+ return `None`). For the "static" category the `split` is allowed to be
+ `None` because the static data is the same for all splits.
+
+ The returned dataarray is expected to at minimum have dimensions of
+ `(grid_index, {category}_feature)` so that any spatial dimensions have
+ been stacked into a single dimension and all variables and levels have
+ been stacked into a single feature dimension named by the `category` of
+ data being loaded.
+
+ For categories of data that have a time dimension (i.e. not static
+ data), the dataarray is expected additionally have `(analysis_time,
+ elapsed_forecast_duration)` dimensions if `is_forecast` is True, or
+ `(time)` if `is_forecast` is False.
+
+ If the data is ensemble data, the dataarray is expected to have an
+ additional `ensemble_member` dimension.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ split : str
+ The time split to filter the dataset (train/val/test).
+
+ Returns
+ -------
+ xr.DataArray or None
+ The xarray DataArray object with processed dataset.
+
+ """
+ dim_order = self.expected_dim_order(category=category)
+ return self.ds[category].transpose(*dim_order)
+
+ @cached_property
+ def boundary_mask(self) -> xr.DataArray:
+ """
+ Return the boundary mask for the dataset, with spatial dimensions
+ stacked. Where the value is 1, the grid point is a boundary point, and
+ where the value is 0, the grid point is not a boundary point.
+
+ Returns
+ -------
+ xr.DataArray
+ The boundary mask for the dataset, with dimensions
+ `('grid_index',)`.
+
+ """
+ return self.ds["boundary_mask"]
+
+ def get_xy(self, category: str, stacked: bool) -> ndarray:
+ """Return the x, y coordinates of the dataset.
+
+ Parameters
+ ----------
+ category : str
+ The category of the dataset (state/forcing/static).
+ stacked : bool
+ Whether to stack the x, y coordinates.
+
+ Returns
+ -------
+ np.ndarray
+ The x, y coordinates of the dataset, returned differently based on
+ the value of `stacked`:
+ - `stacked==True`: shape `(n_grid_points, 2)` where
+ n_grid_points=N_x*N_y.
+ - `stacked==False`: shape `(N_x, N_y, 2)`
+
+ """
+ # assume variables are stored in dimensions [grid_index, ...]
+ ds_category = self.unstack_grid_coords(da_or_ds=self.ds[category])
+
+ da_xs = ds_category.x
+ da_ys = ds_category.y
+
+ assert da_xs.ndim == da_ys.ndim == 1, "x and y coordinates must be 1D"
+
+ da_x, da_y = xr.broadcast(da_xs, da_ys)
+ da_xy = xr.concat([da_x, da_y], dim="grid_coord")
+
+ if stacked:
+ da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose(
+ "grid_index",
+ "grid_coord",
+ )
+ else:
+ dims = [
+ "x",
+ "y",
+ "grid_coord",
+ ]
+ da_xy = da_xy.transpose(*dims)
+
+ return da_xy.values
+
+ @property
+ def coords_projection(self) -> ccrs.Projection:
+ """Return the projection object for the coordinates.
+
+ The projection object is used to plot the coordinates on a map.
+
+ Returns
+ -------
+ cartopy.crs.Projection:
+ The projection object.
+
+ """
+ # make a projection centered on Denmark
+ lat_center, lon_center = self.latlon_center
+ return ccrs.LambertAzimuthalEqualArea(
+ central_latitude=lat_center, central_longitude=lon_center
+ )
+
+ @property
+ def num_grid_points(self) -> int:
+ """Return the number of grid points in the dataset.
+
+ Returns
+ -------
+ int
+ The number of grid points in the dataset.
+
+ """
+ return self._num_grid_points
+
+ @cached_property
+ def grid_shape_state(self) -> CartesianGridShape:
+ """The shape of the grid for the state variables.
+
+ Returns
+ -------
+ CartesianGridShape:
+ The shape of the grid for the state variables, which has `x` and
+ `y` attributes.
+ """
+
+ n_points_1d = int(np.sqrt(self.num_grid_points))
+ return CartesianGridShape(x=n_points_1d, y=n_points_1d)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index e90daa04..0dbd04a1 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,18 +1,12 @@
# First-party
import neural_lam
-import neural_lam.create_grid_features
-import neural_lam.create_mesh
-import neural_lam.create_parameter_weights
+import neural_lam.create_graph
import neural_lam.train_model
def test_import():
- """
- This test just ensures that each cli entry-point can be imported for now,
- eventually we should test their execution too
- """
+ """This test just ensures that each cli entry-point can be imported for now,
+ eventually we should test their execution too."""
assert neural_lam is not None
- assert neural_lam.create_mesh is not None
- assert neural_lam.create_grid_features is not None
- assert neural_lam.create_parameter_weights is not None
+ assert neural_lam.create_graph is not None
assert neural_lam.train_model is not None
diff --git a/tests/test_config.py b/tests/test_config.py
new file mode 100644
index 00000000..1ff40bc6
--- /dev/null
+++ b/tests/test_config.py
@@ -0,0 +1,72 @@
+# Third-party
+import pytest
+
+# First-party
+import neural_lam.config as nlconfig
+
+
+@pytest.mark.parametrize(
+ "state_weighting_config",
+ [
+ nlconfig.ManualStateFeatureWeighting(
+ weights=dict(u100m=1.0, v100m=0.5)
+ ),
+ nlconfig.UniformFeatureWeighting(),
+ ],
+)
+def test_config_serialization(state_weighting_config):
+ c = nlconfig.NeuralLAMConfig(
+ datastore=nlconfig.DatastoreSelection(kind="mdp", config_path=""),
+ training=nlconfig.TrainingConfig(
+ state_feature_weighting=state_weighting_config
+ ),
+ )
+
+ assert c == c.from_json(c.to_json())
+ assert c == c.from_yaml(c.to_yaml())
+
+
+yaml_training_defaults = """
+datastore:
+ kind: mdp
+ config_path: ""
+"""
+
+default_config = nlconfig.NeuralLAMConfig(
+ datastore=nlconfig.DatastoreSelection(kind="mdp", config_path=""),
+ training=nlconfig.TrainingConfig(
+ state_feature_weighting=nlconfig.UniformFeatureWeighting()
+ ),
+)
+
+yaml_training_manual_weights = """
+datastore:
+ kind: mdp
+ config_path: ""
+training:
+ state_feature_weighting:
+ __config_class__: ManualStateFeatureWeighting
+ weights:
+ u100m: 1.0
+ v100m: 1.0
+"""
+
+manual_weights_config = nlconfig.NeuralLAMConfig(
+ datastore=nlconfig.DatastoreSelection(kind="mdp", config_path=""),
+ training=nlconfig.TrainingConfig(
+ state_feature_weighting=nlconfig.ManualStateFeatureWeighting(
+ weights=dict(u100m=1.0, v100m=1.0)
+ )
+ ),
+)
+
+yaml_samples = zip(
+ [yaml_training_defaults, yaml_training_manual_weights],
+ [default_config, manual_weights_config],
+)
+
+
+@pytest.mark.parametrize("yaml_str, config_expected", yaml_samples)
+def test_config_load_from_yaml(yaml_str, config_expected):
+ c = nlconfig.NeuralLAMConfig.from_yaml(yaml_str)
+ assert c == config_expected
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
new file mode 100644
index 00000000..419aece0
--- /dev/null
+++ b/tests/test_datasets.py
@@ -0,0 +1,261 @@
+# Standard library
+from pathlib import Path
+
+# Third-party
+import numpy as np
+import pytest
+import torch
+from torch.utils.data import DataLoader
+
+# First-party
+from neural_lam import config as nlconfig
+from neural_lam.create_graph import create_graph_from_datastore
+from neural_lam.datastore import DATASTORES
+from neural_lam.datastore.base import BaseRegularGridDatastore
+from neural_lam.models.graph_lam import GraphLAM
+from neural_lam.weather_dataset import WeatherDataset
+from tests.conftest import init_datastore_example
+from tests.dummy_datastore import DummyDatastore
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_dataset_item_shapes(datastore_name):
+ """Check that the `datastore.get_dataarray` method is implemented.
+
+ Validate the shapes of the tensors match between the different
+ components of the training sample.
+
+ init_states: (2, N_grid, d_features)
+ target_states: (ar_steps, N_grid, d_features)
+ forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,)
+
+ """
+ datastore = init_datastore_example(datastore_name)
+ N_gridpoints = datastore.num_grid_points
+
+ N_pred_steps = 4
+ num_past_forcing_steps = 1
+ num_future_forcing_steps = 1
+ dataset = WeatherDataset(
+ datastore=datastore,
+ split="train",
+ ar_steps=N_pred_steps,
+ num_past_forcing_steps=num_past_forcing_steps,
+ num_future_forcing_steps=num_future_forcing_steps,
+ )
+
+ item = dataset[0]
+
+ # unpack the item, this is the current return signature for
+ # WeatherDataset.__getitem__
+ init_states, target_states, forcing, target_times = item
+
+ # initial states
+ assert init_states.ndim == 3
+ assert init_states.shape[0] == 2 # two time steps go into the input
+ assert init_states.shape[1] == N_gridpoints
+ assert init_states.shape[2] == datastore.get_num_data_vars("state")
+
+ # output states
+ assert target_states.ndim == 3
+ assert target_states.shape[0] == N_pred_steps
+ assert target_states.shape[1] == N_gridpoints
+ assert target_states.shape[2] == datastore.get_num_data_vars("state")
+
+ # forcing
+ assert forcing.ndim == 3
+ assert forcing.shape[0] == N_pred_steps
+ assert forcing.shape[1] == N_gridpoints
+ assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * (
+ num_past_forcing_steps + num_future_forcing_steps + 1
+ )
+
+ # batch times
+ assert target_times.ndim == 1
+ assert target_times.shape[0] == N_pred_steps
+
+ # try to get the last item of the dataset to ensure slicing and stacking
+ # operations are working as expected and are consistent with the dataset
+ # length
+ dataset[len(dataset) - 1]
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_dataset_item_create_dataarray_from_tensor(datastore_name):
+ datastore = init_datastore_example(datastore_name)
+
+ N_pred_steps = 4
+ num_past_forcing_steps = 1
+ num_future_forcing_steps = 1
+ dataset = WeatherDataset(
+ datastore=datastore,
+ split="train",
+ ar_steps=N_pred_steps,
+ num_past_forcing_steps=num_past_forcing_steps,
+ num_future_forcing_steps=num_future_forcing_steps,
+ )
+
+ idx = 0
+
+ # unpack the item, this is the current return signature for
+ # WeatherDataset.__getitem__
+ _, target_states, _, target_times_arr = dataset[idx]
+ _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays(
+ idx=idx
+ )
+
+ target_times = np.array(target_times_arr, dtype="datetime64[ns]")
+ np.testing.assert_equal(target_times, da_target_times_true.values)
+
+ da_target = dataset.create_dataarray_from_tensor(
+ tensor=target_states, category="state", time=target_times
+ )
+
+ # conversion to torch.float32 may lead to loss of precision
+ np.testing.assert_allclose(
+ da_target.values, da_target_true.values, rtol=1e-6
+ )
+ assert da_target.dims == da_target_true.dims
+ for dim in da_target.dims:
+ np.testing.assert_equal(
+ da_target[dim].values, da_target_true[dim].values
+ )
+
+ if isinstance(datastore, BaseRegularGridDatastore):
+ # test unstacking the grid coordinates
+ da_target_unstacked = datastore.unstack_grid_coords(da_target)
+ assert all(
+ coord_name in da_target_unstacked.coords
+ for coord_name in ["x", "y"]
+ )
+
+ # check construction of a single time
+ da_target_single = dataset.create_dataarray_from_tensor(
+ tensor=target_states[0], category="state", time=target_times[0]
+ )
+
+ # check that the content is the same
+ # conversion to torch.float32 may lead to loss of precision
+ np.testing.assert_allclose(
+ da_target_single.values, da_target_true[0].values, rtol=1e-6
+ )
+ assert da_target_single.dims == da_target_true[0].dims
+ for dim in da_target_single.dims:
+ np.testing.assert_equal(
+ da_target_single[dim].values, da_target_true[0][dim].values
+ )
+
+ if isinstance(datastore, BaseRegularGridDatastore):
+ # test unstacking the grid coordinates
+ da_target_single_unstacked = datastore.unstack_grid_coords(
+ da_target_single
+ )
+ assert all(
+ coord_name in da_target_single_unstacked.coords
+ for coord_name in ["x", "y"]
+ )
+
+
+@pytest.mark.parametrize("split", ["train", "val", "test"])
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_single_batch(datastore_name, split):
+ """Check that the `datastore.get_dataarray` method is implemented.
+
+ And that it returns an xarray DataArray with the correct dimensions.
+
+ """
+ datastore = init_datastore_example(datastore_name)
+
+ device_name = (
+ torch.device("cuda") if torch.cuda.is_available() else "cpu"
+ ) # noqa
+
+ graph_name = "1level"
+
+ class ModelArgs:
+ output_std = False
+ loss = "mse"
+ restore_opt = False
+ n_example_pred = 1
+ graph = graph_name
+ hidden_dim = 4
+ hidden_layers = 1
+ processor_layers = 2
+ mesh_aggr = "sum"
+ num_past_forcing_steps = 1
+ num_future_forcing_steps = 1
+
+ args = ModelArgs()
+
+ graph_dir_path = Path(datastore.root_path) / "graph" / graph_name
+
+ def _create_graph():
+ if not graph_dir_path.exists():
+ create_graph_from_datastore(
+ datastore=datastore,
+ output_root_path=str(graph_dir_path),
+ n_max_levels=1,
+ )
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ with pytest.raises(NotImplementedError):
+ _create_graph()
+ pytest.skip("Skipping on model-run on non-regular grid datastores")
+
+ _create_graph()
+
+ config = nlconfig.NeuralLAMConfig(
+ datastore=nlconfig.DatastoreSelection(
+ kind=datastore.SHORT_NAME, config_path=datastore.root_path
+ )
+ )
+
+ dataset = WeatherDataset(datastore=datastore, split=split, ar_steps=2)
+
+ model = GraphLAM(args=args, datastore=datastore, config=config) # noqa
+
+ model_device = model.to(device_name)
+ data_loader = DataLoader(dataset, batch_size=2)
+ batch = next(iter(data_loader))
+ batch_device = [part.to(device_name) for part in batch]
+ model_device.common_step(batch_device)
+ model_device.training_step(batch_device)
+
+
+@pytest.mark.parametrize(
+ "dataset_config",
+ [
+ {"past": 0, "future": 0, "ar_steps": 1, "exp_len_reduction": 3},
+ {"past": 2, "future": 0, "ar_steps": 1, "exp_len_reduction": 3},
+ {"past": 0, "future": 2, "ar_steps": 1, "exp_len_reduction": 5},
+ {"past": 4, "future": 0, "ar_steps": 1, "exp_len_reduction": 5},
+ {"past": 0, "future": 0, "ar_steps": 5, "exp_len_reduction": 7},
+ {"past": 3, "future": 3, "ar_steps": 2, "exp_len_reduction": 8},
+ ],
+)
+def test_dataset_length(dataset_config):
+ """Check that correct number of samples can be extracted from the dataset,
+ given a specific configuration of forcing windowing and ar_steps.
+ """
+ # Use dummy datastore of length 10 here, only want to test slicing
+ # in dataset class
+ ds_len = 10
+ datastore = DummyDatastore(n_timesteps=ds_len)
+
+ dataset = WeatherDataset(
+ datastore=datastore,
+ split="train",
+ ar_steps=dataset_config["ar_steps"],
+ num_past_forcing_steps=dataset_config["past"],
+ num_future_forcing_steps=dataset_config["future"],
+ )
+
+ # We expect dataset to contain this many samples
+ expected_len = ds_len - dataset_config["exp_len_reduction"]
+
+ # Check that datast has correct length
+ assert len(dataset) == expected_len
+
+ # Check that we can actually get last and first sample
+ dataset[0]
+ dataset[expected_len - 1]
diff --git a/tests/test_datastores.py b/tests/test_datastores.py
new file mode 100644
index 00000000..4a4b1100
--- /dev/null
+++ b/tests/test_datastores.py
@@ -0,0 +1,384 @@
+"""List of methods and attributes that should be implemented in a subclass of
+`` (these are all decorated with `@abc.abstractmethod`):
+
+- `root_path` (property): Root path of the datastore.
+- `step_length` (property): Length of the time step in hours.
+- `grid_shape_state` (property): Shape of the grid for the state variables.
+- `get_xy` (method): Return the x, y coordinates of the dataset.
+- `coords_projection` (property): Projection object for the coordinates.
+- `get_vars_units` (method): Get the units of the variables in the given
+ category.
+- `get_vars_names` (method): Get the names of the variables in the given
+ category.
+- `get_vars_long_names` (method): Get the long names of the variables in
+ the given category.
+- `get_num_data_vars` (method): Get the number of data variables in the
+ given category.
+- `get_normalization_dataarray` (method): Return the normalization
+ dataarray for the given category.
+- `get_dataarray` (method): Return the processed data (as a single
+ `xr.DataArray`) for the given category and test/train/val-split.
+- `boundary_mask` (property): Return the boundary mask for the dataset,
+ with spatial dimensions stacked.
+- `config` (property): Return the configuration of the datastore.
+
+In addition BaseRegularGridDatastore must have the following methods and
+attributes:
+- `get_xy_extent` (method): Return the extent of the x, y coordinates for a
+ given category of data.
+- `get_xy` (method): Return the x, y coordinates of the dataset.
+- `coords_projection` (property): Projection object for the coordinates.
+- `grid_shape_state` (property): Shape of the grid for the state variables.
+- `stack_grid_coords` (method): Stack the grid coordinates of the dataset
+
+"""
+
+# Standard library
+import collections
+import dataclasses
+from pathlib import Path
+
+# Third-party
+import cartopy.crs as ccrs
+import numpy as np
+import pytest
+import torch
+import xarray as xr
+
+# First-party
+from neural_lam.datastore import DATASTORES
+from neural_lam.datastore.base import BaseRegularGridDatastore
+from neural_lam.datastore.plot_example import plot_example_from_datastore
+from tests.conftest import init_datastore_example
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_root_path(datastore_name):
+ """Check that the `datastore.root_path` property is implemented."""
+ datastore = init_datastore_example(datastore_name)
+ assert isinstance(datastore.root_path, Path)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_config(datastore_name):
+ """Check that the `datastore.config` property is implemented."""
+ datastore = init_datastore_example(datastore_name)
+ # check the config is a mapping or a dataclass
+ config = datastore.config
+ assert isinstance(
+ config, collections.abc.Mapping
+ ) or dataclasses.is_dataclass(config)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_step_length(datastore_name):
+ """Check that the `datastore.step_length` property is implemented."""
+ datastore = init_datastore_example(datastore_name)
+ step_length = datastore.step_length
+ assert isinstance(step_length, int)
+ assert step_length > 0
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_datastore_grid_xy(datastore_name):
+ """Use the `datastore.get_xy` method to get the x, y coordinates of the
+ dataset and check that the shape is correct against the `da
+ tastore.grid_shape_state` property."""
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ pytest.skip(
+ "Skip grid_shape_state test for non-regular grid datastores"
+ )
+
+ # check the shapes of the xy grid
+ grid_shape = datastore.grid_shape_state
+ nx, ny = grid_shape.x, grid_shape.y
+ for stacked in [True, False]:
+ xy = datastore.get_xy("static", stacked=stacked)
+ if stacked:
+ assert xy.shape == (nx * ny, 2)
+ else:
+ assert xy.shape == (nx, ny, 2)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_vars(datastore_name):
+ """Check that results of.
+
+ - `datastore.get_vars_units`
+ - `datastore.get_vars_names`
+ - `datastore.get_vars_long_names`
+ - `datastore.get_num_data_vars`
+
+ are consistent (as in the number of variables are the same) and that the
+ return types of each are correct.
+
+ """
+ datastore = init_datastore_example(datastore_name)
+
+ for category in ["state", "forcing", "static"]:
+ units = datastore.get_vars_units(category)
+ names = datastore.get_vars_names(category)
+ long_names = datastore.get_vars_long_names(category)
+ num_vars = datastore.get_num_data_vars(category)
+
+ assert len(units) == len(names) == num_vars
+ assert isinstance(units, list)
+ assert isinstance(names, list)
+ assert isinstance(long_names, list)
+ assert isinstance(num_vars, int)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_normalization_dataarray(datastore_name):
+ """Check that the `datastore.get_normalization_dataa rray` method is
+ implemented."""
+ datastore = init_datastore_example(datastore_name)
+
+ for category in ["state", "forcing", "static"]:
+ ds_stats = datastore.get_standardization_dataarray(category=category)
+
+ # check that the returned object is an xarray DataArray
+ # and that it has the correct variables
+ assert isinstance(ds_stats, xr.Dataset)
+
+ if category == "state":
+ ops = ["mean", "std", "diff_mean", "diff_std"]
+ elif category == "forcing":
+ ops = ["mean", "std"]
+ elif category == "static":
+ ops = []
+ else:
+ raise NotImplementedError(category)
+
+ for op in ops:
+ var_name = f"{category}_{op}"
+ assert var_name in ds_stats.data_vars
+ da_val = ds_stats[var_name]
+ assert set(da_val.dims) == {f"{category}_feature"}
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_dataarray(datastore_name):
+ """Check that the `datastore.get_dataarray` method is implemented.
+
+ And that it returns an xarray DataArray with the correct dimensions.
+
+ """
+
+ datastore = init_datastore_example(datastore_name)
+
+ for category in ["state", "forcing", "static"]:
+ n_features = {}
+ if category in ["state", "forcing"]:
+ splits = ["train", "val", "test"]
+ elif category == "static":
+ # static data should be the same for all splits, so split
+ # should be allowed to be None
+ splits = ["train", "val", "test", None]
+ else:
+ raise NotImplementedError(category)
+
+ for split in splits:
+ expected_dims = ["grid_index", f"{category}_feature"]
+ if category != "static":
+ if not datastore.is_forecast:
+ expected_dims.append("time")
+ else:
+ expected_dims += [
+ "analysis_time",
+ "elapsed_forecast_duration",
+ ]
+
+ if datastore.is_ensemble and category == "state":
+ # assume that only state variables change with ensemble members
+ expected_dims.append("ensemble_member")
+
+ # XXX: for now we only have a single attribute to get the shape of
+ # the grid which uses the shape from the "state" category, maybe
+ # this should change?
+
+ da = datastore.get_dataarray(category=category, split=split)
+
+ assert isinstance(da, xr.DataArray)
+ assert set(da.dims) == set(expected_dims)
+ if isinstance(datastore, BaseRegularGridDatastore):
+ grid_shape = datastore.grid_shape_state
+ assert da.grid_index.size == grid_shape.x * grid_shape.y
+
+ n_features[split] = da[category + "_feature"].size
+
+ # check that the number of features is the same for all splits
+ assert n_features["train"] == n_features["val"] == n_features["test"]
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_boundary_mask(datastore_name):
+ """Check that the `datastore.boundary_mask` property is implemented and
+ that the returned object is an xarray DataArray with the correct shape."""
+ datastore = init_datastore_example(datastore_name)
+ da_mask = datastore.boundary_mask
+
+ assert isinstance(da_mask, xr.DataArray)
+ assert set(da_mask.dims) == {"grid_index"}
+ assert da_mask.dtype == "int"
+ assert set(da_mask.values) == {0, 1}
+ assert da_mask.sum() > 0
+ assert da_mask.sum() < da_mask.size
+
+ if isinstance(datastore, BaseRegularGridDatastore):
+ grid_shape = datastore.grid_shape_state
+ assert datastore.boundary_mask.size == grid_shape.x * grid_shape.y
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_xy_extent(datastore_name):
+ """Check that the `datastore.get_xy_extent` method is implemented and that
+ the returned object is a tuple of the correct length."""
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ extents = {}
+ # get the extents for each category, and finally check they are all the same
+ for category in ["state", "forcing", "static"]:
+ extent = datastore.get_xy_extent(category)
+ assert isinstance(extent, list)
+ assert len(extent) == 4
+ assert all(isinstance(e, (int, float)) for e in extent)
+ extents[category] = extent
+
+ # check that the extents are the same for all categories
+ for category in ["forcing", "static"]:
+ assert extents["state"] == extents[category]
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_xy(datastore_name):
+ """Check that the `datastore.get_xy` method is implemented."""
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ for category in ["state", "forcing", "static"]:
+ xy_stacked = datastore.get_xy(category=category, stacked=True)
+ xy_unstacked = datastore.get_xy(category=category, stacked=False)
+
+ assert isinstance(xy_stacked, np.ndarray)
+ assert isinstance(xy_unstacked, np.ndarray)
+
+ nx, ny = datastore.grid_shape_state.x, datastore.grid_shape_state.y
+
+ # for stacked=True, the shape should be (n_grid_points, 2)
+ assert xy_stacked.ndim == 2
+ assert xy_stacked.shape[0] == nx * ny
+ assert xy_stacked.shape[1] == 2
+
+ # for stacked=False, the shape should be (nx, ny, 2)
+ assert xy_unstacked.ndim == 3
+ assert xy_unstacked.shape[0] == nx
+ assert xy_unstacked.shape[1] == ny
+ assert xy_unstacked.shape[2] == 2
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_get_projection(datastore_name):
+ """Check that the `datastore.coords_projection` property is implemented."""
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ assert isinstance(datastore.coords_projection, ccrs.Projection)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def get_grid_shape_state(datastore_name):
+ """Check that the `datastore.grid_shape_state` property is implemented."""
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ grid_shape = datastore.grid_shape_state
+ assert isinstance(grid_shape, tuple)
+ assert len(grid_shape) == 2
+ assert all(isinstance(e, int) for e in grid_shape)
+ assert all(e > 0 for e in grid_shape)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+@pytest.mark.parametrize("category", ["state", "forcing", "static"])
+def test_stacking_grid_coords(datastore_name, category):
+ """Check that the `datastore.stack_grid_coords` method is implemented."""
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ pytest.skip("Datastore does not implement `BaseCartesianDatastore`")
+
+ da_static = datastore.get_dataarray(category=category, split="train")
+
+ da_static_unstacked = datastore.unstack_grid_coords(da_static).load()
+ da_static_test = datastore.stack_grid_coords(da_static_unstacked)
+
+ assert da_static.dims == da_static_test.dims
+ xr.testing.assert_equal(da_static, da_static_test)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_dataarray_shapes(datastore_name):
+ datastore = init_datastore_example(datastore_name)
+ static_da = datastore.get_dataarray("static", split=None)
+ static_da = datastore.stack_grid_coords(static_da)
+ static_da = static_da.isel(static_feature=0)
+
+ # Convert the unstacked grid coordinates and static data array to tensors
+ unstacked_tensor = torch.tensor(
+ datastore.unstack_grid_coords(static_da).to_numpy(), dtype=torch.float32
+ ).squeeze()
+
+ reshaped_tensor = (
+ torch.tensor(static_da.to_numpy(), dtype=torch.float32)
+ .reshape(datastore.grid_shape_state.x, datastore.grid_shape_state.y)
+ .squeeze()
+ )
+
+ # Compute the difference
+ diff = unstacked_tensor - reshaped_tensor
+
+ # Check the shapes
+ assert unstacked_tensor.shape == (
+ datastore.grid_shape_state.x,
+ datastore.grid_shape_state.y,
+ )
+ assert reshaped_tensor.shape == (
+ datastore.grid_shape_state.x,
+ datastore.grid_shape_state.y,
+ )
+ assert diff.shape == (
+ datastore.grid_shape_state.x,
+ datastore.grid_shape_state.y,
+ )
+ # assert diff == 0 with tolerance 1e-6
+ assert torch.allclose(diff, torch.zeros_like(diff), atol=1e-6)
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_plot_example_from_datastore(datastore_name):
+ """Check that the `plot_example_from_datastore` function is implemented."""
+ datastore = init_datastore_example(datastore_name)
+ fig = plot_example_from_datastore(
+ category="static",
+ datastore=datastore,
+ col_dim="{category}_feature",
+ split="train",
+ standardize=True,
+ selection={},
+ index_selection={},
+ )
+
+ assert fig is not None
+ assert fig.get_axes()
diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py
new file mode 100644
index 00000000..93a7a55f
--- /dev/null
+++ b/tests/test_graph_creation.py
@@ -0,0 +1,119 @@
+# Standard library
+import tempfile
+from pathlib import Path
+
+# Third-party
+import pytest
+import torch
+
+# First-party
+from neural_lam.create_graph import create_graph_from_datastore
+from neural_lam.datastore import DATASTORES
+from neural_lam.datastore.base import BaseRegularGridDatastore
+from tests.conftest import init_datastore_example
+
+
+@pytest.mark.parametrize("graph_name", ["1level", "multiscale", "hierarchical"])
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_graph_creation(datastore_name, graph_name):
+ """Check that the `create_ graph_from_datastore` function is implemented.
+
+ And that the graph is created in the correct location.
+
+ """
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ pytest.skip(
+ f"Skipping test for {datastore_name} as it is not a regular "
+ "grid datastore."
+ )
+
+ if graph_name == "hierarchical":
+ hierarchical = True
+ n_max_levels = 3
+ elif graph_name == "multiscale":
+ hierarchical = False
+ n_max_levels = 3
+ elif graph_name == "1level":
+ hierarchical = False
+ n_max_levels = 1
+ else:
+ raise ValueError(f"Unknown graph_name: {graph_name}")
+
+ required_graph_files = [
+ "m2m_edge_index.pt",
+ "g2m_edge_index.pt",
+ "m2g_edge_index.pt",
+ "m2m_features.pt",
+ "g2m_features.pt",
+ "m2g_features.pt",
+ "mesh_features.pt",
+ ]
+ if hierarchical:
+ required_graph_files.extend(
+ [
+ "mesh_up_edge_index.pt",
+ "mesh_down_edge_index.pt",
+ "mesh_up_features.pt",
+ "mesh_down_features.pt",
+ ]
+ )
+
+ # TODO: check that the number of edges is consistent over the files, for
+ # now we just check the number of features
+ d_features = 3
+ d_mesh_static = 2
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ graph_dir_path = Path(tmpdir) / "graph" / graph_name
+
+ create_graph_from_datastore(
+ datastore=datastore,
+ output_root_path=str(graph_dir_path),
+ hierarchical=hierarchical,
+ n_max_levels=n_max_levels,
+ )
+
+ assert graph_dir_path.exists()
+
+ # check that all the required files are present
+ for file_name in required_graph_files:
+ assert (graph_dir_path / file_name).exists()
+
+ # try to load each and ensure they have the right shape
+ for file_name in required_graph_files:
+ file_id = Path(file_name).stem # remove the extension
+ result = torch.load(graph_dir_path / file_name)
+
+ if file_id.startswith("g2m") or file_id.startswith("m2g"):
+ assert isinstance(result, torch.Tensor)
+
+ if file_id.endswith("_index"):
+ assert (
+ result.shape[0] == 2
+ ) # adjacency matrix uses two rows
+ elif file_id.endswith("_features"):
+ assert result.shape[1] == d_features
+
+ elif file_id.startswith("m2m") or file_id.startswith("mesh"):
+ assert isinstance(result, list)
+ if not hierarchical:
+ assert len(result) == 1
+ else:
+ if file_id.startswith("mesh_up") or file_id.startswith(
+ "mesh_down"
+ ):
+ assert len(result) == n_max_levels - 1
+ else:
+ assert len(result) == n_max_levels
+
+ for r in result:
+ assert isinstance(r, torch.Tensor)
+
+ if file_id == "mesh_features":
+ assert r.shape[1] == d_mesh_static
+ elif file_id.endswith("_index"):
+ assert r.shape[0] == 2 # adjacency matrix uses two rows
+ elif file_id.endswith("_features"):
+ assert r.shape[1] == d_features
diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py
deleted file mode 100644
index 5c8b7aa1..00000000
--- a/tests/test_mllam_dataset.py
+++ /dev/null
@@ -1,142 +0,0 @@
-# Standard library
-import os
-from pathlib import Path
-
-# Third-party
-import pooch
-import pytest
-
-# First-party
-from neural_lam.config import Config
-from neural_lam.create_mesh import main as create_mesh
-from neural_lam.train_model import main as train_model
-from neural_lam.utils import load_static_data
-from neural_lam.weather_dataset import WeatherDataset
-
-# Disable weights and biases to avoid unnecessary logging
-# and to avoid having to deal with authentication
-os.environ["WANDB_DISABLED"] = "true"
-
-# Initializing variables for the s3 client
-S3_BUCKET_NAME = "mllam-testdata"
-S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int"
-S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip"
-S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH])
-TEST_DATA_KNOWN_HASH = (
- "98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7"
-)
-
-
-@pytest.fixture(scope="module")
-def meps_example_reduced_filepath():
- # Download and unzip test data into data/meps_example_reduced
- pooch.retrieve(
- url=S3_FULL_PATH,
- known_hash=TEST_DATA_KNOWN_HASH,
- processor=pooch.Unzip(extract_dir=""),
- path="data",
- fname="meps_example_reduced.zip",
- )
- return Path("data/meps_example_reduced")
-
-
-def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
- # The data_config.yaml file is downloaded and extracted in
- # test_retrieve_data_ewc together with the dataset itself
- data_config_file = meps_example_reduced_filepath / "data_config.yaml"
- dataset_name = meps_example_reduced_filepath.name
-
- dataset = WeatherDataset(dataset_name=dataset_name)
- config = Config.from_file(str(data_config_file))
-
- var_names = config.values["dataset"]["var_names"]
- var_units = config.values["dataset"]["var_units"]
- var_longnames = config.values["dataset"]["var_longnames"]
-
- assert len(var_names) == len(var_longnames)
- assert len(var_names) == len(var_units)
-
- # in future the number of grid static features
- # will be provided by the Dataset class itself
- n_grid_static_features = 4
- # Hardcoded in model
- n_input_steps = 2
-
- n_forcing_features = config.values["dataset"]["num_forcing_features"]
- n_state_features = len(var_names)
- n_prediction_timesteps = dataset.sample_length - n_input_steps
-
- nx, ny = config.values["grid_shape_state"]
- n_grid = nx * ny
-
- # check that the dataset is not empty
- assert len(dataset) > 0
-
- # get the first item
- init_states, target_states, forcing = dataset[0]
-
- # check that the shapes of the tensors are correct
- assert init_states.shape == (n_input_steps, n_grid, n_state_features)
- assert target_states.shape == (
- n_prediction_timesteps,
- n_grid,
- n_state_features,
- )
- assert forcing.shape == (
- n_prediction_timesteps,
- n_grid,
- n_forcing_features,
- )
-
- static_data = load_static_data(dataset_name=dataset_name)
-
- required_props = {
- "border_mask",
- "grid_static_features",
- "step_diff_mean",
- "step_diff_std",
- "data_mean",
- "data_std",
- "param_weights",
- }
-
- # check the sizes of the props
- assert static_data["border_mask"].shape == (n_grid, 1)
- assert static_data["grid_static_features"].shape == (
- n_grid,
- n_grid_static_features,
- )
- assert static_data["step_diff_mean"].shape == (n_state_features,)
- assert static_data["step_diff_std"].shape == (n_state_features,)
- assert static_data["data_mean"].shape == (n_state_features,)
- assert static_data["data_std"].shape == (n_state_features,)
- assert static_data["param_weights"].shape == (n_state_features,)
-
- assert set(static_data.keys()) == required_props
-
-
-def test_create_graph_reduced_meps_dataset():
- args = [
- "--graph=hierarchical",
- "--hierarchical",
- "--data_config=data/meps_example_reduced/data_config.yaml",
- "--levels=2",
- ]
- create_mesh(args)
-
-
-def test_train_model_reduced_meps_dataset():
- args = [
- "--model=hi_lam",
- "--data_config=data/meps_example_reduced/data_config.yaml",
- "--n_workers=4",
- "--epochs=1",
- "--graph=hierarchical",
- "--hidden_dim=16",
- "--hidden_layers=1",
- "--processor_layers=1",
- "--ar_steps=1",
- "--eval=val",
- "--n_example_pred=0",
- ]
- train_model(args)
diff --git a/tests/test_time_slicing.py b/tests/test_time_slicing.py
new file mode 100644
index 00000000..29161505
--- /dev/null
+++ b/tests/test_time_slicing.py
@@ -0,0 +1,146 @@
+# Third-party
+import numpy as np
+import pytest
+import xarray as xr
+
+# First-party
+from neural_lam.datastore.base import BaseDatastore
+from neural_lam.weather_dataset import WeatherDataset
+
+
+class SinglePointDummyDatastore(BaseDatastore):
+ step_length = 1
+ config = None
+ coords_projection = None
+ num_grid_points = 1
+ root_path = None
+
+ def __init__(self, time_values, state_data, forcing_data, is_forecast):
+ self._time_values = np.array(time_values)
+ self._state_data = np.array(state_data)
+ self._forcing_data = np.array(forcing_data)
+ self.is_forecast = is_forecast
+
+ if is_forecast:
+ assert self._state_data.ndim == 2
+ else:
+ assert self._state_data.ndim == 1
+
+ def get_num_data_vars(self, category):
+ return 1
+
+ def get_dataarray(self, category, split):
+ if category == "state":
+ values = self._state_data
+ elif category == "forcing":
+ values = self._forcing_data
+ else:
+ raise NotImplementedError(category)
+
+ if self.is_forecast:
+ raise NotImplementedError()
+ else:
+ da = xr.DataArray(
+ values, dims=["time"], coords={"time": self._time_values}
+ )
+ # add `{category}_feature` and `grid_index` dimensions
+
+ da = da.expand_dims("grid_index")
+ da = da.expand_dims(f"{category}_feature")
+
+ dim_order = self.expected_dim_order(category=category)
+ return da.transpose(*dim_order)
+
+ def get_standardization_dataarray(self, category):
+ raise NotImplementedError()
+
+ def get_xy(self, category):
+ raise NotImplementedError()
+
+ def get_vars_units(self, category):
+ raise NotImplementedError()
+
+ def get_vars_names(self, category):
+ raise NotImplementedError()
+
+ def get_vars_long_names(self, category):
+ raise NotImplementedError()
+
+
+ANALYSIS_STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+FORCING_VALUES = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+
+
+@pytest.mark.parametrize(
+ "ar_steps,num_past_forcing_steps,num_future_forcing_steps",
+ [[3, 0, 0], [3, 1, 0], [3, 2, 0], [3, 3, 0]],
+)
+def test_time_slicing_analysis(
+ ar_steps, num_past_forcing_steps, num_future_forcing_steps
+):
+ # state and forcing variables have only on dimension, `time`
+ time_values = np.datetime64("2020-01-01") + np.arange(
+ len(ANALYSIS_STATE_VALUES)
+ )
+ assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values)
+
+ datastore = SinglePointDummyDatastore(
+ state_data=ANALYSIS_STATE_VALUES,
+ forcing_data=FORCING_VALUES,
+ time_values=time_values,
+ is_forecast=False,
+ )
+
+ dataset = WeatherDataset(
+ datastore=datastore,
+ ar_steps=ar_steps,
+ num_future_forcing_steps=num_future_forcing_steps,
+ num_past_forcing_steps=num_past_forcing_steps,
+ standardize=False,
+ )
+
+ sample = dataset[0]
+
+ init_states, target_states, forcing, _ = [
+ tensor.numpy() for tensor in sample
+ ]
+
+ expected_init_states = [0, 1]
+ if ar_steps == 3:
+ expected_target_states = [2, 3, 4]
+ else:
+ raise NotImplementedError()
+
+ if num_past_forcing_steps == num_future_forcing_steps == 0:
+ expected_forcing_values = [[12], [13], [14]]
+ elif num_past_forcing_steps == 1 and num_future_forcing_steps == 0:
+ expected_forcing_values = [[11, 12], [12, 13], [13, 14]]
+ elif num_past_forcing_steps == 2 and num_future_forcing_steps == 0:
+ expected_forcing_values = [[10, 11, 12], [11, 12, 13], [12, 13, 14]]
+ elif num_past_forcing_steps == 3 and num_future_forcing_steps == 0:
+ expected_init_states = [1, 2]
+ expected_target_states = [3, 4, 5]
+ expected_forcing_values = [
+ [10, 11, 12, 13],
+ [11, 12, 13, 14],
+ [12, 13, 14, 15],
+ ]
+ else:
+ raise NotImplementedError()
+
+ # init_states: (2, N_grid, d_features)
+ # target_states: (ar_steps, N_grid, d_features)
+ # forcing: (ar_steps, N_grid, d_windowed_forcing)
+ # target_times: (ar_steps,)
+ assert init_states.shape == (2, 1, 1)
+ assert init_states[:, 0, 0].tolist() == expected_init_states
+
+ assert target_states.shape == (3, 1, 1)
+ assert target_states[:, 0, 0].tolist() == expected_target_states
+
+ assert forcing.shape == (
+ 3,
+ 1,
+ 1 + num_past_forcing_steps + num_future_forcing_steps,
+ )
+ np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values))
diff --git a/tests/test_training.py b/tests/test_training.py
new file mode 100644
index 00000000..1ed1847d
--- /dev/null
+++ b/tests/test_training.py
@@ -0,0 +1,103 @@
+# Standard library
+from pathlib import Path
+
+# Third-party
+import pytest
+import pytorch_lightning as pl
+import torch
+import wandb
+
+# First-party
+from neural_lam import config as nlconfig
+from neural_lam.create_graph import create_graph_from_datastore
+from neural_lam.datastore import DATASTORES
+from neural_lam.datastore.base import BaseRegularGridDatastore
+from neural_lam.models.graph_lam import GraphLAM
+from neural_lam.weather_dataset import WeatherDataModule
+from tests.conftest import init_datastore_example
+
+
+@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
+def test_training(datastore_name):
+ datastore = init_datastore_example(datastore_name)
+
+ if not isinstance(datastore, BaseRegularGridDatastore):
+ pytest.skip(
+ f"Skipping test for {datastore_name} as it is not a regular "
+ "grid datastore."
+ )
+
+ if torch.cuda.is_available():
+ device_name = "cuda"
+ torch.set_float32_matmul_precision(
+ "high"
+ ) # Allows using Tensor Cores on A100s
+ else:
+ device_name = "cpu"
+
+ trainer = pl.Trainer(
+ max_epochs=1,
+ deterministic=True,
+ accelerator=device_name,
+ # XXX: `devices` has to be set to 2 otherwise
+ # neural_lam.models.ar_model.ARModel.aggregate_and_plot_metrics fails
+ # because it expects to aggregate over multiple devices
+ devices=2,
+ log_every_n_steps=1,
+ )
+
+ graph_name = "1level"
+
+ graph_dir_path = Path(datastore.root_path) / "graph" / graph_name
+
+ if not graph_dir_path.exists():
+ create_graph_from_datastore(
+ datastore=datastore,
+ output_root_path=str(graph_dir_path),
+ n_max_levels=1,
+ )
+
+ data_module = WeatherDataModule(
+ datastore=datastore,
+ ar_steps_train=3,
+ ar_steps_eval=5,
+ standardize=True,
+ batch_size=2,
+ num_workers=1,
+ num_past_forcing_steps=1,
+ num_future_forcing_steps=1,
+ )
+
+ class ModelArgs:
+ output_std = False
+ loss = "mse"
+ restore_opt = False
+ n_example_pred = 1
+ # XXX: this should be superfluous when we have already defined the
+ # model object no?
+ graph = graph_name
+ hidden_dim = 4
+ hidden_layers = 1
+ processor_layers = 2
+ mesh_aggr = "sum"
+ lr = 1.0e-3
+ val_steps_to_log = [1, 3]
+ metrics_watch = []
+ num_past_forcing_steps = 1
+ num_future_forcing_steps = 1
+
+ model_args = ModelArgs()
+
+ config = nlconfig.NeuralLAMConfig(
+ datastore=nlconfig.DatastoreSelection(
+ kind=datastore.SHORT_NAME, config_path=datastore.root_path
+ )
+ )
+
+ model = GraphLAM( # noqa
+ args=model_args,
+ datastore=datastore,
+ config=config,
+ )
+ wandb.init()
+ trainer.fit(model=model, datamodule=data_module)