From 3378faa5e9a7b14798938506c6825df54c902c7c Mon Sep 17 00:00:00 2001 From: Joel Oskarsson Date: Mon, 11 Nov 2024 15:04:18 +0100 Subject: [PATCH] Fix `attribute` keyword typo in save function (#35) * Fix attribute keyword typo in save function and add test for it * Add changelog entry --- CHANGELOG.md | 6 ++++++ src/weather_model_graphs/save.py | 2 +- tests/test_save.py | 11 +++++++++-- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2bcd96..aefa9c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#19](https://github.com/mllam/weather-model-graphs/pull/19) @joeloskarsson +### Fixed + +- Fix `attribute` keyword bug in save function + [\#35](https://github.com/mllam/weather-model-graphs/pull/35) + @joeloskarsson + ### Maintenance - Ensure that cell execution doesn't time out when building jupyterbook based diff --git a/src/weather_model_graphs/save.py b/src/weather_model_graphs/save.py index 74e1a53..5253a6a 100644 --- a/src/weather_model_graphs/save.py +++ b/src/weather_model_graphs/save.py @@ -108,7 +108,7 @@ def _concat_pyg_features( try: sub_graphs = list( split_graph_by_edge_attribute( - graph=graph, attribute=list_from_attribute + graph=graph, attr=list_from_attribute ).values() ) except MissingEdgeAttributeError: diff --git a/tests/test_save.py b/tests/test_save.py index e6e311f..0edb940 100644 --- a/tests/test_save.py +++ b/tests/test_save.py @@ -1,6 +1,7 @@ import tempfile import numpy as np +import pytest from loguru import logger import weather_model_graphs as wmg @@ -15,7 +16,8 @@ def _create_fake_xy(N=10): return xy -def test_save_to_pyg(): +@pytest.mark.parametrize("list_from_attribute", [None, "level"]) +def test_save_to_pyg(list_from_attribute): if not HAS_PYG: logger.warning( "Skipping test_save_to_pyg because weather-model-graphs[pytorch] is not installed." @@ -40,4 +42,9 @@ def test_save_to_pyg(): with tempfile.TemporaryDirectory() as tmpdir: for name, graph in graph_components.items(): - wmg.save.to_pyg(graph=graph, output_directory=tmpdir, name=name) + wmg.save.to_pyg( + graph=graph, + output_directory=tmpdir, + name=name, + list_from_attribute=list_from_attribute, + )