Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve time discretization #257

Merged
merged 27 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
15d19e1
improve time discretization:
dbrakenhoff Aug 28, 2023
78b620c
update STO pkg
dbrakenhoff Aug 28, 2023
ec31953
update recharge pkg for new ds time discretization
dbrakenhoff Aug 28, 2023
f239cb2
update get_tdis_perioddata
dbrakenhoff Aug 28, 2023
36cebae
update tests for new time discretization
dbrakenhoff Aug 28, 2023
06c4e3c
black
dbrakenhoff Aug 28, 2023
c5ffc29
codacy
dbrakenhoff Aug 28, 2023
6082bd4
process @OnnoEbbens comments
dbrakenhoff Aug 30, 2023
21dd905
process @OnnoEbbens comments
dbrakenhoff Aug 30, 2023
505a985
remove commented code
dbrakenhoff Aug 31, 2023
361c509
pin pandas version < 2.1.0
dbrakenhoff Aug 31, 2023
9ff2a84
add pin to ci not RTD...
dbrakenhoff Aug 31, 2023
5bc7da3
process comments @rubencalje
dbrakenhoff Aug 31, 2023
54160a4
Add perlen and default value for start
rubencalje Aug 31, 2023
f5c5da2
minor docstring update
rubencalje Aug 31, 2023
3bfc28e
Allow time to be a single value as well
rubencalje Aug 31, 2023
d32c937
remove default value of start
rubencalje Aug 31, 2023
3322919
Fix tests
rubencalje Aug 31, 2023
d620084
Update notebooks
rubencalje Aug 31, 2023
fb17b73
Make sure time is converted to an iterable a bit earlier
rubencalje Aug 31, 2023
d796995
Add knmi bugfix
rubencalje Sep 1, 2023
f270efd
Fix new warning in pandas 2.1.0
rubencalje Sep 1, 2023
f7159c9
Fix other problems in notebooks
rubencalje Sep 1, 2023
cfa3d3a
Fix last notebook bugs
rubencalje Sep 1, 2023
e4d009f
Remove start_date_time check in modpath
rubencalje Sep 1, 2023
652196e
codacy + json error nb11
dbrakenhoff Sep 4, 2023
658b4ef
update log message
dbrakenhoff Sep 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 212 additions & 2 deletions nlmod/dims/time.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import datetime as dt
import logging
import warnings

import numpy as np
import pandas as pd
import xarray as xr
from xarray import IndexVariable

logger = logging.getLogger(__name__)


