Skip to content

Commit

Permalink
Merge pull request #124 from pastas/pr119_fix_dtype_pasconnector
Browse files Browse the repository at this point in the history
PasConnector returns Series and maintains series dtype
  • Loading branch information
dbrakenhoff authored Jun 26, 2024
2 parents d723f66 + 696d2ca commit 0f5d21f
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 27 deletions.
21 changes: 14 additions & 7 deletions pastastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ def _update_series(
self._validate_input_series(series)
series = self._set_series_name(series, name)
stored = self._get_series(libname, name, progressbar=False)
if self.conn_type == "pas" and (type(series) != type(stored)):
if isinstance(series, pd.DataFrame):
stored = stored.to_frame()
# get union of index
idx_union = stored.index.union(series.index)
# update series with new values
Expand Down Expand Up @@ -1261,7 +1264,7 @@ def _meta_list_to_frame(metalist: list, names: list):
return meta

def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
"""Internal method to parse dictionary describing pastas models.
"""Parse dictionary describing pastas models (internal method).
Parameters
----------
Expand All @@ -1286,7 +1289,7 @@ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
if name not in self.oseries.index:
msg = "oseries '{}' not present in library".format(name)
raise LookupError(msg)
mdict["oseries"]["series"] = self.get_oseries(name)
mdict["oseries"]["series"] = self.get_oseries(name).squeeze()
# update tmin/tmax from time series
if update_ts_settings:
mdict["oseries"]["settings"]["tmin"] = mdict["oseries"]["series"].index[
Expand All @@ -1306,7 +1309,7 @@ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
if "series" not in stress:
name = str(stress["name"])
if name in self.stresses.index:
stress["series"] = self.get_stresses(name)
stress["series"] = self.get_stresses(name).squeeze()
# update tmin/tmax from time series
if update_ts_settings:
stress["settings"]["tmin"] = stress["series"].index[
Expand All @@ -1321,7 +1324,7 @@ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
if "series" not in stress:
name = str(stress["name"])
if name in self.stresses.index:
stress["series"] = self.get_stresses(name)
stress["series"] = self.get_stresses(name).squeeze()
# update tmin/tmax from time series
if update_ts_settings:
stress["settings"]["tmin"] = stress["series"].index[
Expand All @@ -1337,7 +1340,7 @@ def _parse_model_dict(self, mdict: dict, update_ts_settings: bool = False):
if "series" not in stress:
name = str(stress["name"])
if name in self.stresses.index:
stress["series"] = self.get_stresses(name)
stress["series"] = self.get_stresses(name).squeeze()
# update tmin/tmax from time series
if update_ts_settings:
stress["settings"]["tmin"] = stress["series"].index[0]
Expand Down Expand Up @@ -1727,23 +1730,27 @@ def _models_to_archive(self, archive, names=None, progressbar=True):
archive.writestr(f"models/{n}.pas", jsondict)

@staticmethod
def _series_from_json(fjson: str):
def _series_from_json(fjson: str, squeeze: bool = True):
"""Load time series from JSON.
Parameters
----------
fjson : str
path to file
squeeze : bool, optional
squeeze time series object to obtain pandas Series
Returns
-------
s : pd.DataFrame
DataFrame containing time series
"""
s = pd.read_json(fjson, orient="columns", precise_float=True)
s = pd.read_json(fjson, orient="columns", precise_float=True, dtype=False)
if not isinstance(s.index, pd.DatetimeIndex):
s.index = pd.to_datetime(s.index, unit="ms")
s = s.sort_index() # needed for some reason ...
if squeeze:
return s.squeeze()
return s

@staticmethod
Expand Down
18 changes: 9 additions & 9 deletions pastastore/yaml_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def __init__(self, pstore):
self.pstore = pstore

def _parse_rechargemodel_dict(self, d: Dict, onam: Optional[str] = None) -> Dict:
"""Internal method to parse RechargeModel dictionary.
"""Parse RechargeModel dictionary (internal method).
Note: supports 'nearest' as input to 'prec' and 'evap',
which will automatically select nearest stress with kind="prec" or
Expand All @@ -205,7 +205,7 @@ def _parse_rechargemodel_dict(self, d: Dict, onam: Optional[str] = None) -> Dict
if isinstance(prec_val, dict):
pnam = prec_val["name"]
p = self.pstore.get_stresses(pnam)
prec_val["series"] = p
prec_val["series"] = p.squeeze()
prec = prec_val
elif prec_val.startswith("nearest"):
if onam is None:
Expand All @@ -221,7 +221,7 @@ def _parse_rechargemodel_dict(self, d: Dict, onam: Optional[str] = None) -> Dict
"name": pnam,
"settings": "prec",
"metadata": pmeta,
"series": p,
"series": p.squeeze(),
}
elif isinstance(prec_val, str):
pnam = d["prec"]
Expand All @@ -230,7 +230,7 @@ def _parse_rechargemodel_dict(self, d: Dict, onam: Optional[str] = None) -> Dict
"name": pnam,
"settings": "prec",
"metadata": pmeta,
"series": p,
"series": p.squeeze(),
}
else:
raise NotImplementedError(f"Could not parse prec value: '{prec_val}'")
Expand All @@ -241,7 +241,7 @@ def _parse_rechargemodel_dict(self, d: Dict, onam: Optional[str] = None) -> Dict
if isinstance(evap_val, dict):
enam = evap_val["name"]
e = self.pstore.get_stresses(enam)
evap_val["series"] = e
evap_val["series"] = e.squeeze()
evap = evap_val
elif evap_val.startswith("nearest"):
if onam is None:
Expand All @@ -257,7 +257,7 @@ def _parse_rechargemodel_dict(self, d: Dict, onam: Optional[str] = None) -> Dict
"name": enam,
"settings": "evap",
"metadata": emeta,
"series": e,
"series": e.squeeze(),
}
elif isinstance(evap_val, str):
enam = d["evap"]
Expand All @@ -266,7 +266,7 @@ def _parse_rechargemodel_dict(self, d: Dict, onam: Optional[str] = None) -> Dict
"name": enam,
"settings": "evap",
"metadata": emeta,
"series": e,
"series": e.squeeze(),
}
else:
raise NotImplementedError(f"Could not parse evap value: '{evap_val}'")
Expand Down Expand Up @@ -307,7 +307,7 @@ def _parse_rechargemodel_dict(self, d: Dict, onam: Optional[str] = None) -> Dict
onam = d["oseries"]
if isinstance(onam, str):
o = self.pstore.get_oseries(onam)
d["oseries"] = o
d["oseries"] = o.squeeze()

