Skip to content

Commit

Permalink
initial attempt at #531 (for spatial averaging)
Browse files Browse the repository at this point in the history
  • Loading branch information
pochedls authored and tomvothecoder committed Nov 21, 2024
1 parent 27396e5 commit ea08915
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
34 changes: 34 additions & 0 deletions tests/test_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims(
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights)

def test_raises_error_if_required_weight_not_between_zero_and_one(
self,
):
# ensure error if required_weight less than zero
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=-0.01)

# ensure error if required_weight greater than 1
with pytest.raises(ValueError):
self.ds.spatial.average("ts", axis=["X", "Y"], required_weight=1.01)

def test_spatial_average_for_lat_region_and_keep_weights(self):
ds = self.ds.copy()

Expand Down Expand Up @@ -254,6 +265,29 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self):

xr.testing.assert_allclose(result, expected)

def test_spatial_average_with_required_weight(self):
ds = self.ds.copy()

# insert a nan
ds["ts"][0, :, 2] = np.nan

result = ds.spatial.average(
"ts",
axis=["X", "Y"],
lat_bounds=(-5.0, 5),
lon_bounds=(-170, -120.1),
required_weight=1.0,
)

expected = self.ds.copy()
expected["ts"] = xr.DataArray(
data=np.array([np.nan, 1.0, 1.0]),
coords={"time": expected.time},
dims="time",
)

xr.testing.assert_allclose(result, expected)

def test_spatial_average_for_lat_and_lon_region_with_custom_weights(self):
ds = self.ds.copy()

Expand Down
51 changes: 49 additions & 2 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def average(
keep_weights: bool = False,
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
required_weight: Optional[float] = 0.0,
) -> xr.Dataset:
"""
Calculates the spatial average for a rectilinear grid over an optionally
Expand Down Expand Up @@ -125,6 +126,9 @@ def average(
ignored if ``weights`` are supplied. The lower bound can be larger
than the upper bound (e.g., across the prime meridian, dateline), by
default None.
required_weight : optional, float
Fraction of data coverage (i..e, weight) needed to return a
spatial average value. Value must range from 0 to 1.
Returns
-------
Expand Down Expand Up @@ -196,7 +200,7 @@ def average(
self._weights = weights

self._validate_weights(dv, axis)
ds[dv.name] = self._averager(dv, axis)
ds[dv.name] = self._averager(dv, axis, required_weight=required_weight)

if keep_weights:
ds[self._weights.name] = self._weights
Expand Down Expand Up @@ -702,7 +706,10 @@ def _validate_weights(
)

def _averager(
self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...]
self,
data_var: xr.DataArray,
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
required_weight: Optional[float] = 0.0,
):
"""Perform a weighted average of a data variable.
Expand All @@ -721,6 +728,9 @@ def _averager(
Data variable inside a Dataset.
axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
List of axis dimensions to average over.
required_weight : optional, float
Fraction of data coverage (i..e, weight) needed to return a
spatial average value. Value must range from 0 to 1.
Returns
-------
Expand All @@ -734,11 +744,48 @@ def _averager(
"""
weights = self._weights.fillna(0)

# ensure required weight is between 0 and 1
if required_weight is None:
required_weight = 0.0

if required_weight < 0.0:
raise ValueError(
"required_weight argment is less than zero. "
"required_weight must be between 0 and 1."
)

if required_weight > 1.0:
raise ValueError(
"required_weight argment is greater than zero. "
"required_weight must be between 0 and 1."
)

# need weights to match data_var dimensionality
if required_weight > 0.0:
weights, data_var = xr.broadcast(weights, data_var)

# get averaging dimensions
dim = []
for key in axis:
dim.append(get_dim_keys(data_var, key))

# compute weighed mean
with xr.set_options(keep_attrs=True):
weighted_mean = data_var.cf.weighted(weights).mean(dim=dim)

# if weight thresholds applied, calculate fraction of data availability
# replace values that do not meet minimum weight with nan
if required_weight > 0.0:
# sum all weights (assuming no missing values exist)
print(dim)
weight_sum_all = weights.sum(dim=dim) # type: ignore
# zero out cells with missing values in data_var
weights = xr.where(~np.isnan(data_var), weights, 0)
# sum all weights (including zero for missing values)
weight_sum_masked = weights.sum(dim=dim) # type: ignore
# get fraction of weight available
frac = weight_sum_masked / weight_sum_all
# nan out values that don't meet specified weight threshold
weighted_mean = xr.where(frac >= required_weight, weighted_mean, np.nan)

return weighted_mean

0 comments on commit ea08915

Please sign in to comment.