def set_ds_time(
def set_ds_time_deprecated(
ds,
time=None,
steady_state=False,
Expand Down Expand Up @@ -71,6 +73,13 @@ def set_ds_time(
ds : xarray.Dataset
dataset with time variant model data
"""

warnings.warn(
"this function is deprecated and will eventually be removed, "
"please use nlmod.time.set_ds_time() in the future.",
DeprecationWarning,
)

# checks
if time_units.lower() != "days":
raise NotImplementedError()
Expand Down Expand Up @@ -124,9 +133,192 @@ def set_ds_time(
ds.time.attrs["steady_start"] = int(steady_start)
ds.time.attrs["steady_state"] = int(steady_state)

# add to ds (for new version nlmod)
# add steady, nstp and tsmult to dataset
steady = int(steady_state) * np.ones(len(time_dt), dtype=int)
if steady_start:
steady[0] = 1
ds["steady"] = ("time",), steady

if isinstance(nstp, (int, np.integer)):
nstp = nstp * np.ones(len(time), dtype=int)
ds["nstp"] = ("time",), nstp

if isinstance(tsmult, float):
tsmult = tsmult * np.ones(len(time))
ds["tsmult"] = ("time",), tsmult

return ds


def set_ds_time(
ds,
start,
time=None,
steady=False,
steady_start=True,
time_units="DAYS",
perlen=None,
nstp=1,
tsmult=1.0,
):
"""Set time discretisation for model dataset.

Parameters
----------
ds : xarray.Dataset
model dataset
start : int, float, str or pandas.Timestamp, optional
model start. When start is an integer or float it is interpreted as the number
of days of the first stress-period. When start is a string or pandas Timestamp
it is the start datetime of the simulation.
time : float, int or array-like, optional
float(s) (indicating elapsed time) or timestamp(s) corresponding to the end of
each stress period in the model. When time is a single value, the model will
have only one stress period. When time is None, the stress period lengths have
to be supplied via perlen. The default is None.
steady : arraylike or bool, optional
arraylike indicating which stress periods are steady-state, by default False,
which sets all stress periods to transient with the first period determined by
value of `steady_start`.
steady_start : bool, optional
whether to set the first period to steady-state, default is True, only used
when steady is passed as single boolean.
time_units : str, optional
time units, by default "DAYS"
perlen : float, int or array-like, optional
length of each stress-period. Only used when time is None. When perlen is a
single value, the model will have only one stress period. The default is None.
nstp : int or array-like, optional
number of steps per stress period, stored in ds.attrs, default is 1
tsmult : float, optional
timestep multiplier within stress periods, stored in ds.attrs, default is 1.0

Returns
-------
ds : xarray.Dataset
model dataset with added time coordinate

"""
logger.info(
"This is the new version of set_ds_time()."
" If you're looking for the old behavior,"
"use `nlmod.time.set_ds_time_deprecated()`."
)

if time is None and perlen is None:
raise (Exception("Please specify either time or perlen in set_ds_time"))
elif perlen is not None:
if time is not None:
msg = f"Cannot use both time and perlen. Ignoring perlen: {perlen}"
logger.warning(msg)
else:
if isinstance(perlen, (int, np.integer, float)):
perlen = [perlen]
time = np.cumsum(perlen)

# parse start
if isinstance(start, (int, np.integer, float)):
if isinstance(time[0], (int, np.integer, float)):
raise (Exception("Make sure start or time contains a valid TimeStamp"))
start = time[0] - pd.to_timedelta(start, "D")
elif isinstance(start, str):
start = pd.Timestamp(start)
elif isinstance(start, (pd.Timestamp, np.datetime64)):
pass
else:
raise TypeError("Cannot parse start datetime.")

# convert time to Timestamps
if not hasattr(time, "__iter__"):
time = [time]
if isinstance(time[0], (int, np.integer, float)):
time = pd.Timestamp(start) + pd.to_timedelta(time, time_units)
elif isinstance(time[0], str):
time = pd.to_datetime(time)
elif isinstance(time[0], (pd.Timestamp, np.datetime64, xr.core.variable.Variable)):
pass
else:
raise TypeError("Cannot process 'time' argument. Datatype not understood.")

if time[0] <= start:
msg = (
"The timestamp of the first stress period cannot be before or "
"equal to the model start time! Please modify `time` or `start`!"
)
logger.error(msg)
raise ValueError(msg)

ds = ds.assign_coords(coords={"time": time})

# add steady, nstp and tsmult to dataset
if isinstance(steady, bool):
steady = int(steady) * np.ones(len(time), dtype=int)
if steady_start:
steady[0] = 1
ds["steady"] = ("time",), steady

if isinstance(nstp, (int, np.integer)):
nstp = nstp * np.ones(len(time), dtype=int)
ds["nstp"] = ("time",), nstp

if isinstance(tsmult, float):
tsmult = tsmult * np.ones(len(time))
ds["tsmult"] = ("time",), tsmult

if time_units == "D":
time_units = "DAYS"
ds.time.attrs["time_units"] = time_units
ds.time.attrs["start"] = str(start)

return ds


def ds_time_idx_from_tdis_settings(start, perlen, nstp=1, tsmult=1.0, time_units="D"):
dbrakenhoff marked this conversation as resolved.
Show resolved Hide resolved
"""Get time index from TDIS perioddata: perlen, nstp, tsmult.


Parameters
----------
start : str, pd.Timestamp
start datetime
perlen : array-like
array of period lengths
nstp : int, or array-like optional
number of steps per period, by default 1
tsmult : float or array-like, optional
timestep multiplier per period, by default 1.0
time_units : str, optional
time units, by default "D"

Returns
-------
IndexVariable
time coordinate for xarray data-array or dataset
"""
deltlist = []
for kper, delt in enumerate(perlen):
if not isinstance(nstp, int):
kstpkper = nstp[kper]
else:
kstpkper = nstp

if not isinstance(tsmult, float):
tsm = tsmult[kper]
else:
tsm = tsmult

if tsm > 1.0:
delt0 = delt * (tsm - 1) / (tsm**kstpkper - 1)
delt = delt0 * tsm ** np.arange(kstpkper)
else:
delt = np.ones(kstpkper) * delt / kstpkper
deltlist.append(delt)

dt_arr = np.cumsum(np.concatenate(deltlist))
return ds_time_idx(dt_arr, start_datetime=start, time_units=time_units)


def estimate_nstp(
forcing, perlen=1, tsmult=1.1, nstp_min=1, nstp_max=25, return_dt_arr=False
):
Expand Down Expand Up @@ -214,6 +406,15 @@ def estimate_nstp(


def ds_time_from_model(gwf):
warnings.warn(
"this function was renamed to `ds_time_idx_from_model`. "
"Please use the new function name.",
DeprecationWarning,
)
return ds_time_idx_from_model(gwf)


def ds_time_idx_from_model(gwf):
"""Get time index variable from model (gwf or gwt).

Parameters
Expand All @@ -227,10 +428,19 @@ def ds_time_from_model(gwf):
time coordinate for xarray data-array or dataset
"""

return ds_time_from_modeltime(gwf.modeltime)
return ds_time_idx_from_modeltime(gwf.modeltime)


def ds_time_from_modeltime(modeltime):
warnings.warn(
"this function was renamed to `ds_time_idx_from_model`. "
"Please use the new function name.",
DeprecationWarning,
)
return ds_time_idx_from_modeltime(modeltime)


def ds_time_idx_from_modeltime(modeltime):
"""Get time index variable from modeltime object.

Parameters
Expand Down
13 changes: 5 additions & 8 deletions nlmod/gwf/gwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,18 +578,15 @@ def sto(
"""
logger.info("creating mf6 STO")

if ds.time.steady_state:
if "time" not in ds or ds["steady"].all():
logger.warning("Model is steady-state, no STO package created.")
return None
else:
if ds.time.steady_start:
sts_spd = {0: True}
trn_spd = {1: True}
else:
sts_spd = None
trn_spd = {0: True}
sts_spd = {iper: bool(b) for iper, b in enumerate(ds["steady"])}
trn_spd = {iper: not bool(b) for iper, b in enumerate(ds["steady"])}

sy = _get_value_from_ds_datavar(ds, "sy", sy, default=0.2)
ss = _get_value_from_ds_datavar(ds, "ss", ss, default=0.000001)
ss = _get_value_from_ds_datavar(ds, "ss", ss, default=1e-5)

sto = flopy.mf6.ModflowGwfsto(
gwf,
Expand Down
8 changes: 4 additions & 4 deletions nlmod/gwf/recharge.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def ds_to_rch(gwf, ds, mask=None, pname="rch", **kwargs):
raise ValueError("please remove nan values in recharge data array")

# get stress period data
if ds.time.steady_state:
if ds["steady"].all():
recharge = "recharge"
if "time" in ds["recharge"].dims:
mask = ds["recharge"].isel(time=0) != 0
Expand Down Expand Up @@ -69,7 +69,7 @@ def ds_to_rch(gwf, ds, mask=None, pname="rch", **kwargs):
**kwargs,
)

if ds.time.steady_state:
if ds["steady"].all():
return rch

# create timeseries packages
Expand Down Expand Up @@ -128,7 +128,7 @@ def ds_to_evt(gwf, ds, pname="evt", nseg=1, surface=None, depth=None, **kwargs):
raise ValueError("please remove nan values in evaporation data array")

# get stress period data
if ds.time.steady_state:
if ds["steady"].all():
if "time" in ds["evaporation"].dims:
mask = ds["evaporation"].isel(time=0) != 0
else:
Expand Down Expand Up @@ -163,7 +163,7 @@ def ds_to_evt(gwf, ds, pname="evt", nseg=1, surface=None, depth=None, **kwargs):
**kwargs,
)

if ds.time.steady_state:
if ds["steady"].all():
return evt

# create timeseries packages
Expand Down
1 change: 1 addition & 0 deletions nlmod/plot/plotutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..dims.resample import get_affine_mod_to_world
from ..epsg28992 import EPSG_28992


def get_patches(ds, rotated=False):
"""Get the matplotlib patches for a vertex grid."""
assert "icell2d" in ds.dims
Expand Down
25 changes: 14 additions & 11 deletions nlmod/sim/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def write_and_run(sim, ds, write_ds=True, script_path=None, silent=False):
ds.attrs["model_ran_on"] = dt.datetime.now().strftime("%Y%m%d_%H:%M:%S")


def get_tdis_perioddata(ds):
def get_tdis_perioddata(ds, nstp="nstp", tsmult="tsmult"):
"""Get tdis_perioddata from ds.

Parameters
Expand Down Expand Up @@ -92,15 +92,15 @@ def get_tdis_perioddata(ds):
if len(ds["time"]) > 1:
perlen.extend(np.diff(ds["time"]) / deltat)

if "nstp" in ds:
nstp = ds["nstp"].values
else:
nstp = [ds.time.nstp] * len(perlen)
nstp = util._get_value_from_ds_datavar(ds, "nstp", nstp, return_da=False)

if "tsmult" in ds:
tsmult = ds["tsmult"].values
else:
tsmult = [ds.time.tsmult] * len(perlen)
if isinstance(nstp, (int, np.integer)):
nstp = [nstp] * len(perlen)

tsmult = util._get_value_from_ds_datavar(ds, "tsmult", tsmult, return_da=False)

if isinstance(tsmult, float):
tsmult = [tsmult] * len(perlen)

tdis_perioddata = list(zip(perlen, nstp, tsmult))

Expand Down Expand Up @@ -144,7 +144,7 @@ def sim(ds, exe_name=None):
return sim


def tdis(ds, sim, pname="tdis"):
def tdis(ds, sim, pname="tdis", nstp="nstp", tsmult="tsmult", **kwargs):
"""create tdis package from the model dataset.

Parameters
Expand All @@ -156,6 +156,8 @@ def tdis(ds, sim, pname="tdis"):
simulation object.
pname : str, optional
package name
**kwargs
passed on to flopy.mft.ModflowTdis

Returns
-------
Expand All @@ -166,7 +168,7 @@ def tdis(ds, sim, pname="tdis"):
# start creating model
logger.info("creating mf6 TDIS")

tdis_perioddata = get_tdis_perioddata(ds)
tdis_perioddata = get_tdis_perioddata(ds, nstp=nstp, tsmult=tsmult)

# Create the Flopy temporal discretization object
tdis = flopy.mf6.ModflowTdis(
Expand All @@ -176,6 +178,7 @@ def tdis(ds, sim, pname="tdis"):
nper=len(ds.time),
start_date_time=pd.Timestamp(ds.time.start).isoformat(),
perioddata=tdis_perioddata,
**kwargs,
)

return tdis
Expand Down
Loading
Loading