Skip to content

Commit

Permalink
accept dates in string format for UUIDTmeRange
Browse files Browse the repository at this point in the history
  • Loading branch information
cevian committed Nov 14, 2023
1 parent 8897224 commit 8056bcd
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 68 deletions.
74 changes: 41 additions & 33 deletions nbs/00_vector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,46 @@
"source": [
"#| export\n",
"class UUIDTimeRange:\n",
" def __init__(self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, time_delta: Optional[timedelta] = None, start_inclusive=True, end_inclusive=False):\n",
" \n",
" @staticmethod\n",
" def _parse_datetime(input_datetime: Union[datetime, str]):\n",
" \"\"\"\n",
" Parse a datetime object or string representation of a datetime.\n",
"\n",
" Args:\n",
" input_datetime (datetime or str): Input datetime or string.\n",
"\n",
" Returns:\n",
" datetime: Parsed datetime object.\n",
"\n",
" Raises:\n",
" ValueError: If the input cannot be parsed as a datetime.\n",
" \"\"\"\n",
" if input_datetime is None or input_datetime == \"None\":\n",
" return None\n",
" \n",
" if isinstance(input_datetime, datetime):\n",
" # If input is already a datetime object, return it as is\n",
" return input_datetime\n",
"\n",
" if isinstance(input_datetime, str):\n",
" try:\n",
" # Attempt to parse the input string into a datetime\n",
" return datetime.fromisoformat(input_datetime)\n",
" except ValueError:\n",
" raise ValueError(\"Invalid datetime string format: {}\".format(input_datetime))\n",
"\n",
" raise ValueError(\"Input must be a datetime object or string\")\n",
"\n",
" def __init__(self, start_date: Optional[Union[datetime, str]] = None, end_date: Optional[Union[datetime, str]] = None, time_delta: Optional[timedelta] = None, start_inclusive=True, end_inclusive=False):\n",
" \"\"\"\n",
" A UUIDTimeRange is a time range predicate on the UUID Version 1 timestamps. \n",
" \n",
" Note that naive datetime objects are interpreted as local time on the python client side and converted to UTC before being sent to the database.\n",
" \"\"\"\n",
" start_date = UUIDTimeRange._parse_datetime(start_date)\n",
" end_date = UUIDTimeRange._parse_datetime(end_date)\n",
"\n",
" if start_date is not None and end_date is not None:\n",
" if start_date > end_date:\n",
" raise Exception(\"start_date must be before end_date\")\n",
Expand Down Expand Up @@ -726,36 +760,6 @@
" raise ValueError(\"Unknown filter type: {filter_type}\".format(filter_type=type(filter)))\n",
"\n",
" return (where, params)\n",
" \n",
" def _parse_datetime(self, input_datetime):\n",
" \"\"\"\n",
" Parse a datetime object or string representation of a datetime.\n",
"\n",
" Args:\n",
" input_datetime (datetime or str): Input datetime or string.\n",
"\n",
" Returns:\n",
" datetime: Parsed datetime object.\n",
"\n",
" Raises:\n",
" ValueError: If the input cannot be parsed as a datetime.\n",
" \"\"\"\n",
" if input_datetime is None:\n",
" return None\n",
" \n",
" if isinstance(input_datetime, datetime):\n",
" # If input is already a datetime object, return it as is\n",
" return input_datetime\n",
"\n",
" if isinstance(input_datetime, str):\n",
" try:\n",
" # Attempt to parse the input string into a datetime\n",
" return datetime.fromisoformat(input_datetime)\n",
" except ValueError:\n",
" raise ValueError(\"Invalid datetime string format\")\n",
"\n",
" raise ValueError(\"Input must be a datetime object or string\")\n",
"\n",
"\n",
" def search_query(\n",
" self, \n",
Expand Down Expand Up @@ -785,8 +789,8 @@
" if self.infer_filters:\n",
" if uuid_time_filter is None and isinstance(filter, dict):\n",
" if \"__start_date\" in filter or \"__end_date\" in filter:\n",
" start_date = self._parse_datetime(filter.get(\"__start_date\"))\n",
" end_date = self._parse_datetime(filter.get(\"__end_date\"))\n",
" start_date = UUIDTimeRange._parse_datetime(filter.get(\"__start_date\"))\n",
" end_date = UUIDTimeRange._parse_datetime(filter.get(\"__end_date\"))\n",
" \n",
" uuid_time_filter = UUIDTimeRange(start_date, end_date)\n",
" \n",
Expand Down Expand Up @@ -1506,6 +1510,8 @@
" #using uuid_time_filter\n",
" rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(start_date, end_date))\n",
" assert len(rec) == expected\n",
" rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(str(start_date), str(end_date)))\n",
" assert len(rec) == expected\n",
" \n",
" #using filters\n",
" filter = {}\n",
Expand Down Expand Up @@ -2248,6 +2254,8 @@
" #using uuid_time_filter\n",
" rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(start_date, end_date))\n",
" assert len(rec) == expected\n",
" rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(str(start_date), str(end_date)))\n",
" assert len(rec) == expected\n",
" \n",
" #using filters\n",
" filter = {}\n",
Expand Down
4 changes: 2 additions & 2 deletions timescale_vector/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@
'timescale_vector/client.py'),
'timescale_vector.client.QueryBuilder._get_embedding_index_name': ( 'vector.html#querybuilder._get_embedding_index_name',
'timescale_vector/client.py'),
'timescale_vector.client.QueryBuilder._parse_datetime': ( 'vector.html#querybuilder._parse_datetime',
'timescale_vector/client.py'),
'timescale_vector.client.QueryBuilder._quote_ident': ( 'vector.html#querybuilder._quote_ident',
'timescale_vector/client.py'),
'timescale_vector.client.QueryBuilder._where_clause_for_filter': ( 'vector.html#querybuilder._where_clause_for_filter',
Expand Down Expand Up @@ -153,6 +151,8 @@
'timescale_vector/client.py'),
'timescale_vector.client.UUIDTimeRange.__str__': ( 'vector.html#uuidtimerange.__str__',
'timescale_vector/client.py'),
'timescale_vector.client.UUIDTimeRange._parse_datetime': ( 'vector.html#uuidtimerange._parse_datetime',
'timescale_vector/client.py'),
'timescale_vector.client.UUIDTimeRange.build_query': ( 'vector.html#uuidtimerange.build_query',
'timescale_vector/client.py'),
'timescale_vector.client.uuid_from_time': ( 'vector.html#uuid_from_time',
Expand Down
70 changes: 37 additions & 33 deletions timescale_vector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,46 @@ def create_index_query(self, table_name_quoted:str, column_name_quoted: str, ind

# %% ../nbs/00_vector.ipynb 11
class UUIDTimeRange:
def __init__(self, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, time_delta: Optional[timedelta] = None, start_inclusive=True, end_inclusive=False):

@staticmethod
def _parse_datetime(input_datetime: Union[datetime, str]):
"""
Parse a datetime object or string representation of a datetime.
Args:
input_datetime (datetime or str): Input datetime or string.
Returns:
datetime: Parsed datetime object.
Raises:
ValueError: If the input cannot be parsed as a datetime.
"""
if input_datetime is None or input_datetime == "None":
return None

if isinstance(input_datetime, datetime):
# If input is already a datetime object, return it as is
return input_datetime

if isinstance(input_datetime, str):
try:
# Attempt to parse the input string into a datetime
return datetime.fromisoformat(input_datetime)
except ValueError:
raise ValueError("Invalid datetime string format: {}".format(input_datetime))

raise ValueError("Input must be a datetime object or string")

def __init__(self, start_date: Optional[Union[datetime, str]] = None, end_date: Optional[Union[datetime, str]] = None, time_delta: Optional[timedelta] = None, start_inclusive=True, end_inclusive=False):
"""
A UUIDTimeRange is a time range predicate on the UUID Version 1 timestamps.
Note that naive datetime objects are interpreted as local time on the python client side and converted to UTC before being sent to the database.
"""
start_date = UUIDTimeRange._parse_datetime(start_date)
end_date = UUIDTimeRange._parse_datetime(end_date)

if start_date is not None and end_date is not None:
if start_date > end_date:
raise Exception("start_date must be before end_date")
Expand Down Expand Up @@ -615,36 +649,6 @@ def _where_clause_for_filter(self, params: List, filter: Optional[Union[Dict[str
raise ValueError("Unknown filter type: {filter_type}".format(filter_type=type(filter)))

return (where, params)

def _parse_datetime(self, input_datetime):
"""
Parse a datetime object or string representation of a datetime.
Args:
input_datetime (datetime or str): Input datetime or string.
Returns:
datetime: Parsed datetime object.
Raises:
ValueError: If the input cannot be parsed as a datetime.
"""
if input_datetime is None:
return None

if isinstance(input_datetime, datetime):
# If input is already a datetime object, return it as is
return input_datetime

if isinstance(input_datetime, str):
try:
# Attempt to parse the input string into a datetime
return datetime.fromisoformat(input_datetime)
except ValueError:
raise ValueError("Invalid datetime string format")

raise ValueError("Input must be a datetime object or string")


def search_query(
self,
Expand Down Expand Up @@ -674,8 +678,8 @@ def search_query(
if self.infer_filters:
if uuid_time_filter is None and isinstance(filter, dict):
if "__start_date" in filter or "__end_date" in filter:
start_date = self._parse_datetime(filter.get("__start_date"))
end_date = self._parse_datetime(filter.get("__end_date"))
start_date = UUIDTimeRange._parse_datetime(filter.get("__start_date"))
end_date = UUIDTimeRange._parse_datetime(filter.get("__end_date"))

uuid_time_filter = UUIDTimeRange(start_date, end_date)

Expand Down

0 comments on commit 8056bcd

Please sign in to comment.