Skip to content

Commit

Permalink
Merge pull request #148 from pastas/improve_search
Browse files Browse the repository at this point in the history
allow searching in all libraries
  • Loading branch information
dbrakenhoff authored Oct 22, 2024
2 parents b89d3a5 + 0926d03 commit 0877510
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions pastastore/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ def search(
case_sensitive: bool = True,
sort=True,
):
"""Search for names of time series or models starting with `s`.
"""Search for names of time series or models containing string `s`.
Parameters
----------
Expand All @@ -1515,30 +1515,45 @@ def search(
list of names that match search result
"""
if libname == "models":
lib_names = self.model_names
lib_names = {"models": self.model_names}
elif libname == "stresses":
lib_names = self.stresses_names
lib_names = {"stresses": self.stresses_names}
elif libname == "oseries":
lib_names = self.oseries_names
lib_names = {"oseries": self.oseries_names}
elif libname is None:
lib_names = {
"oseries": self.oseries_names,
"stresses": self.stresses_names,
"models": self.model_names,
}
else:
raise ValueError("Provide valid libname: 'models', 'stresses' or 'oseries'")

if isinstance(s, str):
if case_sensitive:
matches = [n for n in lib_names if s in n]
else:
matches = [n for n in lib_names if s.lower() in n.lower()]
if isinstance(s, list):
m = np.array([])
for sub in s:
result = {}
for lib, names in lib_names.items():
if isinstance(s, str):
if case_sensitive:
m = np.append(m, [n for n in lib_names if sub in n])
matches = [n for n in names if s in n]
else:
m = np.append(m, [n for n in lib_names if sub.lower() in n.lower()])
matches = list(np.unique(m))
if sort:
matches.sort()
return matches
matches = [n for n in names if s.lower() in n.lower()]
elif isinstance(s, list):
m = np.array([])
for sub in s:
if case_sensitive:
m = np.append(m, [n for n in names if sub in n])
else:
m = np.append(m, [n for n in names if sub.lower() in n.lower()])
matches = list(np.unique(m))
else:
raise TypeError("s must be str or list of str!")
if sort:
matches.sort()
result[lib] = matches

if len(result) == 1:
return result[lib]
else:
return result

def get_model_timeseries_names(
self,
Expand Down

0 comments on commit 0877510

Please sign in to comment.