return d

Expand Down Expand Up @@ -487,7 +487,7 @@ def construct_mldict(self, mlyml: dict, mlnam: str) -> dict:
o, ometa = self.pstore.get_oseries(onam, return_metadata=True)

# create model to obtain default model settings
ml = ps.Model(o, name=mlnam, metadata=ometa)
ml = ps.Model(o.squeeze(), name=mlnam, metadata=ometa)
mldict = ml.to_dict(series=True)

# update with stored model settings
Expand Down
18 changes: 7 additions & 11 deletions tests/test_002_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,15 @@ def test_get_library(conn):
def test_add_get_series(request, conn):
o1 = pd.Series(
index=pd.date_range("2000", periods=10, freq="D"),
data=1.0,
dtype=np.float64,
data=0.0,
)
o1.name = "test_series"
conn.add_oseries(o1, "test_series", metadata=None)
o2 = conn.get_oseries("test_series")
# PasConnector has no logic for preserving Series
if conn.conn_type == "pas":
o2 = o2.squeeze()
try:
assert isinstance(o2, pd.Series)
assert (o1 == o2).all()
assert o1.equals(o2)
assert o1.dtype == o2.dtype
finally:
conn.del_oseries("test_series")

Expand All @@ -46,9 +43,6 @@ def test_add_get_series_wnans(request, conn):
o1.name = "test_series_nans"
conn.add_oseries(o1, "test_series_nans", metadata=None)
o2 = conn.get_oseries("test_series_nans")
# PasConnector has no logic for preserving Series
if conn.conn_type == "pas":
o2 = o2.squeeze()
try:
assert isinstance(o2, pd.Series)
assert o1.equals(o2)
Expand All @@ -65,10 +59,12 @@ def test_add_get_dataframe(request, conn):
o1.index.name = "test_idx"
conn.add_oseries(o1, "test_df", metadata=None)
o2 = conn.get_oseries("test_df")
# little hack as PasConnector does preserve DataFrames after load...
if conn.conn_type == "pas":
o2 = o2.to_frame()
try:
assert isinstance(o2, pd.DataFrame)
# little hack as PasConnector has dtype int after load...
assert o1.equals(o2.astype(float))
assert o1.equals(o2)
finally:
conn.del_oseries("test_df")

Expand Down

0 comments on commit 0f5d21f

Please sign in to comment.