Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix evaluation example visualisation plots #91

Merged
merged 10 commits into from
Dec 4, 2024

Conversation

leifdenby
Copy link
Member

@leifdenby leifdenby commented Nov 26, 2024

Describe your changes

Fix bugs in recently introduced datastore functionality #66 (error in calculation in BaseDatastore.get_xy_extent() and overlooked in-place modification of config dict in MDPDatastore.coords_projection), and also fix issue in ARModel.plot_examples by using newly introduced (#66) WeatherDataset.create_dataarray_from_tensor() to create xr.DataArray from prediction tensor and calling plot methods directly on xr.DataArray rather than using bare numpy arrays with matplotlib.

Without the latter fix @khintz noticed that the example plots are incorrect (the field is transposed, and the interior rather than the boundary is alpha blended):

Screenshot 2024-11-25 at 14 18 12

Steps to reproduce (eval CLI run is not currently part of testing suite):

# train model
WANDB_DISABLED=1 pdm run python -m neural_lam.train_model --config_path tests/datastore_examples/mdp/danra_100m_winds/config.yaml --hidden_dim 2 --epochs 1 --ar_steps_train 1 --ar_steps_eval 1
# run eval with trained model
WANDB_DISABLED=1 pdm run python -m neural_lam.train_model --config_path tests/datastore_examples/mdp/danra_100m_winds/config.yaml --hidden_dim 2 --epochs 1 --ar_steps_train 1 --ar_steps_eval 1 --eval val --load saved_models/train-graph_lam-4x2-11_25_15-9215/min_val_loss.ckpt --val_steps_to_log 1

After the fixes in this PR the eval plot is as follows (not same sample as above, but exemplary):

example_0_1_t1

NB: This PR introduces the creation of a WeatherDataset instance on every tensor to xr.DataArray creation in ARModel. This isn't ideal, but since we will soon be refactoring the whole plotting functionality in neural-lam I thought this would be a step in that direction (since to me at least plotting from xr.DataArray objects rather than bare numpy arrays with matplotlib is a lot less error-prone).

No change of dependencies for these fixes.

Issue Link

No related link, issue was brought up on slack

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • author has added an entry to the changelog (and designated the change as added, changed or fixed)
  • Once the PR is ready to be merged, squash commits and merge the PR.

@leifdenby leifdenby added this to the v0.3.0 (proposed) milestone Nov 26, 2024
neural_lam/vis.py Outdated Show resolved Hide resolved
neural_lam/vis.py Outdated Show resolved Hide resolved
@khintz
Copy link
Contributor

khintz commented Nov 28, 2024

When --ar_steps_eval > 1 it breaks.

ValueError: A 3-dimensional array was passed to imshow(), but there is no dimension that could be color.  At least one dimension must be of size 3 (RGB) or 4 (RGBA), and not given as x or y.

A solution could be to use .isel() with the time index in ar_model.py:

            # Iterate over prediction horizon time steps
            for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1):
                # Create one figure per variable at this time step
                var_figs = [
                    vis.plot_prediction(
                        datastore=self._datastore,
                        title=f"{var_name} ({var_unit}), "
                        f"t={t_i} ({self._datastore.step_length * t_i} h)",
                        vrange=var_vrange,
                        da_prediction=da_prediction.isel(
                            state_feature=var_i,
                            time=t_i - 1
                        ).squeeze(),
                        da_target=da_target.isel(
                            state_feature=var_i,
                            time=t_i - 1
                        ).squeeze(),
                    )
                    for var_i, (var_name, var_unit, var_vrange) in enumerate(
                        zip(
                            self._datastore.get_vars_names("state"),
                            self._datastore.get_vars_units("state"),
                            var_vranges,
                        )
                    )
                ]

@leifdenby
Copy link
Member Author

A solution could be to use .isel() with the time index in ar_model.py:

Thanks for supplying me with a fix! Addressed in leifdenby@52c4528

