Skip to content

Commit

Permalink
Add test and code cleanup
Browse files Browse the repository at this point in the history
- Fix `_get_bounds_ensure_dtype` to determine `bounds` with axis that has `standard_name` attr (in addition to `axis` attr check)
- Remove unused `dst_lat_bnds` and `dst_lon_bnds` args for `_build_dataset()`
- Add unit test to cover `ValueError` in `regrid2.py` `_build_dataset()`
  • Loading branch information
tomvothecoder committed Dec 4, 2024
1 parent e067a6c commit e2d259e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
12 changes: 12 additions & 0 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,18 @@ def test_unknown_variable(self):
with pytest.raises(KeyError):
regridder.horizontal("unknown", self.coarse_2d_ds)

def test_raises_error_if_axis_name_for_dim_cannot_be_determined(self):
ds = self.coarse_2d_ds.copy()
ds["lat"].attrs["standard_name"] = "latitude"
ds["lat"].attrs.pop("axis")

regridder = regrid2.Regrid2Regridder(ds, self.fine_2d_ds)

with pytest.raises(
ValueError, match="Could not determine axis name for dimension"
):
regridder.horizontal("ts", ds)

@pytest.mark.filterwarnings("ignore:.*invalid value.*true_divide.*:RuntimeWarning")
def test_regrid_input_mask(self):
regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)
Expand Down
34 changes: 20 additions & 14 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import xarray as xr

from xcdat.axis import get_dim_keys
from xcdat.axis import CF_ATTR_MAP, get_dim_keys
from xcdat.regridder.base import BaseRegridder, _preserve_bounds


Expand Down Expand Up @@ -105,8 +105,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
ds,
data_var,
output_data,
dst_lat_bnds,
dst_lon_bnds,
self._input_grid,
self._output_grid,
)
Expand Down Expand Up @@ -228,8 +226,6 @@ def _build_dataset(
ds: xr.Dataset,
data_var: str,
output_data: np.ndarray,
dst_lat_bnds,
dst_lon_bnds,
input_grid: xr.Dataset,
output_grid: xr.Dataset,
) -> xr.Dataset:
Expand All @@ -242,11 +238,13 @@ def _build_dataset(
dim = str(dim)

try:
axis_name = [x for x, y in ds.cf.axes.items() if dim in y][0]
except Exception:
axis_name = [
cf_axis for cf_axis, dims in ds.cf.axes.items() if dim in dims
][0]
except IndexError as e:
raise ValueError(
f"Could not determine axis name for dimension {dim}"
) from None
) from e

if axis_name in ["X", "Y"]:
output_coords[dim] = output_grid.cf[axis_name]
Expand Down Expand Up @@ -566,12 +564,20 @@ def _get_dimension(input_data_var, cf_axis_name):


def _get_bounds_ensure_dtype(ds, axis):
try:
name = ds.cf.bounds[axis][0]
except (KeyError, IndexError) as e:
raise RuntimeError(f"Could not determine {axis!r} bounds") from e
else:
bounds = ds[name]
cf_keys = CF_ATTR_MAP[axis].values()

bounds = None

for key in cf_keys:
try:
name = ds.cf.bounds[key][0]
except (KeyError, IndexError):
pass
else:
bounds = ds[name]

if bounds is None:
raise RuntimeError(f"Could not determine {axis!r} bounds")

if bounds.dtype != np.float32:
bounds = bounds.astype(np.float32)
Expand Down

0 comments on commit e2d259e

Please sign in to comment.