Skip to content

Commit

Permalink
Fix attribute keyword typo in save function (#35)
Browse files Browse the repository at this point in the history
* Fix attribute keyword typo in save function and add test for it

* Add changelog entry
  • Loading branch information
joeloskarsson authored Nov 11, 2024
1 parent c3ba212 commit 3378faa
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#19](https://github.com/mllam/weather-model-graphs/pull/19)
@joeloskarsson

### Fixed

- Fix `attribute` keyword bug in save function
[\#35](https://github.com/mllam/weather-model-graphs/pull/35)
@joeloskarsson

### Maintenance

- Ensure that cell execution doesn't time out when building jupyterbook based
Expand Down
2 changes: 1 addition & 1 deletion src/weather_model_graphs/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _concat_pyg_features(
try:
sub_graphs = list(
split_graph_by_edge_attribute(
graph=graph, attribute=list_from_attribute
graph=graph, attr=list_from_attribute
).values()
)
except MissingEdgeAttributeError:
Expand Down
11 changes: 9 additions & 2 deletions tests/test_save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tempfile

import numpy as np
import pytest
from loguru import logger

import weather_model_graphs as wmg
Expand All @@ -15,7 +16,8 @@ def _create_fake_xy(N=10):
return xy


def test_save_to_pyg():
@pytest.mark.parametrize("list_from_attribute", [None, "level"])
def test_save_to_pyg(list_from_attribute):
if not HAS_PYG:
logger.warning(
"Skipping test_save_to_pyg because weather-model-graphs[pytorch] is not installed."
Expand All @@ -40,4 +42,9 @@ def test_save_to_pyg():

with tempfile.TemporaryDirectory() as tmpdir:
for name, graph in graph_components.items():
wmg.save.to_pyg(graph=graph, output_directory=tmpdir, name=name)
wmg.save.to_pyg(
graph=graph,
output_directory=tmpdir,
name=name,
list_from_attribute=list_from_attribute,
)

0 comments on commit 3378faa

Please sign in to comment.