Skip to content

Commit

Permalink
Improve Predicate creation interface
Browse files Browse the repository at this point in the history
Previously Predicates had to be created from tuples:

Predicates(("key", "==", "val))

But, as you can see above the single-predicate case is weird
because of the extra paranthesis. This change allows for:

Predicates("key", "==", "val)

That is, for the single-predicate case it accepts 3 arguments instead
of a tuple.
  • Loading branch information
cevian committed Nov 16, 2023
1 parent 3d12b32 commit 63e101d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 52 deletions.
98 changes: 50 additions & 48 deletions nbs/00_vector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -21,7 +21,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -42,7 +42,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -52,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -73,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -146,7 +146,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -275,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -290,7 +290,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -388,7 +388,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -411,7 +411,9 @@
" \"!=\": \"<>\",\n",
" }\n",
"\n",
" def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str, str]], operator: str = 'AND'):\n",
" PredicateValue = Union[str, int, float]\n",
"\n",
" def __init__(self, *clauses: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue], operator: str = 'AND'):\n",
" \"\"\"\n",
" Predicates class defines predicates on the object metadata. Predicates can be combined using logical operators (&, |, and ~).\n",
"\n",
Expand All @@ -425,9 +427,14 @@
" if operator not in self.logical_operators: \n",
" raise ValueError(f\"invalid operator: {operator}\")\n",
" self.operator = operator\n",
" self.clauses = list(clauses)\n",
" if isinstance(clauses[0], str):\n",
" if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):\n",
" raise ValueError(f\"Invalid clause format: {clauses}\")\n",
" self.clauses = [(clauses[0], clauses[1], clauses[2])]\n",
" else:\n",
" self.clauses = list(clauses)\n",
"\n",
" def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, str, str]]):\n",
" def add_clause(self, *clause: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue]):\n",
" \"\"\"\n",
" Add a clause to the predicates object.\n",
"\n",
Expand All @@ -436,7 +443,12 @@
" clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str]\n",
" Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) or (field, value).\n",
" \"\"\"\n",
" self.clauses.extend(list(clause))\n",
" if isinstance(clause[0], str):\n",
" if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):\n",
" raise ValueError(f\"Invalid clause format: {clause}\")\n",
" self.clauses.append((clause[0], clause[1], clause[2]))\n",
" else:\n",
" self.clauses.extend(list(clause))\n",
" \n",
" def __and__(self, other):\n",
" new_predicates = Predicates(self, other, operator='AND')\n",
Expand Down Expand Up @@ -522,7 +534,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -836,7 +848,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 108,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -864,7 +876,7 @@
"Generates a query to create the tables, indexes, and extensions needed to store the vector data."
]
},
"execution_count": null,
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -883,7 +895,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1143,7 +1155,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 110,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1171,7 +1183,7 @@
"Creates necessary tables."
]
},
"execution_count": null,
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -1182,7 +1194,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 111,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1210,7 +1222,7 @@
"Creates necessary tables."
]
},
"execution_count": null,
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -1221,21 +1233,9 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 112,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/cevian/.pyenv/versions/3.11.4/envs/nbdev_env/lib/python3.11/site-packages/fastcore/docscrape.py:225: UserWarning: potentially wrong underline length... \n",
"Returns \n",
"-------- in \n",
"Retrieves similar records using a similarity query.\n",
"...\n",
" else: warn(msg)\n"
]
},
{
"data": {
"text/markdown": [
Expand Down Expand Up @@ -1285,7 +1285,7 @@
"| **Returns** | **List: List of similar records.** | | |"
]
},
"execution_count": null,
"execution_count": 112,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -1296,7 +1296,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -1317,7 +1317,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -1393,19 +1393,21 @@
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"==\", \"val2\")))\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)))\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key\", \"==\", \"val2\"))\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100))\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 10)))\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 10))\n",
"assert len(rec) == 0\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10)))\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<=\", 10))\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<=\", 10.0)))\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<=\", 10.0))\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<=\", 11.3)))\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_11\", \"<=\", 11.3))\n",
"assert len(rec) == 1\n",
"rec = await vec.search(limit=4, predicates=Predicates((\"key_11\", \">=\", 11.29999)))\n",
"rec = await vec.search(limit=4, predicates=Predicates(\"key_11\", \">=\", 11.29999))\n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_11\", \"<\", 11.299999)))\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_11\", \"<\", 11.299999))\n",
"assert len(rec) == 0\n",
"\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(*[(\"key\", \"val2\"), (\"key_10\", \"<\", 100)]))\n",
Expand All @@ -1414,9 +1416,9 @@
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\"), (\"key_2\", \"val_2\"), operator='OR'))\n",
"assert len(rec) == 2\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) & (Predicates((\"key\", \"val2\")) | Predicates((\"key_2\", \"val_2\")))) \n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100) & (Predicates(\"key\",\"==\", \"val2\",) | Predicates(\"key_2\", \"==\", \"val_2\"))) \n",
"assert len(rec) == 1\n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key_10\", \"<\", 100)) and (Predicates((\"key\", \"val2\")) or Predicates((\"key_2\", \"val_2\")))) \n",
"rec = await vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key_10\", \"<\", 100) and (Predicates(\"key\",\"==\", \"val2\") or Predicates(\"key_2\",\"==\", \"val_2\"))) \n",
"assert len(rec) == 1\n",
"rec = await vec.search(limit=4, predicates=~Predicates((\"key\", \"val2\"), (\"key_10\", \"<\", 100)))\n",
"assert len(rec) == 4\n",
Expand Down Expand Up @@ -2193,7 +2195,7 @@
"assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556\n",
"assert rec[0][\"distance\"] == 0.0009438353921149556\n",
"\n",
"rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates((\"key\", \"val2\")))\n",
"rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates(\"key\",\"==\", \"val2\"))\n",
"assert len(rec) == 1\n",
"\n",
"rec = vec.search([1.0, 2.0], limit=4, filter=[\n",
Expand Down
20 changes: 16 additions & 4 deletions timescale_vector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ class Predicates:
"!=": "<>",
}

