From 9151c56dc26cd861ebcf98a45cd6f6754d2e071a Mon Sep 17 00:00:00 2001 From: Jascha Date: Thu, 24 Oct 2024 16:37:38 -0700 Subject: [PATCH] Fix tests after type changes --- tests/async_client_test.py | 4 +- tests/conftest.py | 12 ++++- tests/sync_client_test.py | 6 +-- timescale_vector/client/async_client.py | 25 +++++++---- timescale_vector/client/predicates.py | 20 ++++++--- timescale_vector/typings/asyncpg/__init__.pyi | 15 ++++--- .../typings/asyncpg/connection.pyi | 44 +++---------------- timescale_vector/typings/asyncpg/pool.pyi | 7 +-- .../typings/langchain/docstore/document.pyi | 12 ++--- .../vectorstores/timescalevector.pyi | 41 ++++++++--------- timescale_vector/typings/pgvector.pyi | 2 +- .../typings/psycopg2/__init__.pyi | 5 ++- .../typings/psycopg2/extensions.pyi | 5 +-- timescale_vector/typings/psycopg2/extras.pyi | 2 +- timescale_vector/typings/psycopg2/pool.pyi | 9 +--- 15 files changed, 98 insertions(+), 111 deletions(-) diff --git a/tests/async_client_test.py b/tests/async_client_test.py index 0595bfc..f2e86f8 100644 --- a/tests/async_client_test.py +++ b/tests/async_client_test.py @@ -17,7 +17,7 @@ @pytest.mark.asyncio -@pytest.mark.parametrize("schema", ["tschema", None]) +@pytest.mark.parametrize("schema", ["temp", None]) async def test_vector(service_url: str, schema: str) -> None: vec = Async(service_url, "data_table", 2, schema_name=schema) await vec.drop_table() @@ -338,7 +338,7 @@ async def search_date(start_date: datetime | str | None, end_date: datetime | st rec = await vec.search([1.0, 2.0], limit=4, filter=filter) assert len(rec) == expected # using predicates - predicates: list[tuple[str, str, str|datetime]] = [] + predicates: list[tuple[str, str, str | datetime]] = [] if start_date is not None: predicates.append(("__uuid_timestamp", ">=", start_date)) if end_date is not None: diff --git a/tests/conftest.py b/tests/conftest.py index 1a880a5..cbd771b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,20 @@ import os +import psycopg2 import pytest from dotenv import find_dotenv, load_dotenv -@pytest.fixture +@pytest.fixture(scope="module") def service_url() -> str: _ = load_dotenv(find_dotenv(), override=True) return os.environ["TIMESCALE_SERVICE_URL"] + + +@pytest.fixture(scope="module", autouse=True) +def create_temp_schema(service_url: str) -> None: + conn = psycopg2.connect(service_url) + with conn.cursor() as cursor: + cursor.execute("CREATE SCHEMA IF NOT EXISTS temp;") + conn.commit() + conn.close() diff --git a/tests/sync_client_test.py b/tests/sync_client_test.py index ead160d..840b991 100644 --- a/tests/sync_client_test.py +++ b/tests/sync_client_test.py @@ -20,7 +20,7 @@ ) -@pytest.mark.parametrize("schema", ["tschema", None]) +@pytest.mark.parametrize("schema", ["temp", None]) def test_sync_client(service_url: str, schema: str) -> None: vec = Sync(service_url, "data_table", 2, schema_name=schema) vec.create_tables() @@ -234,7 +234,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No assert len(rec) == expected # using filters - filter: dict[str, str|datetime] = {} + filter: dict[str, str | datetime] = {} if start_date is not None: filter["__start_date"] = start_date if end_date is not None: @@ -250,7 +250,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No rec = vec.search([1.0, 2.0], limit=4, filter=filter) assert len(rec) == expected # using predicates - predicates: list[tuple[str, str, str|datetime]] = [] + predicates: list[tuple[str, str, str | datetime]] = [] if start_date is not None: predicates.append(("__uuid_timestamp", ">=", start_date)) if end_date is not None: diff --git a/timescale_vector/client/async_client.py b/timescale_vector/client/async_client.py index 8f7fce1..760b84e 100644 --- a/timescale_vector/client/async_client.py +++ b/timescale_vector/client/async_client.py @@ -92,14 +92,12 @@ async def connect(self) -> PoolAcquireContext: self.max_db_connections = await self._default_max_db_connections() async def init(conn: Connection) -> None: - await register_vector(conn) + schema = await self._detect_vector_schema(conn) + if schema is None: + raise ValueError("pg_vector extension not found") + await register_vector(conn, schema=schema) # decode to a dict, but accept a string as input in upsert - await conn.set_type_codec( - "jsonb", - encoder=str, - decoder=json.loads, - schema="pg_catalog" - ) + await conn.set_type_codec("jsonb", encoder=str, decoder=json.loads, schema="pg_catalog") self.pool = await create_pool( dsn=self.service_url, @@ -127,13 +125,22 @@ async def table_is_empty(self) -> bool: rec = await pool.fetchrow(query) return rec is None - def munge_record(self, records: list[tuple[Any, ...]]) -> list[tuple[uuid.UUID, str, str, list[float]]]: metadata_is_dict = isinstance(records[0][1], dict) if metadata_is_dict: return list(map(lambda item: Async._convert_record_meta_to_json(item), records)) return records + async def _detect_vector_schema(self, conn: Connection) -> str | None: + query = """ + select n.nspname + from pg_extension x + inner join pg_namespace n on (x.extnamespace = n.oid) + where x.extname = 'vector'; + """ + + return await conn.fetchval(query) + @staticmethod def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]: if not isinstance(item[1], dict): @@ -301,4 +308,4 @@ async def search( return await pool.fetch(query, *params) else: async with await self.connect() as pool: - return await pool.fetch(query, *params) \ No newline at end of file + return await pool.fetch(query, *params) diff --git a/timescale_vector/client/predicates.py b/timescale_vector/client/predicates.py index eb0b790..48fac64 100644 --- a/timescale_vector/client/predicates.py +++ b/timescale_vector/client/predicates.py @@ -1,6 +1,12 @@ import json from datetime import datetime -from typing import Any, Literal, Union +from typing import Any, Literal, Union, get_args, get_origin + + +def get_runtime_types(typ) -> tuple[type, ...]: # type: ignore + """Convert a type with generic parameters to runtime types. + Necessary because Generic types cant be passed to isinstance in python 3.10""" + return tuple(get_origin(t) or t for t in get_args(typ)) # type: ignore class Predicates: @@ -51,7 +57,9 @@ def __init__( raise ValueError(f"invalid operator: {operator}") self.operator: str = operator if isinstance(clauses[0], str): - if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)): + if len(clauses) != 3 or not ( + isinstance(clauses[1], str) and isinstance(clauses[2], get_runtime_types(self.PredicateValue)) + ): raise ValueError(f"Invalid clause format: {clauses}") self.clauses = [clauses] else: @@ -77,11 +85,13 @@ def add_clause( or (field, value). """ if isinstance(clause[0], str): - if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)): + if len(clause) != 3 or not ( + isinstance(clause[1], str) and isinstance(clause[2], get_runtime_types(self.PredicateValue)) + ): raise ValueError(f"Invalid clause format: {clause}") - self.clauses.append(clause) # type: ignore + self.clauses.append(clause) # type: ignore else: - self.clauses.extend(list(clause)) # type: ignore + self.clauses.extend(list(clause)) # type: ignore def __and__(self, other: "Predicates") -> "Predicates": new_predicates = Predicates(self, other, operator="AND") diff --git a/timescale_vector/typings/asyncpg/__init__.pyi b/timescale_vector/typings/asyncpg/__init__.pyi index 0ea13a4..0b16ec0 100644 --- a/timescale_vector/typings/asyncpg/__init__.pyi +++ b/timescale_vector/typings/asyncpg/__init__.pyi @@ -1,8 +1,10 @@ -from typing import Any, Protocol, TypeVar, Sequence -from . import pool, connection +from collections.abc import Sequence +from typing import Any, Protocol, TypeVar + +from . import connection, pool # Core types -T = TypeVar('T') +T = TypeVar("T") class Record(Protocol): def __getitem__(self, key: int | str) -> Any: ... @@ -30,9 +32,8 @@ async def connect( user: str | None = None, password: str | None = None, database: str | None = None, - timeout: int = 60 + timeout: int = 60, ) -> Connection: ... - async def create_pool( dsn: str | None = None, *, @@ -42,5 +43,5 @@ async def create_pool( max_inactive_connection_lifetime: float = 300.0, setup: Any | None = None, init: Any | None = None, - **connect_kwargs: Any -) -> Pool: ... \ No newline at end of file + **connect_kwargs: Any, +) -> Pool: ... diff --git a/timescale_vector/typings/asyncpg/connection.pyi b/timescale_vector/typings/asyncpg/connection.pyi index 5646cb4..8399aa0 100644 --- a/timescale_vector/typings/asyncpg/connection.pyi +++ b/timescale_vector/typings/asyncpg/connection.pyi @@ -6,50 +6,18 @@ from . import Record class Connection: # Transaction management async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: ... - async def executemany( - self, - command: str, - args: Sequence[Sequence[Any]], - *, - timeout: float | None = None + self, command: str, args: Sequence[Sequence[Any]], *, timeout: float | None = None ) -> str: ... - - async def fetch( - self, - query: str, - *args: Any, - timeout: float | None = None - ) -> list[Record]: ... - - async def fetchval( - self, - query: str, - *args: Any, - column: int = 0, - timeout: float | None = None - ) -> Any: ... - - async def fetchrow( - self, - query: str, - *args: Any, - timeout: float | None = None - ) -> Record | None: ... - + async def fetch(self, query: str, *args: Any, timeout: float | None = None) -> list[Record]: ... + async def fetchval(self, query: str, *args: Any, column: int = 0, timeout: float | None = None) -> Any: ... + async def fetchrow(self, query: str, *args: Any, timeout: float | None = None) -> Record | None: ... async def set_type_codec( - self, - typename: str, - *, - schema: str = "public", - encoder: Any, - decoder: Any, - format: str = "text" + self, typename: str, *, schema: str = "public", encoder: Any, decoder: Any, format: str = "text" ) -> None: ... # Transaction context def transaction(self, *, isolation: str = "read_committed") -> Transaction: ... - async def close(self, *, timeout: float | None = None) -> None: ... class Transaction: @@ -57,4 +25,4 @@ class Transaction: async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... async def start(self) -> None: ... async def commit(self) -> None: ... - async def rollback(self) -> None: ... \ No newline at end of file + async def rollback(self) -> None: ... diff --git a/timescale_vector/typings/asyncpg/pool.pyi b/timescale_vector/typings/asyncpg/pool.pyi index f1f3e0b..4f5fdf3 100644 --- a/timescale_vector/typings/asyncpg/pool.pyi +++ b/timescale_vector/typings/asyncpg/pool.pyi @@ -1,4 +1,5 @@ -from typing import Any, AsyncContextManager +from contextlib import AbstractAsyncContextManager +from typing import Any from . import connection @@ -13,6 +14,6 @@ class Pool: async def __aenter__(self) -> Pool: ... async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... -class PoolAcquireContext(AsyncContextManager['connection.Connection']): +class PoolAcquireContext(AbstractAsyncContextManager["connection.Connection"]): async def __aenter__(self) -> connection.Connection: ... - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... \ No newline at end of file + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... diff --git a/timescale_vector/typings/langchain/docstore/document.pyi b/timescale_vector/typings/langchain/docstore/document.pyi index 6737830..1462263 100644 --- a/timescale_vector/typings/langchain/docstore/document.pyi +++ b/timescale_vector/typings/langchain/docstore/document.pyi @@ -1,4 +1,5 @@ -from typing import Any, TypeVar, Optional +from typing import Any, TypeVar + from typing_extensions import TypedDict class Metadata(TypedDict, total=False): @@ -8,21 +9,20 @@ class Metadata(TypedDict, total=False): category: str published_time: str -T = TypeVar('T') +T = TypeVar("T") class Document: """Documents are the basic unit of text in LangChain.""" + page_content: str metadata: dict[str, Any] def __init__( self, page_content: str, - metadata: Optional[dict[str, Any]] = None, + metadata: dict[str, Any] | None = None, ) -> None: ... - @property def lc_kwargs(self) -> dict[str, Any]: ... - @classmethod - def is_lc_serializable(cls) -> bool: ... \ No newline at end of file + def is_lc_serializable(cls) -> bool: ... diff --git a/timescale_vector/typings/langchain_community/vectorstores/timescalevector.pyi b/timescale_vector/typings/langchain_community/vectorstores/timescalevector.pyi index 12e2061..08f1211 100644 --- a/timescale_vector/typings/langchain_community/vectorstores/timescalevector.pyi +++ b/timescale_vector/typings/langchain_community/vectorstores/timescalevector.pyi @@ -7,31 +7,28 @@ from langchain.schema.embeddings import Embeddings class TimescaleVector: def __init__( - self, - collection_name: str, - service_url: str, - embedding: Embeddings, - time_partition_interval: timedelta | None = None, + self, + collection_name: str, + service_url: str, + embedding: Embeddings, + time_partition_interval: timedelta | None = None, ) -> None: ... - def add_texts( - self, - texts: Sequence[str], - metadatas: list[dict[str, Any]] | None = None, - ids: list[str] | None = None, - **kwargs: Any, + self, + texts: Sequence[str], + metadatas: list[dict[str, Any]] | None = None, + ids: list[str] | None = None, + **kwargs: Any, ) -> list[str]: ... - def delete_by_metadata( - self, - metadata_filter: dict[str, Any] | list[dict[str, Any]], + self, + metadata_filter: dict[str, Any] | list[dict[str, Any]], ) -> None: ... - def similarity_search_with_score( - self, - query: str, - k: int = 4, - filter: dict[str, Any] | list[dict[str, Any]] | None = None, - predicates: Any | None = None, - **kwargs: Any, - ) -> list[tuple[Document, float]]: ... \ No newline at end of file + self, + query: str, + k: int = 4, + filter: dict[str, Any] | list[dict[str, Any]] | None = None, + predicates: Any | None = None, + **kwargs: Any, + ) -> list[tuple[Document, float]]: ... diff --git a/timescale_vector/typings/pgvector.pyi b/timescale_vector/typings/pgvector.pyi index fd0ba09..214158e 100644 --- a/timescale_vector/typings/pgvector.pyi +++ b/timescale_vector/typings/pgvector.pyi @@ -1,3 +1,3 @@ from typing import Any -def register_vector(conn_or_curs: Any) -> None: ... \ No newline at end of file +def register_vector(conn_or_curs: Any) -> None: ... diff --git a/timescale_vector/typings/psycopg2/__init__.pyi b/timescale_vector/typings/psycopg2/__init__.pyi index 7f4fd43..9d37d1c 100644 --- a/timescale_vector/typings/psycopg2/__init__.pyi +++ b/timescale_vector/typings/psycopg2/__init__.pyi @@ -1,6 +1,7 @@ from typing import Any, TypeVar + from psycopg2.extensions import connection -T = TypeVar('T') +T = TypeVar("T") -def connect(dsn: str = "", **kwargs: Any) -> connection: ... \ No newline at end of file +def connect(dsn: str = "", **kwargs: Any) -> connection: ... diff --git a/timescale_vector/typings/psycopg2/extensions.pyi b/timescale_vector/typings/psycopg2/extensions.pyi index 6ed8052..f28d064 100644 --- a/timescale_vector/typings/psycopg2/extensions.pyi +++ b/timescale_vector/typings/psycopg2/extensions.pyi @@ -5,15 +5,14 @@ class cursor(Protocol): def executemany(self, query: str, vars_list: list[Any]) -> Any: ... def fetchone(self) -> tuple[Any, ...] | None: ... def fetchall(self) -> list[tuple[Any, ...]]: ... - def __enter__(self) -> 'cursor': ... + def __enter__(self) -> cursor: ... def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... class connection(Protocol): def cursor(self, cursor_factory: Any | None = None) -> cursor: ... def commit(self) -> None: ... def close(self) -> None: ... - def __enter__(self) -> 'connection': ... + def __enter__(self) -> connection: ... def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... def register_uuid(oids: Any | None = None, conn_or_curs: Any | None = None) -> None: ... - diff --git a/timescale_vector/typings/psycopg2/extras.pyi b/timescale_vector/typings/psycopg2/extras.pyi index e53e3e2..1f39933 100644 --- a/timescale_vector/typings/psycopg2/extras.pyi +++ b/timescale_vector/typings/psycopg2/extras.pyi @@ -4,5 +4,5 @@ from psycopg2.extensions import cursor class DictCursor(cursor, Protocol): def __init__(self) -> None: ... - + def register_uuid(oids: int | None = None, conn_or_curs: cursor | None = None) -> None: ... diff --git a/timescale_vector/typings/psycopg2/pool.pyi b/timescale_vector/typings/psycopg2/pool.pyi index fbb4a77..fe5f3d1 100644 --- a/timescale_vector/typings/psycopg2/pool.pyi +++ b/timescale_vector/typings/psycopg2/pool.pyi @@ -4,14 +4,7 @@ from typing import Any from psycopg2.extensions import connection class SimpleConnectionPool: - def __init__( - self, - minconn: int, - maxconn: int, - dsn: str, - **kwargs: Any - ) -> None: ... - + def __init__(self, minconn: int, maxconn: int, dsn: str, **kwargs: Any) -> None: ... def getconn(self, key: Hashable | None = None) -> connection: ... def putconn(self, conn: connection, key: Hashable | None = None, close: bool = False) -> None: ... def closeall(self) -> None: ...