Skip to content

Commit

Permalink
Fix tests after type changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Askir committed Oct 24, 2024
1 parent ecacaf1 commit 9151c56
Show file tree
Hide file tree
Showing 15 changed files with 98 additions and 111 deletions.
4 changes: 2 additions & 2 deletions tests/async_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@pytest.mark.asyncio
@pytest.mark.parametrize("schema", ["tschema", None])
@pytest.mark.parametrize("schema", ["temp", None])
async def test_vector(service_url: str, schema: str) -> None:
vec = Async(service_url, "data_table", 2, schema_name=schema)
await vec.drop_table()
Expand Down Expand Up @@ -338,7 +338,7 @@ async def search_date(start_date: datetime | str | None, end_date: datetime | st
rec = await vec.search([1.0, 2.0], limit=4, filter=filter)
assert len(rec) == expected
# using predicates
predicates: list[tuple[str, str, str|datetime]] = []
predicates: list[tuple[str, str, str | datetime]] = []
if start_date is not None:
predicates.append(("__uuid_timestamp", ">=", start_date))
if end_date is not None:
Expand Down
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import os

import psycopg2
import pytest
from dotenv import find_dotenv, load_dotenv


@pytest.fixture
@pytest.fixture(scope="module")
def service_url() -> str:
_ = load_dotenv(find_dotenv(), override=True)
return os.environ["TIMESCALE_SERVICE_URL"]


@pytest.fixture(scope="module", autouse=True)
def create_temp_schema(service_url: str) -> None:
conn = psycopg2.connect(service_url)
with conn.cursor() as cursor:
cursor.execute("CREATE SCHEMA IF NOT EXISTS temp;")
conn.commit()
conn.close()
6 changes: 3 additions & 3 deletions tests/sync_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)


@pytest.mark.parametrize("schema", ["tschema", None])
@pytest.mark.parametrize("schema", ["temp", None])
def test_sync_client(service_url: str, schema: str) -> None:
vec = Sync(service_url, "data_table", 2, schema_name=schema)
vec.create_tables()
Expand Down Expand Up @@ -234,7 +234,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No
assert len(rec) == expected