def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str, str]], operator: str = 'AND'):
PredicateValue = Union[str, int, float]

def __init__(self, *clauses: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue], operator: str = 'AND'):
"""
Predicates class defines predicates on the object metadata. Predicates can be combined using logical operators (&, |, and ~).
Expand All @@ -321,9 +323,14 @@ def __init__(self, *clauses: Union['Predicates', Tuple[str, str], Tuple[str, str
if operator not in self.logical_operators:
raise ValueError(f"invalid operator: {operator}")
self.operator = operator
self.clauses = list(clauses)
if isinstance(clauses[0], str):
if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):
raise ValueError("Invalid clause format: {clauses}")
self.clauses = [(clauses[0], clauses[1], clauses[2])]
else:
self.clauses = list(clauses)

def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, str, str]]):
def add_clause(self, *clause: Union['Predicates', Tuple[str, PredicateValue], Tuple[str, str, PredicateValue], str, PredicateValue]):
"""
Add a clause to the predicates object.
Expand All @@ -332,7 +339,12 @@ def add_clause(self, *clause: Union['Predicates', Tuple[str, str], Tuple[str, st
clause: 'Predicates' or Tuple[str, str] or Tuple[str, str, str]
Predicate clause. Can be either another Predicates object or a tuple of the form (field, operator, value) or (field, value).
"""
self.clauses.extend(list(clause))
if isinstance(clause[0], str):
if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):
raise ValueError("Invalid clause format: {clauses}")
self.clauses.append((clause[0], clause[1], clause[2]))
else:
self.clauses.extend(list(clause))

def __and__(self, other):
new_predicates = Predicates(self, other, operator='AND')
Expand Down

0 comments on commit 63e101d

Please sign in to comment.