From 636366b2ae3bd19c0cbe2846b8dc8e93dcf2140c Mon Sep 17 00:00:00 2001 From: Tom Vo Date: Thu, 5 Sep 2024 10:46:47 -0700 Subject: [PATCH] Updates from code review - Rename arg `minimum_weight` to `min_weight` - Add `_get_masked_weights()` and `_validate_min_weight()` to `utils.py` - Update `SpatialAccessor` to use `_get_masked_weights()` and `_validate_min_weight()` - Replace type annotation `Optional` with `|` - Extract `_mask_var_with_with_threshold()` from `_averager()` for readability --- tests/test_spatial.py | 18 ++--- tests/test_utils.py | 22 +++++- xcdat/spatial.py | 160 +++++++++++++++++++++++++----------------- xcdat/utils.py | 60 ++++++++++++++++ 4 files changed, 185 insertions(+), 75 deletions(-) diff --git a/tests/test_spatial.py b/tests/test_spatial.py index ea2ad27f..244bcf31 100644 --- a/tests/test_spatial.py +++ b/tests/test_spatial.py @@ -140,16 +140,16 @@ def test_raises_error_if_weights_lat_and_lon_dims_dont_align_with_data_var_dims( with pytest.raises(ValueError): self.ds.spatial.average("ts", axis=["X", "Y"], weights=weights) - def test_raises_error_if_minimum_weight_not_between_zero_and_one( + def test_raises_error_if_min_weight_not_between_zero_and_one( self, ): - # ensure error if minimum_weight less than zero + # ensure error if min_weight less than zero with pytest.raises(ValueError): - self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=-0.01) + self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=-0.01) - # ensure error if minimum_weight greater than 1 + # ensure error if min_weight greater than 1 with pytest.raises(ValueError): - self.ds.spatial.average("ts", axis=["X", "Y"], minimum_weight=1.01) + self.ds.spatial.average("ts", axis=["X", "Y"], min_weight=1.01) def test_spatial_average_for_lat_region_and_keep_weights(self): ds = self.ds.copy() @@ -265,7 +265,7 @@ def test_spatial_average_for_lat_and_lon_region_and_keep_weights(self): xr.testing.assert_allclose(result, expected) - def test_spatial_average_with_minimum_weight(self): + def test_spatial_average_with_min_weight(self): ds = self.ds.copy() # insert a nan @@ -276,7 +276,7 @@ def test_spatial_average_with_minimum_weight(self): axis=["X", "Y"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1), - minimum_weight=1.0, + min_weight=1.0, ) expected = self.ds.copy() @@ -288,7 +288,7 @@ def test_spatial_average_with_minimum_weight(self): xr.testing.assert_allclose(result, expected) - def test_spatial_average_with_minimum_weight_as_None(self): + def test_spatial_average_with_min_weight_as_None(self): ds = self.ds.copy() result = ds.spatial.average( @@ -296,7 +296,7 @@ def test_spatial_average_with_minimum_weight_as_None(self): axis=["X", "Y"], lat_bounds=(-5.0, 5), lon_bounds=(-170, -120.1), - minimum_weight=None, + min_weight=None, ) expected = self.ds.copy() diff --git a/tests/test_utils.py b/tests/test_utils.py index 1d4dcbe8..30d3cbfb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ import pytest import xarray as xr -from xcdat.utils import compare_datasets, str_to_bool +from xcdat.utils import _validate_min_weight, compare_datasets, str_to_bool class TestCompareDatasets: @@ -103,3 +103,23 @@ def test_raises_error_if_str_is_not_a_python_bool(self): with pytest.raises(ValueError): str_to_bool("1") + + +class TestValidateMinWeight: + def test_pass_None_returns_0(self): + result = _validate_min_weight(None) + + assert result == 0 + + def test_returns_error_if_less_than_0(self): + with pytest.raises(ValueError): + _validate_min_weight(-1) + + def test_returns_error_if_greater_than_1(self): + with pytest.raises(ValueError): + _validate_min_weight(1.1) + + def test_returns_valid_min_weight(self): + result = _validate_min_weight(1) + + assert result == 1 diff --git a/xcdat/spatial.py b/xcdat/spatial.py index 7dc4075f..20d8cc0c 100644 --- a/xcdat/spatial.py +++ b/xcdat/spatial.py @@ -1,5 +1,8 @@ """Module containing geospatial averaging functions.""" +<<<<<<< HEAD +======= +>>>>>>> 34b570d6 (Updates from code review) from __future__ import annotations from functools import reduce @@ -9,7 +12,6 @@ Hashable, List, Literal, - Optional, Tuple, TypedDict, Union, @@ -27,7 +29,11 @@ get_dim_keys, ) from xcdat.dataset import _get_data_var -from xcdat.utils import _if_multidim_dask_array_then_load +from xcdat.utils import ( + _get_masked_weights, + _if_multidim_dask_array_then_load, + _validate_min_weight, +) #: Type alias for a dictionary of axis keys mapped to their bounds. AxisWeights = Dict[Hashable, xr.DataArray] @@ -74,9 +80,9 @@ def average( axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"), weights: Union[Literal["generate"], xr.DataArray] = "generate", keep_weights: bool = False, - lat_bounds: Optional[RegionAxisBounds] = None, - lon_bounds: Optional[RegionAxisBounds] = None, - minimum_weight: Optional[float] = None, + lat_bounds: RegionAxisBounds | None = None, + lon_bounds: RegionAxisBounds | None = None, + min_weight: float | None = None, ) -> xr.Dataset: """ Calculates the spatial average for a rectilinear grid over an optionally @@ -115,21 +121,21 @@ def average( keep_weights : bool, optional If calculating averages using weights, keep the weights in the final dataset output, by default False. - lat_bounds : Optional[RegionAxisBounds], optional + lat_bounds : RegionAxisBounds | None, optional A tuple of floats/ints for the regional latitude lower and upper boundaries. This arg is used when calculating axis weights, but is ignored if ``weights`` are supplied. The lower bound cannot be larger than the upper bound, by default None. - lon_bounds : Optional[RegionAxisBounds], optional + lon_bounds : RegionAxisBounds | None, optional A tuple of floats/ints for the regional longitude lower and upper boundaries. This arg is used when calculating axis weights, but is ignored if ``weights`` are supplied. The lower bound can be larger than the upper bound (e.g., across the prime meridian, dateline), by default None. - minimum_weight : optional, float - Fraction of data coverage (i..e, weight) needed to return a + min_weight : optional, float + Fraction of data coverage (i.e, weight) needed to return a spatial average value. Value must range from 0 to 1, by default None - (equivalent to minimum_weight=0.0). + (equivalent to ``min_weight=0.0``). Returns ------- @@ -189,7 +195,9 @@ def average( """ ds = self._dataset.copy() dv = _get_data_var(ds, data_var) + self._validate_axis_arg(axis) + min_weight = _validate_min_weight(min_weight) if isinstance(weights, str) and weights == "generate": if lat_bounds is not None: @@ -201,7 +209,7 @@ def average( self._weights = weights self._validate_weights(dv, axis) - ds[dv.name] = self._averager(dv, axis, minimum_weight=minimum_weight) + ds[dv.name] = self._averager(dv, axis, min_weight=min_weight) if keep_weights: ds[self._weights.name] = self._weights @@ -211,9 +219,9 @@ def average( def get_weights( self, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], - lat_bounds: Optional[RegionAxisBounds] = None, - lon_bounds: Optional[RegionAxisBounds] = None, - data_var: Optional[str] = None, + lat_bounds: RegionAxisBounds | None = None, + lon_bounds: RegionAxisBounds | None = None, + data_var: str | None = None, ) -> xr.DataArray: """ Get area weights for specified axis keys and an optional target domain. @@ -232,13 +240,13 @@ def get_weights( ---------- axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. - lat_bounds : Optional[RegionAxisBounds] + lat_bounds : RegionAxisBounds | None Tuple of latitude boundaries for regional selection, by default None. - lon_bounds : Optional[RegionAxisBounds] + lon_bounds : RegionAxisBounds | None Tuple of longitude boundaries for regional selection, by default None. - data_var: Optional[str] + data_var: str | None The key of the data variable, by default None. Pass this argument when the dataset has more than one bounds per axis (e.g., "lon" and "zlon_bnds" for the "X" axis), or you want weights for a @@ -259,7 +267,7 @@ def get_weights( and pressure). """ Bounds = TypedDict( - "Bounds", {"weights_method": Callable, "region": Optional[np.ndarray]} + "Bounds", {"weights_method": Callable, "region": np.ndarray | None} ) axis_bounds: Dict[SpatialAxis, Bounds] = { @@ -382,7 +390,7 @@ def _validate_region_bounds(self, axis: SpatialAxis, bounds: RegionAxisBounds): ) def _get_longitude_weights( - self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None ) -> xr.DataArray: """Gets weights for the longitude axis. @@ -409,7 +417,7 @@ def _get_longitude_weights( ---------- domain_bounds : xr.DataArray The array of bounds for the longitude domain. - region_bounds : Optional[np.ndarray] + region_bounds : np.ndarray | None The array of bounds for longitude regional selection. Returns @@ -423,7 +431,7 @@ def _get_longitude_weights( If the there are multiple instances in which the domain_bounds[:, 0] > domain_bounds[:, 1] """ - p_meridian_index: Optional[np.ndarray] = None + p_meridian_index: np.ndarray | None = None d_bounds = domain_bounds.copy() pm_cells = np.where(domain_bounds[:, 1] - domain_bounds[:, 0] < 0)[0] @@ -455,7 +463,7 @@ def _get_longitude_weights( return weights def _get_latitude_weights( - self, domain_bounds: xr.DataArray, region_bounds: Optional[np.ndarray] + self, domain_bounds: xr.DataArray, region_bounds: np.ndarray | None ) -> xr.DataArray: """Gets weights for the latitude axis. @@ -467,7 +475,7 @@ def _get_latitude_weights( ---------- domain_bounds : xr.DataArray The array of bounds for the latitude domain. - region_bounds : Optional[np.ndarray] + region_bounds : np.ndarray | None The array of bounds for latitude regional selection. Returns @@ -710,7 +718,7 @@ def _averager( self, data_var: xr.DataArray, axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], - minimum_weight: Optional[float] = None, + min_weight: float, ): """Perform a weighted average of a data variable. @@ -729,10 +737,9 @@ def _averager( Data variable inside a Dataset. axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] List of axis dimensions to average over. - minimum_weight : optional, float - Fraction of data coverage (i..e, weight) needed to return a - spatial average value. Value must range from 0 to 1, by default None - (equivalent to minimum_weight=0.0). + min_weight : float + Fraction of data coverage (i.e, weight) needed to return a + spatial average value. Value must range from 0 to 1. Returns ------- @@ -746,45 +753,68 @@ def _averager( """ weights = self._weights.fillna(0) - # ensure required weight is between 0 and 1 - if minimum_weight is None: - minimum_weight = 0.0 - elif minimum_weight < 0.0: - raise ValueError( - "minimum_weight argument is less than 0. " - "minimum_weight must be between 0 and 1." - ) - elif minimum_weight > 1.0: - raise ValueError( - "minimum_weight argument is greater than 1. " - "minimum_weight must be between 0 and 1." - ) - - # need weights to match data_var dimensionality - if minimum_weight > 0.0: + # TODO: This conditional might not be needed because Xarray will + # automatically broadcast the weights to the data variable for + # operations such as .mean() and .where(). + if min_weight > 0.0: weights, data_var = xr.broadcast(weights, data_var) - # get averaging dimensions - dim = [] + dim: List[str] = [] for key in axis: - dim.append(get_dim_keys(data_var, key)) + dim.append(get_dim_keys(data_var, key)) # type: ignore - # compute weighed mean with xr.set_options(keep_attrs=True): - weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) - - # if weight thresholds applied, calculate fraction of data availability - # replace values that do not meet minimum weight with nan - if minimum_weight > 0.0: - # sum all weights (assuming no missing values exist) - weight_sum_all = weights.sum(dim=dim) # type: ignore - # zero out cells with missing values in data_var - weights = xr.where(~np.isnan(data_var), weights, 0) - # sum all weights (including zero for missing values) - weight_sum_masked = weights.sum(dim=dim) # type: ignore - # get fraction of weight available - frac = weight_sum_masked / weight_sum_all - # nan out values that don't meet specified weight threshold - weighted_mean = xr.where(frac >= minimum_weight, weighted_mean, np.nan) - - return weighted_mean + dv_mean = data_var.cf.weighted(weights).mean(dim=dim) + + if min_weight > 0.0: + dv_mean = self._mask_var_with_weight_threshold( + dv_mean, dim, weights, min_weight + ) + + return dv_mean + + def _mask_var_with_weight_threshold( + self, dv: xr.DataArray, dim: List[str], weights: xr.DataArray, min_weight: float + ) -> xr.DataArray: + """Mask values that do not meet the minimum weight threshold with np.nan. + + This function is useful for cases where the weighting of data might be + skewed based on the availability of data. For example, if a portion of + cells in a region has significantly more missing data than other other + regions, it can result in inaccurate calculations of spatial averaging. + Masking values that do not meet the minimum weight threshold ensures + more accurate calculations. + + Parameters + ---------- + dv : xr.DataArray + The weighted variable. + dim: List[str]: + List of axis dimensions to average over. + weights : xr.DataArray + A DataArray containing either the regional weights used for weighted + averaging. ``weights`` must include the same axis dimensions and + dimensional sizes as the data variable. + min_weight : float + Fraction of data coverage (i.e, weight) needed to return a + spatial average value. Value must range from 0 to 1. + + Returns + ------- + xr.DataArray + The variable with the minimum weight threshold applied. + """ + # Sum all weights, including zero for missing values. + weight_sum_all = weights.sum(dim=dim) + + masked_weights = _get_masked_weights(dv, weights) + weight_sum_masked = masked_weights.sum(dim=dim) + + # Get fraction of the available weight. + frac = weight_sum_masked / weight_sum_all + + # Nan out values that don't meet specified weight threshold. + dv_new = xr.where(frac >= min_weight, dv, np.nan, keep_attrs=True) + dv_new.name = dv.name + + return dv_new diff --git a/xcdat/utils.py b/xcdat/utils.py index 83596561..a2f674fa 100644 --- a/xcdat/utils.py +++ b/xcdat/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import json from typing import Dict, List, Optional, Union @@ -132,3 +134,61 @@ def _if_multidim_dask_array_then_load( return obj.load() return None + + +def _get_masked_weights(dv: xr.DataArray, weights: xr.DataArray) -> xr.DataArray: + """Get weights with missing data (`np.nan`) receiving no weight (zero). + + Parameters + ---------- + dv : xr.DataArray + The variable. + weights : xr.DataArray + A DataArray containing either the regional or temporal weights used for + weighted averaging. ``weights`` must include the same axis dimensions + and dimensional sizes as the data variable. + + Returns + ------- + xr.DataArray + The masked weights. + """ + masked_weights = xr.where(dv.copy().isnull(), 0.0, weights) + + return masked_weights + + +def _validate_min_weight(min_weight: float | None) -> float: + """Validate the ``min_weight`` value. + + Parameters + ---------- + min_weight : float | None + Fraction of data coverage (i..e, weight) needed to return a + spatial average value. Value must range from 0 to 1. + + Returns + ------- + float + The required weight percentage. + + Raises + ------ + ValueError + If the `min_weight` argument is less than 0. + ValueError + If the `min_weight` argument is greater than 1. + """ + if min_weight is None: + return 0.0 + elif min_weight < 0.0: + raise ValueError( + "min_weight argument is less than 0. " "min_weight must be between 0 and 1." + ) + elif min_weight > 1.0: + raise ValueError( + "min_weight argument is greater than 1. " + "min_weight must be between 0 and 1." + ) + + return min_weight