Skip to content

Commit

Permalink
Some minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rubencalje committed Jul 25, 2024
1 parent e9d0923 commit 32b65e7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
14 changes: 10 additions & 4 deletions nlmod/dims/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def to_model_ds(
return ds


def extrapolate_ds(ds, mask=None, layer="layer"):
def extrapolate_ds(ds, mask=None, layer="layer", mask_values=None):
"""Fill missing data in layermodel based on nearest interpolation.
Used for ensuring layer model contains data everywhere. Useful for
Expand All @@ -223,6 +223,10 @@ def extrapolate_ds(ds, mask=None, layer="layer"):
variable. The default is None.
layer : str, optional
The name of the layer dimension. The default is 'layer'.
mask_values : np.ndarray, optional
Boolean mask for each cell, with a value of True if its value is used to fill
data in mask. When mask_values is None, it is determined from mask. The default
is None.
Returns
-------
Expand All @@ -234,6 +238,8 @@ def extrapolate_ds(ds, mask=None, layer="layer"):
if not mask.any():
# all of the model cells are is inside the known area
return ds
if mask_values is None:
mask_values = ~mask
if mask.all():
raise (ValueError("The model only contains NaNs"))
if "gridtype" in ds.attrs and ds.gridtype == "vertex":
Expand All @@ -243,7 +249,7 @@ def extrapolate_ds(ds, mask=None, layer="layer"):
else:
x, y = np.meshgrid(ds.x, ds.y)
dims = ("y", "x")
points = np.stack((x[~mask], y[~mask]), axis=1)
points = np.stack((x[mask_values], y[mask_values]), axis=1)
xi = np.stack((x[mask], y[mask]), axis=1)
# geneterate the tree only once, to increase speed
tree = cKDTree(points)
Expand All @@ -254,11 +260,11 @@ def extrapolate_ds(ds, mask=None, layer="layer"):
data = ds[key].data
if ds[key].dims == dims:
if np.isnan(data[mask]).sum() > 0: # do not update if no NaNs
data[mask] = data[~mask][i]
data[mask] = data[mask_values][i]
elif ds[key].dims == (layer,) + dims:
for lay in range(len(ds[layer])):
if np.isnan(data[lay, mask]).sum() > 0: # do not update if no NaNs
data[lay, mask] = data[lay, ~mask][i]
data[lay, mask] = data[lay, mask_values][i]
else:
logger.warning(
f"Data variable '{key}' not extrapolated because "
Expand Down
6 changes: 2 additions & 4 deletions nlmod/dims/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,6 @@ def update_ds_from_layer_ds(ds, layer_ds, method="nearest", **kwargs):
Dataset with variables from layer_ds.
"""
if not layer_ds.layer.equals(ds.layer):
# do not change the original Dataset
layer_ds = layer_ds.copy()
# update layers in ds
drop_vars = []
for var in ds.data_vars:
Expand Down Expand Up @@ -1201,7 +1199,7 @@ def gdf_to_data_array_struc(
DESCRIPTION.
"""
warnings.warn(
"The method gdf_to_data_array_struc is deprecated. Please use gdf_to_da instead",
"The method gdf_to_data_array_struc is deprecated. Please use gdf_to_da instead.",
DeprecationWarning,
)

Expand Down Expand Up @@ -1770,7 +1768,7 @@ def get_thickness_from_topbot(top, bot):
or (layer, icell2d).
"""
warnings.warn(
"The method get_thickness_from_topbot is deprecated. Please use nlmod.layers.calculate_thickness instead",
"The method get_thickness_from_topbot is deprecated. Please use nlmod.layers.calculate_thickness instead.",
DeprecationWarning,
)

Expand Down
14 changes: 7 additions & 7 deletions nlmod/dims/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def structured_da_to_ds(da, ds, method="average", nodata=np.nan):
def extent_to_polygon(extent):
logger.warning(
"nlmod.resample.extent_to_polygon is deprecated. "
"Use nlmod.util.extent_to_polygon instead"
"Use nlmod.util.extent_to_polygon instead."
)
from ..util import extent_to_polygon

Expand All @@ -594,7 +594,7 @@ def get_extent_polygon(ds, rotated=True):
"""Get the model extent, as a shapely Polygon."""
logger.warning(
"nlmod.resample.get_extent_polygon is deprecated. "
"Use nlmod.grid.get_extent_polygon instead"
"Use nlmod.grid.get_extent_polygon instead."
)
from .grid import get_extent_polygon

Expand All @@ -605,7 +605,7 @@ def affine_transform_gdf(gdf, affine):
"""Apply an affine transformation to a geopandas GeoDataFrame."""
logger.warning(
"nlmod.resample.affine_transform_gdf is deprecated. "
"Use nlmod.grid.affine_transform_gdf instead"
"Use nlmod.grid.affine_transform_gdf instead."
)
from .grid import affine_transform_gdf

Expand All @@ -615,7 +615,7 @@ def affine_transform_gdf(gdf, affine):
def get_extent(ds, rotated=True):
"""Get the model extent, corrected for angrot if necessary."""
logger.warning(
"nlmod.resample.get_extent is deprecated. Use nlmod.grid.get_extent instead"
"nlmod.resample.get_extent is deprecated. Use nlmod.grid.get_extent instead."
)
from .grid import get_extent

Expand All @@ -626,7 +626,7 @@ def get_affine_mod_to_world(ds):
"""Get the affine-transformation from model to real-world coordinates."""
logger.warning(
"nlmod.resample.get_affine_mod_to_world is deprecated. "
"Use nlmod.grid.get_affine_mod_to_world instead"
"Use nlmod.grid.get_affine_mod_to_world instead."
)
from .grid import get_affine_mod_to_world

Expand All @@ -637,7 +637,7 @@ def get_affine_world_to_mod(ds):
"""Get the affine-transformation from real-world to model coordinates."""
logger.warning(
"nlmod.resample.get_affine_world_to_mod is deprecated. "
"Use nlmod.grid.get_affine_world_to_mod instead"
"Use nlmod.grid.get_affine_world_to_mod instead."
)
from .grid import get_affine_world_to_mod

Expand All @@ -647,7 +647,7 @@ def get_affine_world_to_mod(ds):
def get_affine(ds, sx=None, sy=None):
"""Get the affine-transformation, from pixel to real-world coordinates."""
logger.warning(
"nlmod.resample.get_affine is deprecated. Use nlmod.grid.get_affine instead"
"nlmod.resample.get_affine is deprecated. Use nlmod.grid.get_affine instead."
)
from .grid import get_affine

Expand Down

0 comments on commit 32b65e7

Please sign in to comment.