diff --git a/nbs/00_vector.ipynb b/nbs/00_vector.ipynb index 5d1e3fb..b6de196 100644 --- a/nbs/00_vector.ipynb +++ b/nbs/00_vector.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 97, "metadata": {}, "outputs": [], "source": [ @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 98, "metadata": {}, "outputs": [], "source": [ @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 99, "metadata": {}, "outputs": [], "source": [ @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 101, "metadata": {}, "outputs": [], "source": [ @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 102, "metadata": {}, "outputs": [], "source": [ @@ -146,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 103, "metadata": {}, "outputs": [], "source": [ @@ -275,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 104, "metadata": {}, "outputs": [], "source": [ @@ -290,7 +290,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -388,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 106, "metadata": {}, "outputs": [], "source": [ @@ -411,7 +411,9 @@ " \"!=\": \"<>\",\n", " }\n", "\n", - " def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str, str]], operator: str = 'AND'):\n", + " PredicateValue = Union[str, int, float]\n", + "\n", + " def __init__(self, *clauses: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue], operator: str = 'AND'):\n", " \"\"\"\n", " Predicates class defines predicates on the object metadata. Predicates can be combined using logical operators (&, |, and ~).\n", "\n", @@ -425,9 +427,14 @@ " if operator not in self.logical_operators: \n", " raise ValueError(f\"invalid operator: {operator}\")\n", " self.operator = operator\n", - " self.clauses = list(clauses)\n", + " if isinstance(clauses[0], str):\n", + " if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):\n", + " raise ValueError(f\"Invalid clause format: {clauses}\")\n", + " self.clauses = [(clauses[0], clauses[1], clauses[2])]\n", + " else:\n", + " self.clauses = list(clauses)\n", "\n", - " def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, str, str]]):\n", + " def add_clause(self, *clause: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue]):\n", " \"\"\"\n", " Add a clause to the predicates object.\n", "\n", @@ -436,7 +443,12 @@ " clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str]\n", " Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) or (field, value).\n", " \"\"\"\n", - " self.clauses.extend(list(clause))\n", + " if isinstance(clause[0], str):\n", + " if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):\n", + " raise ValueError(f\"Invalid clause format: {clause}\")\n", + " self.clauses.append((clause[0], clause[1], clause[2]))\n", + " else:\n", + " self.clauses.extend(list(clause))\n", " \n", " def __and__(self, other):\n", " new_predicates = Predicates(self, other, operator='AND')\n", @@ -522,7 +534,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -836,7 +848,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 108, "metadata": {}, "outputs": [ { @@ -864,7 +876,7 @@ "Generates a query to create the tables, indexes, and extensions needed to store the vector data." ] }, - "execution_count": null, + "execution_count": 108, "metadata": {}, "output_type": "execute_result" } @@ -883,7 +895,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 109, "metadata": {}, "outputs": [], "source": [ @@ -1143,7 +1155,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 110, "metadata": {}, "outputs": [ { @@ -1171,7 +1183,7 @@ "Creates necessary tables." ] }, - "execution_count": null, + "execution_count": 110, "metadata": {}, "output_type": "execute_result" } @@ -1182,7 +1194,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 111, "metadata": {}, "outputs": [ { @@ -1210,7 +1222,7 @@ "Creates necessary tables." ] }, - "execution_count": null, + "execution_count": 111, "metadata": {}, "output_type": "execute_result" } @@ -1221,21 +1233,9 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 112, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/cevian/.pyenv/versions/3.11.4/envs/nbdev_env/lib/python3.11/site-packages/fastcore/docscrape.py:225: UserWarning: potentially wrong underline length... \n", - "Returns \n", - "-------- in \n", - "Retrieves similar records using a similarity query.\n", - "...\n", - " else: warn(msg)\n" - ] - }, { "data": { "text/markdown": [ @@ -1285,7 +1285,7 @@ "| **Returns** | **List: List of similar records.** | | |" ] }, - "execution_count": null, + "execution_count": 112, "metadata": {}, "output_type": "execute_result" } @@ -1296,7 +1296,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -1317,7 +1317,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ @@ -1393,19 +1393,21 @@ "assert len(rec) == 1\n", "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"==\", \"val2\")))\n", "assert len(rec) == 1\n", - "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)))\n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key\", \"==\", \"val2\"))\n", + "assert len(rec) == 1\n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100))\n", "assert len(rec) == 1\n", - "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 10)))\n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 10))\n", "assert len(rec) == 0\n", - "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10)))\n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<=\", 10))\n", "assert len(rec) == 1\n", - "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10.0)))\n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<=\", 10.0))\n", "assert len(rec) == 1\n", - "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<=\", 11.3)))\n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_11\", \"<=\", 11.3))\n", "assert len(rec) == 1\n", - "rec = await vec.search(limit=4, predicates=Predicates((\"key_11\", \">=\", 11.29999)))\n", + "rec = await vec.search(limit=4, predicates=Predicates(\"key_11\", \">=\", 11.29999))\n", "assert len(rec) == 1\n", - "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<\", 11.299999)))\n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_11\", \"<\", 11.299999))\n", "assert len(rec) == 0\n", "\n", "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(*[(\"key\", \"val2\"), (\"key_10\", \"<\", 100)]))\n", @@ -1414,9 +1416,9 @@ "assert len(rec) == 1\n", "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\"), (\"key_2\", \"val_2\"), operator='OR'))\n", "assert len(rec) == 2\n", - "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) & (Predicates((\"key\", \"val2\")) | Predicates((\"key_2\", \"val_2\")))) \n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100) & (Predicates(\"key\",\"==\", \"val2\",) | Predicates(\"key_2\", \"==\", \"val_2\"))) \n", "assert len(rec) == 1\n", - "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) and (Predicates((\"key\", \"val2\")) or Predicates((\"key_2\", \"val_2\")))) \n", + "rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100) and (Predicates(\"key\",\"==\", \"val2\") or Predicates(\"key_2\",\"==\", \"val_2\"))) \n", "assert len(rec) == 1\n", "rec = await vec.search(limit=4, predicates=~Predicates((\"key\", \"val2\"), (\"key_10\", \"<\", 100)))\n", "assert len(rec) == 4\n", @@ -2193,7 +2195,7 @@ "assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556\n", "assert rec[0][\"distance\"] == 0.0009438353921149556\n", "\n", - "rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\")))\n", + "rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key\",\"==\", \"val2\"))\n", "assert len(rec) == 1\n", "\n", "rec = vec.search([1.0, 2.0], limit=4, filter=[\n", diff --git a/timescale_vector/client.py b/timescale_vector/client.py index 5789a5a..bcbaeb8 100644 --- a/timescale_vector/client.py +++ b/timescale_vector/client.py @@ -307,7 +307,9 @@ class Predicates: "!=": "<>", } - def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str, str]], operator: str = 'AND'): + PredicateValue = Union[str, int, float] + + def __init__(self, *clauses: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue], operator: str = 'AND'): """ Predicates class defines predicates on the object metadata. Predicates can be combined using logical operators (&, |, and ~). @@ -321,9 +323,14 @@ def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str if operator not in self.logical_operators: raise ValueError(f"invalid operator: {operator}") self.operator = operator - self.clauses = list(clauses) + if isinstance(clauses[0], str): + if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)): + raise ValueError("Invalid clause format: {clauses}") + self.clauses = [(clauses[0], clauses[1], clauses[2])] + else: + self.clauses = list(clauses) - def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, str, str]]): + def add_clause(self, *clause: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue]): """ Add a clause to the predicates object. @@ -332,7 +339,12 @@ def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, st clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str] Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) or (field, value). """ - self.clauses.extend(list(clause)) + if isinstance(clause[0], str): + if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)): + raise ValueError("Invalid clause format: {clauses}") + self.clauses.append((clause[0], clause[1], clause[2])) + else: + self.clauses.extend(list(clause)) def __and__(self, other): new_predicates = Predicates(self, other, operator='AND')