From 2f52224a9b36d7726c2e52e61ca0e17d6f8bfbb9 Mon Sep 17 00:00:00 2001 From: Matvey Arye Date: Mon, 20 Nov 2023 09:27:16 -0800 Subject: [PATCH] Add query index params These parameters to search() let you control the query-time parameters for index operations. They allow you to adjust the speed/recall tradeoff when querying. If these parameters aren't specified the default for the index will be used. --- nbs/00_vector.ipynb | 136 ++++++++++++++++++++++++++++-------- timescale_vector/_modidx.py | 17 +++++ timescale_vector/client.py | 59 +++++++++++++--- 3 files changed, 173 insertions(+), 39 deletions(-) diff --git a/nbs/00_vector.ipynb b/nbs/00_vector.ipynb index b6de196..2990175 100644 --- a/nbs/00_vector.ipynb +++ b/nbs/00_vector.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -21,7 +21,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -52,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -73,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -146,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -265,6 +265,42 @@ " .format(index_name=index_name_quoted, table_name=table_name_quoted, column_name=column_name_quoted, with_clause=with_clause)\n" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Query Params" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class QueryParams:\n", + " def __init__(self, params: dict[str, Any]) -> None:\n", + " self.params = params\n", + " \n", + " def get_statements(self) -> List[str]:\n", + " return [\"SET LOCAL \" + key + \" = \" + str(value) for key, value in self.params.items()]\n", + "\n", + "class TimescaleVectorIndexParams(QueryParams):\n", + " def __init__(self, search_list_size: int) -> None:\n", + " super().__init__({\"tsv.query_search_list_size\": search_list_size})\n", + "\n", + "class IvfflatIndexParams(QueryParams):\n", + " def __init__(self, probes: int) -> None:\n", + " super().__init__({\"ivfflat.probes\": probes})\n", + "\n", + "class HNSWIndexParams(QueryParams):\n", + " def __init__(self, ef_search: int) -> None:\n", + " super().__init__({\"hnsw.ef_search\": ef_search})" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -275,7 +311,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -290,7 +326,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -388,7 +424,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -534,7 +570,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -848,7 +884,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -876,7 +912,7 @@ "Generates a query to create the tables, indexes, and extensions needed to store the vector data." ] }, - "execution_count": 108, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -895,7 +931,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1128,6 +1164,7 @@ " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n", " predicates: Optional[Predicates] = None,\n", " uuid_time_filter: Optional[UUIDTimeRange] = None,\n", + " query_params: Optional[QueryParams] = None\n", " ): \n", " \"\"\"\n", " Retrieves similar records using a similarity query.\n", @@ -1149,13 +1186,22 @@ " \"\"\"\n", " (query, params) = self.builder.search_query(\n", " query_embedding, limit, filter, predicates, uuid_time_filter)\n", - " async with await self.connect() as pool:\n", - " return await pool.fetch(query, *params)" + " if query_params is not None:\n", + " async with await self.connect() as pool:\n", + " async with pool.transaction():\n", + " #Looks like there is no way to pipeline this: https://github.com/MagicStack/asyncpg/issues/588\n", + " statements = query_params.get_statements()\n", + " for statement in statements:\n", + " await pool.execute(statement)\n", + " return await pool.fetch(query, *params)\n", + " else:\n", + " async with await self.connect() as pool:\n", + " return await pool.fetch(query, *params)" ] }, { "cell_type": "code", - "execution_count": 110, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1183,7 +1229,7 @@ "Creates necessary tables." ] }, - "execution_count": 110, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1194,7 +1240,7 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1222,7 +1268,7 @@ "Creates necessary tables." ] }, - "execution_count": 111, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1233,9 +1279,21 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": null, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/cevian/.pyenv/versions/3.11.4/envs/nbdev_env/lib/python3.11/site-packages/fastcore/docscrape.py:225: UserWarning: potentially wrong underline length... \n", + "Returns \n", + "-------- in \n", + "Retrieves similar records using a similarity query.\n", + "...\n", + " else: warn(msg)\n" + ] + }, { "data": { "text/markdown": [ @@ -1248,7 +1306,8 @@ "> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n", "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n", "> ne, predicates:Optional[__main__.Predicates]=None,\n", - "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n", + "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n", + "> query_params:Optional[__main__.QueryParams]=None)\n", "\n", "Retrieves similar records using a similarity query.\n", "\n", @@ -1259,6 +1318,7 @@ "| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n", "| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n", "| uuid_time_filter | Optional | None | |\n", + "| query_params | Optional | None | |\n", "| **Returns** | **List: List of similar records.** | | |" ], "text/plain": [ @@ -1271,7 +1331,8 @@ "> Async.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n", "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n", "> ne, predicates:Optional[__main__.Predicates]=None,\n", - "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n", + "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n", + "> query_params:Optional[__main__.QueryParams]=None)\n", "\n", "Retrieves similar records using a similarity query.\n", "\n", @@ -1282,10 +1343,11 @@ "| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n", "| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n", "| uuid_time_filter | Optional | None | |\n", + "| query_params | Optional | None | |\n", "| **Returns** | **List: List of similar records.** | | |" ] }, - "execution_count": 112, + "execution_count": null, "metadata": {}, "output_type": "execute_result" } @@ -1296,7 +1358,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1317,7 +1379,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1564,6 +1626,10 @@ "assert len(rec) == 0\n", "rec = await vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(end_date=specific_datetime+timedelta(seconds=1), time_delta=timedelta(days=7)))\n", "assert len(rec) == 1\n", + "rec = await vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(10))\n", + "assert len(rec) == 2\n", + "rec = await vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(100))\n", + "assert len(rec) == 2\n", "await vec.drop_table()\n", "await vec.close()" ] @@ -1883,6 +1949,7 @@ " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None,\n", " predicates: Optional[Predicates] = None,\n", " uuid_time_filter: Optional[UUIDTimeRange] = None,\n", + " query_params: Optional[QueryParams] = None,\n", " ):\n", " \"\"\"\n", " Retrieves similar records using a similarity query.\n", @@ -1910,6 +1977,11 @@ " (query, params) = self.builder.search_query(\n", " query_embedding_np, limit, filter, predicates, uuid_time_filter)\n", " query, params = self._translate_to_pyformat(query, params)\n", + "\n", + " if query_params is not None:\n", + " prefix = \"; \".join(query_params.get_statements())\n", + " query = f\"{prefix}; {query}\"\n", + " \n", " with self.connect() as conn:\n", " with conn.cursor() as cur:\n", " cur.execute(query, params)\n", @@ -2021,7 +2093,8 @@ "> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n", "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n", "> e, predicates:Optional[__main__.Predicates]=None,\n", - "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n", + "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n", + "> query_params:Optional[__main__.QueryParams]=None)\n", "\n", "Retrieves similar records using a similarity query.\n", "\n", @@ -2032,6 +2105,7 @@ "| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n", "| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n", "| uuid_time_filter | Optional | None | |\n", + "| query_params | Optional | None | |\n", "| **Returns** | **List: List of similar records.** | | |" ], "text/plain": [ @@ -2044,7 +2118,8 @@ "> Sync.search (query_embedding:Optional[List[float]]=None, limit:int=10,\n", "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n", "> e, predicates:Optional[__main__.Predicates]=None,\n", - "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None)\n", + "> uuid_time_filter:Optional[__main__.UUIDTimeRange]=None,\n", + "> query_params:Optional[__main__.QueryParams]=None)\n", "\n", "Retrieves similar records using a similarity query.\n", "\n", @@ -2055,6 +2130,7 @@ "| filter | Union | None | A filter for metadata. Should be specified as a key-value object or a list of key-value objects (where any objects in the list are matched). |\n", "| predicates | Optional | None | A Predicates object to filter the results. Predicates support more complex queries than the filter parameter. Predicates can be combined using logical operators (&, \\|, and ~). |\n", "| uuid_time_filter | Optional | None | |\n", + "| query_params | Optional | None | |\n", "| **Returns** | **List: List of similar records.** | | |" ] }, @@ -2314,6 +2390,10 @@ "assert len(rec) == 0\n", "rec = vec.search([1.0, 2.0], limit=4, uuid_time_filter=UUIDTimeRange(end_date=specific_datetime+timedelta(seconds=1), time_delta=timedelta(days=7)))\n", "assert len(rec) == 1\n", + "rec = vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(10))\n", + "assert len(rec) == 2\n", + "rec = vec.search([1.0, 2.0], limit=4, query_params=TimescaleVectorIndexParams(100))\n", + "assert len(rec) == 2\n", "vec.drop_table()\n", "vec.close()" ] diff --git a/timescale_vector/_modidx.py b/timescale_vector/_modidx.py index deab21f..6fa0877 100644 --- a/timescale_vector/_modidx.py +++ b/timescale_vector/_modidx.py @@ -47,6 +47,10 @@ 'timescale_vector/client.py'), 'timescale_vector.client.HNSWIndex.create_index_query': ( 'vector.html#hnswindex.create_index_query', 'timescale_vector/client.py'), + 'timescale_vector.client.HNSWIndexParams': ( 'vector.html#hnswindexparams', + 'timescale_vector/client.py'), + 'timescale_vector.client.HNSWIndexParams.__init__': ( 'vector.html#hnswindexparams.__init__', + 'timescale_vector/client.py'), 'timescale_vector.client.IvfflatIndex': ('vector.html#ivfflatindex', 'timescale_vector/client.py'), 'timescale_vector.client.IvfflatIndex.__init__': ( 'vector.html#ivfflatindex.__init__', 'timescale_vector/client.py'), @@ -56,6 +60,10 @@ 'timescale_vector/client.py'), 'timescale_vector.client.IvfflatIndex.get_num_records': ( 'vector.html#ivfflatindex.get_num_records', 'timescale_vector/client.py'), + 'timescale_vector.client.IvfflatIndexParams': ( 'vector.html#ivfflatindexparams', + 'timescale_vector/client.py'), + 'timescale_vector.client.IvfflatIndexParams.__init__': ( 'vector.html#ivfflatindexparams.__init__', + 'timescale_vector/client.py'), 'timescale_vector.client.Predicates': ('vector.html#predicates', 'timescale_vector/client.py'), 'timescale_vector.client.Predicates.__and__': ( 'vector.html#predicates.__and__', 'timescale_vector/client.py'), @@ -106,6 +114,11 @@ 'timescale_vector/client.py'), 'timescale_vector.client.QueryBuilder.search_query': ( 'vector.html#querybuilder.search_query', 'timescale_vector/client.py'), + 'timescale_vector.client.QueryParams': ('vector.html#queryparams', 'timescale_vector/client.py'), + 'timescale_vector.client.QueryParams.__init__': ( 'vector.html#queryparams.__init__', + 'timescale_vector/client.py'), + 'timescale_vector.client.QueryParams.get_statements': ( 'vector.html#queryparams.get_statements', + 'timescale_vector/client.py'), 'timescale_vector.client.Sync': ('vector.html#sync', 'timescale_vector/client.py'), 'timescale_vector.client.Sync.__init__': ( 'vector.html#sync.__init__', 'timescale_vector/client.py'), @@ -145,6 +158,10 @@ 'timescale_vector/client.py'), 'timescale_vector.client.TimescaleVectorIndex.create_index_query': ( 'vector.html#timescalevectorindex.create_index_query', 'timescale_vector/client.py'), + 'timescale_vector.client.TimescaleVectorIndexParams': ( 'vector.html#timescalevectorindexparams', + 'timescale_vector/client.py'), + 'timescale_vector.client.TimescaleVectorIndexParams.__init__': ( 'vector.html#timescalevectorindexparams.__init__', + 'timescale_vector/client.py'), 'timescale_vector.client.UUIDTimeRange': ( 'vector.html#uuidtimerange', 'timescale_vector/client.py'), 'timescale_vector.client.UUIDTimeRange.__init__': ( 'vector.html#uuidtimerange.__init__', diff --git a/timescale_vector/client.py b/timescale_vector/client.py index bcbaeb8..9308cb3 100644 --- a/timescale_vector/client.py +++ b/timescale_vector/client.py @@ -3,7 +3,8 @@ # %% auto 0 __all__ = ['SEARCH_RESULT_ID_IDX', 'SEARCH_RESULT_METADATA_IDX', 'SEARCH_RESULT_CONTENTS_IDX', 'SEARCH_RESULT_EMBEDDING_IDX', 'SEARCH_RESULT_DISTANCE_IDX', 'uuid_from_time', 'BaseIndex', 'IvfflatIndex', 'HNSWIndex', - 'TimescaleVectorIndex', 'UUIDTimeRange', 'Predicates', 'QueryBuilder', 'Async', 'Sync'] + 'TimescaleVectorIndex', 'QueryParams', 'TimescaleVectorIndexParams', 'IvfflatIndexParams', 'HNSWIndexParams', + 'UUIDTimeRange', 'Predicates', 'QueryBuilder', 'Async', 'Sync'] # %% ../nbs/00_vector.ipynb 5 import asyncpg @@ -192,13 +193,33 @@ def create_index_query(self, table_name_quoted:str, column_name_quoted: str, ind # %% ../nbs/00_vector.ipynb 10 +class QueryParams: + def __init__(self, params: dict[str, Any]) -> None: + self.params = params + + def get_statements(self) -> List[str]: + return ["SET LOCAL " + key + " = " + str(value) for key, value in self.params.items()] + +class TimescaleVectorIndexParams(QueryParams): + def __init__(self, search_list_size: int) -> None: + super().__init__({"tsv.query_search_list_size": search_list_size}) + +class IvfflatIndexParams(QueryParams): + def __init__(self, probes: int) -> None: + super().__init__({"ivfflat.probes": probes}) + +class HNSWIndexParams(QueryParams): + def __init__(self, ef_search: int) -> None: + super().__init__({"hnsw.ef_search": ef_search}) + +# %% ../nbs/00_vector.ipynb 12 SEARCH_RESULT_ID_IDX = 0 SEARCH_RESULT_METADATA_IDX = 1 SEARCH_RESULT_CONTENTS_IDX = 2 SEARCH_RESULT_EMBEDDING_IDX = 3 SEARCH_RESULT_DISTANCE_IDX = 4 -# %% ../nbs/00_vector.ipynb 11 +# %% ../nbs/00_vector.ipynb 13 class UUIDTimeRange: @staticmethod @@ -289,7 +310,7 @@ def build_query(self, params: List) -> Tuple[str, List]: params.append(self.end_date) return " AND ".join(queries), params -# %% ../nbs/00_vector.ipynb 12 +# %% ../nbs/00_vector.ipynb 14 class Predicates: logical_operators = { "AND": "AND", @@ -325,7 +346,7 @@ def __init__(self, *clauses: Union['Predicates', Tuple[str, PredicateValue], Tup self.operator = operator if isinstance(clauses[0], str): if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)): - raise ValueError("Invalid clause format: {clauses}") + raise ValueError(f"Invalid clause format: {clauses}") self.clauses = [(clauses[0], clauses[1], clauses[2])] else: self.clauses = list(clauses) @@ -341,7 +362,7 @@ def add_clause(self, *clause: Union['Predicates', Tuple[str, PredicateValue], Tu """ if isinstance(clause[0], str): if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)): - raise ValueError("Invalid clause format: {clauses}") + raise ValueError(f"Invalid clause format: {clause}") self.clauses.append((clause[0], clause[1], clause[2])) else: self.clauses.extend(list(clause)) @@ -427,7 +448,7 @@ def build_query(self, params: List) -> Tuple[str, List]: where_clause = (" "+self.operator+" ").join(where_conditions) return where_clause, params -# %% ../nbs/00_vector.ipynb 13 +# %% ../nbs/00_vector.ipynb 15 class QueryBuilder: def __init__( self, @@ -734,7 +755,7 @@ def search_query( '''.format(distance=distance, order_by_clause=order_by_clause, where=where, table_name=self._quote_ident(self.table_name), limit=limit) return (query, params) -# %% ../nbs/00_vector.ipynb 16 +# %% ../nbs/00_vector.ipynb 18 class Async(QueryBuilder): def __init__( self, @@ -963,6 +984,7 @@ async def search(self, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, predicates: Optional[Predicates] = None, uuid_time_filter: Optional[UUIDTimeRange] = None, + query_params: Optional[QueryParams] = None ): """ Retrieves similar records using a similarity query. @@ -984,10 +1006,19 @@ async def search(self, """ (query, params) = self.builder.search_query( query_embedding, limit, filter, predicates, uuid_time_filter) - async with await self.connect() as pool: - return await pool.fetch(query, *params) + if query_params is not None: + async with await self.connect() as pool: + async with pool.transaction(): + #Looks like there is no way to pipeline this: https://github.com/MagicStack/asyncpg/issues/588 + statements = query_params.get_statements() + for statement in statements: + await pool.execute(statement) + return await pool.fetch(query, *params) + else: + async with await self.connect() as pool: + return await pool.fetch(query, *params) -# %% ../nbs/00_vector.ipynb 24 +# %% ../nbs/00_vector.ipynb 26 import psycopg2.pool from contextlib import contextmanager import psycopg2.extras @@ -995,7 +1026,7 @@ async def search(self, import numpy as np import re -# %% ../nbs/00_vector.ipynb 25 +# %% ../nbs/00_vector.ipynb 27 class Sync: translated_queries: Dict[str, str] = {} @@ -1281,6 +1312,7 @@ def search(self, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, predicates: Optional[Predicates] = None, uuid_time_filter: Optional[UUIDTimeRange] = None, + query_params: Optional[QueryParams] = None, ): """ Retrieves similar records using a similarity query. @@ -1308,6 +1340,11 @@ def search(self, (query, params) = self.builder.search_query( query_embedding_np, limit, filter, predicates, uuid_time_filter) query, params = self._translate_to_pyformat(query, params) + + if query_params is not None: + prefix = "; ".join(query_params.get_statements()) + query = f"{prefix}; {query}" + with self.connect() as conn: with conn.cursor() as cur: cur.execute(query, params)