# using filters
filter: dict[str, str|datetime] = {}
filter: dict[str, str | datetime] = {}
if start_date is not None:
filter["__start_date"] = start_date
if end_date is not None:
Expand All @@ -250,7 +250,7 @@ def search_date(start_date: datetime | str | None, end_date: datetime | str | No
rec = vec.search([1.0, 2.0], limit=4, filter=filter)
assert len(rec) == expected
# using predicates
predicates: list[tuple[str, str, str|datetime]] = []
predicates: list[tuple[str, str, str | datetime]] = []
if start_date is not None:
predicates.append(("__uuid_timestamp", ">=", start_date))
if end_date is not None:
Expand Down
25 changes: 16 additions & 9 deletions timescale_vector/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,12 @@ async def connect(self) -> PoolAcquireContext:
self.max_db_connections = await self._default_max_db_connections()

async def init(conn: Connection) -> None:
await register_vector(conn)
schema = await self._detect_vector_schema(conn)
if schema is None:
raise ValueError("pg_vector extension not found")
await register_vector(conn, schema=schema)
# decode to a dict, but accept a string as input in upsert
await conn.set_type_codec(
"jsonb",
encoder=str,
decoder=json.loads,
schema="pg_catalog"
)
await conn.set_type_codec("jsonb", encoder=str, decoder=json.loads, schema="pg_catalog")

self.pool = await create_pool(
dsn=self.service_url,
Expand Down Expand Up @@ -127,13 +125,22 @@ async def table_is_empty(self) -> bool:
rec = await pool.fetchrow(query)
return rec is None


def munge_record(self, records: list[tuple[Any, ...]]) -> list[tuple[uuid.UUID, str, str, list[float]]]:
metadata_is_dict = isinstance(records[0][1], dict)
if metadata_is_dict:
return list(map(lambda item: Async._convert_record_meta_to_json(item), records))
return records

async def _detect_vector_schema(self, conn: Connection) -> str | None:
query = """
select n.nspname
from pg_extension x
inner join pg_namespace n on (x.extnamespace = n.oid)
where x.extname = 'vector';
"""

return await conn.fetchval(query)

@staticmethod
def _convert_record_meta_to_json(item: tuple[Any, ...]) -> tuple[uuid.UUID, str, str, list[float]]:
if not isinstance(item[1], dict):
Expand Down Expand Up @@ -301,4 +308,4 @@ async def search(
return await pool.fetch(query, *params)
else:
async with await self.connect() as pool:
return await pool.fetch(query, *params)
return await pool.fetch(query, *params)
20 changes: 15 additions & 5 deletions timescale_vector/client/predicates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import json
from datetime import datetime
from typing import Any, Literal, Union
from typing import Any, Literal, Union, get_args, get_origin


def get_runtime_types(typ) -> tuple[type, ...]: # type: ignore
"""Convert a type with generic parameters to runtime types.
Necessary because Generic types cant be passed to isinstance in python 3.10"""
return tuple(get_origin(t) or t for t in get_args(typ)) # type: ignore


class Predicates:
Expand Down Expand Up @@ -51,7 +57,9 @@ def __init__(
raise ValueError(f"invalid operator: {operator}")
self.operator: str = operator
if isinstance(clauses[0], str):
if len(clauses) != 3 or not (isinstance(clauses[1], str) and isinstance(clauses[2], self.PredicateValue)):
if len(clauses) != 3 or not (
isinstance(clauses[1], str) and isinstance(clauses[2], get_runtime_types(self.PredicateValue))
):
raise ValueError(f"Invalid clause format: {clauses}")
self.clauses = [clauses]
else:
Expand All @@ -77,11 +85,13 @@ def add_clause(
or (field, value).
"""
if isinstance(clause[0], str):
if len(clause) != 3 or not (isinstance(clause[1], str) and isinstance(clause[2], self.PredicateValue)):
if len(clause) != 3 or not (
isinstance(clause[1], str) and isinstance(clause[2], get_runtime_types(self.PredicateValue))
):
raise ValueError(f"Invalid clause format: {clause}")
self.clauses.append(clause) # type: ignore
self.clauses.append(clause) # type: ignore
else:
self.clauses.extend(list(clause)) # type: ignore
self.clauses.extend(list(clause)) # type: ignore

def __and__(self, other: "Predicates") -> "Predicates":
new_predicates = Predicates(self, other, operator="AND")
Expand Down
15 changes: 8 additions & 7 deletions timescale_vector/typings/asyncpg/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Protocol, TypeVar, Sequence
from . import pool, connection
from collections.abc import Sequence
from typing import Any, Protocol, TypeVar

from . import connection, pool

# Core types
T = TypeVar('T')
T = TypeVar("T")

class Record(Protocol):
def __getitem__(self, key: int | str) -> Any: ...
Expand Down Expand Up @@ -30,9 +32,8 @@ async def connect(
user: str | None = None,
password: str | None = None,
database: str | None = None,
timeout: int = 60
timeout: int = 60,
) -> Connection: ...

async def create_pool(
dsn: str | None = None,
*,
Expand All @@ -42,5 +43,5 @@ async def create_pool(
max_inactive_connection_lifetime: float = 300.0,
setup: Any | None = None,
init: Any | None = None,
**connect_kwargs: Any
) -> Pool: ...
**connect_kwargs: Any,
) -> Pool: ...
44 changes: 6 additions & 38 deletions timescale_vector/typings/asyncpg/connection.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -6,55 +6,23 @@ from . import Record
class Connection:
# Transaction management
async def execute(self, query: str, *args: Any, timeout: float | None = None) -> str: ...

async def executemany(
self,
command: str,
args: Sequence[Sequence[Any]],
*,
timeout: float | None = None
self, command: str, args: Sequence[Sequence[Any]], *, timeout: float | None = None
) -> str: ...

async def fetch(
self,
query: str,
*args: Any,
timeout: float | None = None
) -> list[Record]: ...

async def fetchval(
self,
query: str,
*args: Any,
column: int = 0,
timeout: float | None = None
) -> Any: ...

async def fetchrow(
self,
query: str,
*args: Any,
timeout: float | None = None
) -> Record | None: ...

async def fetch(self, query: str, *args: Any, timeout: float | None = None) -> list[Record]: ...
async def fetchval(self, query: str, *args: Any, column: int = 0, timeout: float | None = None) -> Any: ...
async def fetchrow(self, query: str, *args: Any, timeout: float | None = None) -> Record | None: ...
async def set_type_codec(
self,
typename: str,
*,
schema: str = "public",
encoder: Any,
decoder: Any,
format: str = "text"
self, typename: str, *, schema: str = "public", encoder: Any, decoder: Any, format: str = "text"
) -> None: ...

# Transaction context
def transaction(self, *, isolation: str = "read_committed") -> Transaction: ...

async def close(self, *, timeout: float | None = None) -> None: ...

class Transaction:
async def __aenter__(self) -> Transaction: ...
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
async def start(self) -> None: ...
async def commit(self) -> None: ...
async def rollback(self) -> None: ...
async def rollback(self) -> None: ...
7 changes: 4 additions & 3 deletions timescale_vector/typings/asyncpg/pool.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, AsyncContextManager
from contextlib import AbstractAsyncContextManager
from typing import Any

from . import connection

Expand All @@ -13,6 +14,6 @@ class Pool:
async def __aenter__(self) -> Pool: ...
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...

class PoolAcquireContext(AsyncContextManager['connection.Connection']):
class PoolAcquireContext(AbstractAsyncContextManager["connection.Connection"]):
async def __aenter__(self) -> connection.Connection: ...
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...
12 changes: 6 additions & 6 deletions timescale_vector/typings/langchain/docstore/document.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, TypeVar, Optional
from typing import Any, TypeVar

from typing_extensions import TypedDict

class Metadata(TypedDict, total=False):
Expand All @@ -8,21 +9,20 @@ class Metadata(TypedDict, total=False):
category: str
published_time: str

T = TypeVar('T')
T = TypeVar("T")

class Document:
"""Documents are the basic unit of text in LangChain."""

page_content: str
metadata: dict[str, Any]

def __init__(
self,
page_content: str,
metadata: Optional[dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> None: ...

@property
def lc_kwargs(self) -> dict[str, Any]: ...

@classmethod
def is_lc_serializable(cls) -> bool: ...
def is_lc_serializable(cls) -> bool: ...
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,28 @@ from langchain.schema.embeddings import Embeddings

class TimescaleVector:
def __init__(
self,
collection_name: str,
service_url: str,
embedding: Embeddings,
time_partition_interval: timedelta | None = None,
self,
collection_name: str,
service_url: str,
embedding: Embeddings,
time_partition_interval: timedelta | None = None,
) -> None: ...

def add_texts(
self,
texts: Sequence[str],
metadatas: list[dict[str, Any]] | None = None,
ids: list[str] | None = None,
**kwargs: Any,
self,
texts: Sequence[str],
metadatas: list[dict[str, Any]] | None = None,
ids: list[str] | None = None,
**kwargs: Any,
) -> list[str]: ...

def delete_by_metadata(
self,
metadata_filter: dict[str, Any] | list[dict[str, Any]],
self,
metadata_filter: dict[str, Any] | list[dict[str, Any]],
) -> None: ...

def similarity_search_with_score(
self,
query: str,
k: int = 4,
filter: dict[str, Any] | list[dict[str, Any]] | None = None,
predicates: Any | None = None,
**kwargs: Any,
) -> list[tuple[Document, float]]: ...
self,
query: str,
k: int = 4,
filter: dict[str, Any] | list[dict[str, Any]] | None = None,
predicates: Any | None = None,
**kwargs: Any,
) -> list[tuple[Document, float]]: ...
2 changes: 1 addition & 1 deletion timescale_vector/typings/pgvector.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from typing import Any

def register_vector(conn_or_curs: Any) -> None: ...
def register_vector(conn_or_curs: Any) -> None: ...
5 changes: 3 additions & 2 deletions timescale_vector/typings/psycopg2/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, TypeVar

from psycopg2.extensions import connection

T = TypeVar('T')
T = TypeVar("T")

def connect(dsn: str = "", **kwargs: Any) -> connection: ...
def connect(dsn: str = "", **kwargs: Any) -> connection: ...
5 changes: 2 additions & 3 deletions timescale_vector/typings/psycopg2/extensions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ class cursor(Protocol):
def executemany(self, query: str, vars_list: list[Any]) -> Any: ...
def fetchone(self) -> tuple[Any, ...] | None: ...
def fetchall(self) -> list[tuple[Any, ...]]: ...
def __enter__(self) -> 'cursor': ...
def __enter__(self) -> cursor: ...
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...

class connection(Protocol):
def cursor(self, cursor_factory: Any | None = None) -> cursor: ...
def commit(self) -> None: ...
def close(self) -> None: ...
def __enter__(self) -> 'connection': ...
def __enter__(self) -> connection: ...
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ...

def register_uuid(oids: Any | None = None, conn_or_curs: Any | None = None) -> None: ...

Loading

0 comments on commit 9151c56

Please sign in to comment.