@leifdenby leifdenby added the bug Something isn't working label Nov 29, 2024
@khintz
Copy link
Contributor

khintz commented Dec 2, 2024

I am re-running two tests, which failed because there were no online runners for some reason.
Please also remember to update the changelog with an entry.

@leifdenby
Copy link
Member Author

I am re-running two tests, which failed because there were no online runners for some reason. Please also remember to update the changelog with an entry.

It looks like this might be an issue with dataclass-wizard. I can see the CPU-based test runs are ending up with version 0.30.1 (https://github.com/mllam/neural-lam/actions/runs/12086137733/job/33704822660?pr=91#step:6:114) whereas the GPU-based test runs are using 0.32.0 (https://github.com/mllam/neural-lam/actions/runs/12086137715/job/33769807411?pr=91#step:8:126). I will try and see if I can reproduce the issue locally with 0.32.0

@leifdenby
Copy link
Member Author

Yes, there is an issue with 0.32.0 of dataclass-wizard:

Traceback
Adding packages to default dependencies: dataclass-wizard==0.32.0
  0:00:07 🔒 Lock successful.
Changes are written to pyproject.toml.
Synchronizing working set with resolved packages: 0 to add, 1 to update, 0 to remove

✔ Update dataclass-wizard 0.28.0 -> 0.32.0 successful
✔ Update neural-lam 0.2.0 -> 0.2.0 successful

0:00:00 🎉 All complete! 1/1
INFO: PDM 2.19.1 is installed, while 2.21.0 is available.
Please run pdm self update to upgrade.
Run pdm config check_update false to disable the check.
✝  git-repos/mllam/neural-lam   fix/eval-vis-plots±  pdm run python -m pytest -s --sw -vv -k mdp
===================================================================================================================================================== test session starts ======================================================================================================================================================
platform darwin -- Python 3.12.4, pytest-8.3.3, pluggy-1.5.0 -- /Users/B280936/git-repos/mllam/neural-lam/.venv/bin/python
cachedir: .pytest_cache
rootdir: /Users/B280936/git-repos/mllam/neural-lam
configfile: pyproject.toml
collected 91 items / 64 deselected / 27 selected
stepwise: no previously failed tests, not skipping.

tests/test_config.py::test_config_load_from_yaml[\ndatastore:\n kind: mdp\n config_path: ""\n-config_expected0] PASSED
tests/test_config.py::test_config_load_from_yaml[\ndatastore:\n kind: mdp\n config_path: ""\ntraining:\n state_feature_weighting:\n config_class: ManualStateFeatureWeighting\n weights:\n u100m: 1.0\n v100m: 1.0\n-config_expected1] PASSED
tests/test_datasets.py::test_dataset_item_shapes[mdp] FAILED

=========================================================================================================================================================== FAILURES ===========================================================================================================================================================
________________________________________________________________________________________________________________________________________________ test_dataset_item_shapes[mdp] _________________________________________________________________________________________________________________________________________________

cls = <class 'mllam_data_prep.config.Config'>, d = {'dataset_version': 'v0.1.0', 'extra': {'projection': {'class_name': 'LambertConformal', 'kwargs': {'central_latitude'...forcing_feature'], 'state': ['time', 'grid_index', 'state_feature'], 'static': ['grid_index', 'static_feature']}}, ...}

def fromdict(cls: Type[T], d: JSONObject) -> T:
    """
    Converts a Python dictionary object to a dataclass instance.

    Iterates over each dataclass field recursively; lists, dicts, and nested
    dataclasses will likewise be initialized as expected.

    When directly invoking this function, an optional Meta configuration for
    the dataclass can be specified via ``LoadMeta``; by default, this will
    apply recursively to any nested dataclasses. Here's a sample usage of this
    below::

        >>> LoadMeta(key_transform='CAMEL').bind_to(MyClass)
        >>> fromdict(MyClass, {"myStr": "value"})

    """
    try:
      load = CLASS_TO_LOAD_FUNC[cls]

E KeyError: <class 'mllam_data_prep.config.Config'>

.venv/lib/python3.12/site-packages/dataclass_wizard/loaders.py:592: KeyError

During handling of the above exception, another exception occurred:

cls = <class 'abc.ConfigLoadMixin'>, ann_type = <class 'mllam_data_prep.config.Output'>, base_cls = <class 'mllam_data_prep.config.Config'>, extras = {'config': <class 'abc.Meta'>}

@classmethod
def get_parser_for_annotation(cls, ann_type: Type[T],
                              base_cls: Type = None,
                              extras: Extras = None) -> 'AbstractParser | Callable[[dict[str, Any]], T]':
    """Returns the Parser (dispatcher) for a given annotation type."""
    hooks = cls.__LOAD_HOOKS__
    ann_type = eval_forward_ref_if_needed(ann_type, base_cls)
    load_hook = hooks.get(ann_type)
    base_type = ann_type

    # TODO: I'll need to refactor the code below to remove the nested `if`
    #   statements, when time allows. Right now the branching logic is
    #   unseemly and there's really no need for that, as any such
    #   performance gains (if they do exist) are minimal at best.

    if 'pattern' in extras and is_subclass_safe(
            ann_type, (date, time, datetime)):
        # Check for a field that was initially annotated like:
        #   Annotated[List[time], Pattern('%H:%M:%S')]
        return PatternedDTParser(base_cls, extras, base_type)

    if load_hook is None:
        # Need to check this first, because the `Literal` type in Python
        # 3.6 behaves a bit differently (doesn't have an `__origin__`
        # attribute for example)
        if is_literal(ann_type):
            return LiteralParser(base_cls, extras, ann_type)

        if is_annotated(ann_type):
            # Given `Annotated[T, MaxValue(10), ...]`, we only need `T`
            ann_type = get_args(ann_type)[0]
            return cls.get_parser_for_annotation(
                ann_type, base_cls, extras)

        # This property will be available for most generic types in the
        # `typing` library.
        try:
          base_type = get_origin(ann_type, raise_=True)

.venv/lib/python3.12/site-packages/dataclass_wizard/loaders.py:300:


.venv/lib/python3.12/site-packages/dataclass_wizard/utils/typing_compat.py:170: in get_origin
return get_origin(cls, raise=raise_)


cls = <class 'mllam_data_prep.config.Output'>, raise_ = True

def _get_origin(cls, raise_=False):
    if isinstance(cls, types.UnionType):
        return typing.Union

    try:
      return cls.__origin__

E AttributeError: type object 'Output' has no attribute 'origin'

.venv/lib/python3.12/site-packages/dataclass_wizard/utils/typing_compat.py:82: AttributeError

During handling of the above exception, another exception occurred:

datastore_name = 'mdp'

@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_dataset_item_shapes(datastore_name):
    """Check that the `datastore.get_dataarray` method is implemented.

    Validate the shapes of the tensors match between the different
    components of the training sample.

    init_states: (2, N_grid, d_features)
    target_states: (ar_steps, N_grid, d_features)
    forcing: (ar_steps, N_grid, d_windowed_forcing) # batch_times: (ar_steps,)

    """
  datastore = init_datastore_example(datastore_name)

tests/test_datasets.py:33:


tests/conftest.py:101: in init_datastore_example
datastore = init_datastore(
neural_lam/datastore/init.py:24: in init_datastore
datastore = DatastoreClass(config_path=config_path)
neural_lam/datastore/mdp.py:54: in init
self._config = mdp.Config.from_yaml_file(self._config_path)
.venv/lib/python3.12/site-packages/dataclass_wizard/wizard_mixins.py:266: in from_yaml_file
return cls.from_yaml(in_file, decoder=decoder,
.venv/lib/python3.12/site-packages/dataclass_wizard/wizard_mixins.py:255: in from_yaml
return fromdict(cls, o) if isinstance(o, dict) else fromlist(cls, o)
.venv/lib/python3.12/site-packages/dataclass_wizard/loaders.py:594: in fromdict
load = load_func_for_dataclass(cls)
.venv/lib/python3.12/site-packages/dataclass_wizard/loaders.py:650: in load_func_for_dataclass
field_to_parser = dataclass_field_to_load_parser(cls_loader, cls, config)
.venv/lib/python3.12/site-packages/dataclass_wizard/class_helper.py:131: in dataclass_field_to_load_parser
return _setup_load_config_for_cls(cls_loader, cls, config, save)
.venv/lib/python3.12/site-packages/dataclass_wizard/class_helper.py:211: in _setup_load_config_for_cls
name_to_parser[f.name] = getattr(p := cls_loader.get_parser_for_annotation(
.venv/lib/python3.12/site-packages/dataclass_wizard/loaders.py:326: in get_parser_for_annotation
return cls.load_func_for_dataclass(
.venv/lib/python3.12/site-packages/dataclass_wizard/loaders.py:260: in load_func_for_dataclass
return load_func_for_dataclass(
.venv/lib/python3.12/site-packages/dataclass_wizard/loaders.py:650: in load_func_for_dataclass
field_to_parser = dataclass_field_to_load_parser(cls_loader, cls, config)
.venv/lib/python3.12/site-packages/dataclass_wizard/class_helper.py:131: in dataclass_field_to_load_parser
return _setup_load_config_for_cls(cls_loader, cls, config, save)
.venv/lib/python3.12/site-packages/dataclass_wizard/class_helper.py:211: in _setup_load_config_for_cls
name_to_parser[f.name] = getattr(p := cls_loader.get_parser_for_annotation(
.venv/lib/python3.12/site-packages/dataclass_wizard/loaders.py:416: in get_parser_for_annotation
return MappingParser(
:5: in init
???


self = MappingParser(base_type=<class 'dict'>, hook=<function LoadMixin.load_to_dict at 0x146b00cc0>), cls = <class 'mllam_data_prep.config.Output'>, extras = {'config': <class 'abc.Meta'>}, get_parser = <bound method LoadMixin.get_parser_for_annotation of <class 'abc.OutputLoadMixin'>>

def __post_init__(self, cls: Type,
                  extras: Extras,
                  get_parser: GetParserType):
    try:
        key_type, val_type = get_args(self.base_type)
    except ValueError:
        key_type = val_type = Any

    # Base type of the object which is instantiable
    #   ex. `Dict[str, Any]` -> `dict`
    self.base_type: Type[M] = get_origin(self.base_type)

    val_parser = get_parser(val_type, cls, extras)

    self.key_parser = getattr(p := get_parser(key_type, cls, extras), '__call__', p)
    self.val_parser = getattr(val_parser, '__call__', val_parser)
  self.val_base_type = val_parser.base_type

E AttributeError: 'function' object has no attribute 'base_type'

.venv/lib/python3.12/site-packages/dataclass_wizard/parsers.py:559: AttributeError
=================================================================================================================================================== short test summary info ====================================================================================================================================================
FAILED tests/test_datasets.py::test_dataset_item_shapes[mdp] - AttributeError: 'function' object has no attribute 'base_type'

I suggest we skip 0.32.0 for now

@joeloskarsson
Copy link
Collaborator

This is all good and ready for merge now right, @khintz ?

@khintz
Copy link
Contributor

khintz commented Dec 3, 2024

Nearly, I need @leifdenby to push an entry to the changelog :) Other than that, all looks good now.

@leifdenby
Copy link
Member Author

Nearly, I need @leifdenby to push an entry to the changelog :) Other than that, all looks good now.

Done! Is this detailed enough for the changelog? I don't think adding too much detail in the changelog is a good idea, but I could list the bugs I fixed.

@khintz
Copy link
Contributor

khintz commented Dec 3, 2024

Great! I'm happy.

@leifdenby leifdenby merged commit 71cfdf9 into mllam:main Dec 4, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants