Skip to content

Commit

Permalink
Fix typing, formatting and tests (#31)
Browse files Browse the repository at this point in the history
* chore: fix ruff and run it in CI

* chore: claudes attempt at typing

* fix: use mypy instead of pyright and fix all errors

* Fix types with stubs

* Migrate workflow to pyright

* Fix tests after type changes

* Rename action to mypy

* Add testing with vcr.py and a compose file

* Fix tiktoken request
  • Loading branch information
Askir authored Oct 28, 2024
1 parent 108539c commit 7b177d0
Show file tree
Hide file tree
Showing 34 changed files with 204,817 additions and 1,668 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/pyright.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Type Checking

on:
pull_request:
branches: [ main ]

jobs:
pyright:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install uv
run: pip install uv
- name: Create venv
run: uv venv
- name: Install dependencies
run: |
uv sync
- name: Run Pyright
run: uv run pyright
26 changes: 26 additions & 0 deletions .github/workflows/ruff.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: Ruff Linting and Formatting

on:
pull_request:
branches: [ main ]

jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install uv
run: pip install uv
- name: Create venv
run: uv venv
- name: Install dependencies
run: |
uv sync
- name: Run Ruff linter
run: uv run ruff check .
- name: Run Ruff formatter
run: uv run ruff format . --check
32 changes: 27 additions & 5 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,30 @@
name: CI
on: [workflow_dispatch, pull_request, push]
name: Tests

on:
pull_request:
branches: [ main ]

jobs:
test:
if: false
pytest:
runs-on: ubuntu-latest
steps: [uses: fastai/workflows/nbdev-ci@master]
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install uv
run: pip install uv
- name: Create venv
run: uv venv
- name: Install dependencies
run: |
uv sync
- name: Start docker-compose
run: docker compose up -d
- name: Run Test
run: uv run pytest
- name: Logs
run: docker compose logs
- name: Stop docker-compose
run: docker compose down
12 changes: 12 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
services:
db:
image: timescale/timescaledb-ha:pg16
ports:
- "5432:5432"
environment:
- POSTGRES_PASSWORD=postgres
- POSTGRES_USER=postgres
- POSTGRES_DB=postgres
- TIMESCALEDB_TELEMETRY=off
volumes:
- ./data:/var/lib/postgresql/data
62 changes: 21 additions & 41 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,6 @@ dependencies = [
"numpy>=1,<2",
]

[project.optional-dependencies]
dev = [
"ruff>=0.6.9",
"pyright>=1.1.384",
"pytest>=8.3.3",
"langchain>=0.3.3",
"langchain-openai>=0.2.2",
"langchain-community>=0.3.2",
"pandas>=2.2.3",
"pytest-asyncio>=0.24.0",
]

[project.urls]
repository = "https://github.com/timescale/python-vector"
documentation = "https://timescale.github.io/python-vector"
Expand All @@ -51,36 +39,15 @@ addopts = [
"--import-mode=importlib",
]


[tool.mypy]
strict = true
ignore_missing_imports = true
namespace_packages = true

[tool.pyright]
typeCheckingMode = "strict"
reportImplicitOverride = true
exclude = [
"**/.bzr",
"**/.direnv",
"**/.eggs",
"**/.git",
"**/.git-rewrite",
"**/.hg",
"**/.ipynb_checkpoints",
"**/.mypy_cache",
"**/.nox",
"**/.pants.d",
"**/.pyenv",
"**/.pytest_cache",
"**/.pytype",
"**/.ruff_cache",
"**/.svn",
"**/.tox",
"**/.venv",
"**/.vscode",
"**/__pypackages__",
"**/_build",
"**/buck-out",
"**/dist",
"**/node_modules",
"**/site-packages",
"**/venv",
]
stubPath = "timescale_vector/typings"

[tool.ruff]
line-length = 120
Expand Down Expand Up @@ -137,4 +104,17 @@ select = [
"W291",
"PIE",
"Q"
]
]

[tool.uv]
dev-dependencies = [
"ruff>=0.6.9",
"pytest>=8.3.3",
"langchain>=0.3.3",
"langchain-openai>=0.2.2",
"langchain-community>=0.3.2",
"pandas>=2.2.3",
"pytest-asyncio>=0.24.0",
"pyright>=1.1.386",
"vcrpy>=6.0.2",
]
8 changes: 4 additions & 4 deletions tests/async_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


@pytest.mark.asyncio
@pytest.mark.parametrize("schema", ["tschema", None])
@pytest.mark.parametrize("schema", ["temp", None])
async def test_vector(service_url: str, schema: str) -> None:
vec = Async(service_url, "data_table", 2, schema_name=schema)
await vec.drop_table()
Expand Down Expand Up @@ -306,7 +306,7 @@ async def test_vector(service_url: str, schema: str) -> None:
assert not await vec.table_is_empty()

# check all the possible ways to specify a date range
async def search_date(start_date, end_date, expected):
async def search_date(start_date: datetime | str | None, end_date: datetime | str | None, expected: int) -> None:
# using uuid_time_filter
rec = await vec.search(
[1.0, 2.0],
Expand All @@ -322,7 +322,7 @@ async def search_date(start_date, end_date, expected):
assert len(rec) == expected

# using filters
filter = {}
filter: dict[str, str | datetime] = {}
if start_date is not None:
filter["__start_date"] = start_date
if end_date is not None:
Expand All @@ -338,7 +338,7 @@ async def search_date(start_date, end_date, expected):
rec = await vec.search([1.0, 2.0], limit=4, filter=filter)
assert len(rec) == expected
# using predicates
predicates = []
predicates: list[tuple[str, str, str | datetime]] = []
if start_date is not None:
predicates.append(("__uuid_timestamp", ">=", start_date))
if end_date is not None:
Expand Down
27 changes: 23 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,29 @@
import os

import psycopg2
import pytest
from dotenv import find_dotenv, load_dotenv

# from dotenv import find_dotenv, load_dotenv

@pytest.fixture
def service_url() -> str:
_ = load_dotenv(find_dotenv(), override=True)

@pytest.fixture(scope="module")
def setup_env_variables() -> None:
os.environ.clear()
os.environ["TIMESCALE_SERVICE_URL"] = "postgres://postgres:postgres@localhost:5432/postgres"
os.environ["OPENAI_API_KEY"] = "fake key"


@pytest.fixture(scope="module")
def service_url(setup_env_variables: None) -> str: # noqa: ARG001
# _ = load_dotenv(find_dotenv(), override=True)
return os.environ["TIMESCALE_SERVICE_URL"]


@pytest.fixture(scope="module", autouse=True)
def setup_db(service_url: str) -> None:
conn = psycopg2.connect(service_url)
with conn.cursor() as cursor:
cursor.execute("CREATE EXTENSION IF NOT EXISTS ai CASCADE;")
cursor.execute("CREATE SCHEMA IF NOT EXISTS temp;")
conn.commit()
conn.close()
13 changes: 7 additions & 6 deletions tests/pg_vectorizer_test.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from datetime import timedelta
from typing import Any

import psycopg2
import pytest
from langchain.docstore.document import Document
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.vectorstores.timescalevector import TimescaleVector
from langchain_openai import OpenAIEmbeddings

from tests.utils import http_recorder
from timescale_vector import client
from timescale_vector.pgvectorizer import Vectorize


def get_document(blog):
def get_document(blog: dict[str, Any]) -> list[Document]:
text_splitter = CharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
)
docs = []
docs: list[Document] = []
for chunk in text_splitter.split_text(blog["contents"]):
content = f"Author {blog['author']}, title: {blog['title']}, contents:{chunk}"
metadata = {
Expand All @@ -30,7 +31,7 @@ def get_document(blog):
return docs


@pytest.mark.skip(reason="requires OpenAI API key")
@http_recorder.use_cassette("pg_vectorizer.yaml")
def test_pg_vectorizer(service_url: str) -> None:
with psycopg2.connect(service_url) as conn, conn.cursor() as cursor:
for item in ["blog", "blog_embedding_work_queue", "blog_embedding"]:
Expand All @@ -56,7 +57,7 @@ def test_pg_vectorizer(service_url: str) -> None:
VALUES ('first', 'mat', 'first_post', 'personal', '2021-01-01');
""")

def embed_and_write(blog_instances, vectorizer):
def embed_and_write(blog_instances: list[Any], vectorizer: Vectorize) -> None:
TABLE_NAME = vectorizer.table_name_unquoted + "_embedding"
embedding = OpenAIEmbeddings()
vector_store = TimescaleVector(
Expand All @@ -70,7 +71,7 @@ def embed_and_write(blog_instances, vectorizer):
metadata_for_delete = [{"blog_id": blog["locked_id"]} for blog in blog_instances]
vector_store.delete_by_metadata(metadata_for_delete)

documents = []
documents: list[Document] = []
for blog in blog_instances:
# skip blogs that are not published yet, or are deleted (will be None because of left join)
if blog["published_time"] is not None:
Expand Down
14 changes: 7 additions & 7 deletions tests/sync_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)


@pytest.mark.parametrize("schema", ["tschema", None])
@pytest.mark.parametrize("schema", ["temp", None])
def test_sync_client(service_url: str, schema: str) -> None:
vec = Sync(service_url, "data_table", 2, schema_name=schema)
vec.create_tables()
Expand Down Expand Up @@ -136,15 +136,15 @@ def test_sync_client(service_url: str, schema: str) -> None:

rec = vec.search([1.0, 2.0], filter={"key_1": "val_1", "key_2": "val_2"})
assert rec[0][SEARCH_RESULT_CONTENTS_IDX] == "the brown fox"
assert rec[0]["contents"] == "the brown fox"
assert rec[0]["contents"] == "the brown fox" # type: ignore
assert rec[0][SEARCH_RESULT_METADATA_IDX] == {
"key_1": "val_1",
"key_2": "val_2",
}
assert rec[0]["metadata"] == {"key_1": "val_1", "key_2": "val_2"}
assert rec[0]["metadata"] == {"key_1": "val_1", "key_2": "val_2"} # type: ignore
assert isinstance(rec[0][SEARCH_RESULT_METADATA_IDX], dict)
assert rec[0][SEARCH_RESULT_DISTANCE_IDX] == 0.0009438353921149556
assert rec[0]["distance"] == 0.0009438353921149556
assert rec[0]["distance"] == 0.0009438353921149556 # type: ignore

rec = vec.search([1.0, 2.0], limit=4, predicates=Predicates("key", "==", "val2"))
assert len(rec) == 1
Expand Down Expand Up @@ -218,7 +218,7 @@ def test_sync_client(service_url: str, schema: str) -> None:
]
)

def search_date(start_date, end_date, expected):
def search_date(start_date: datetime | str | None, end_date: datetime | str | None, expected: int) -> None:
# using uuid_time_filter
rec = vec.search(
[1.0, 2.0],
Expand All @@ -234,7 +234,7 @@ def search_date(start_date, end_date, expected):
assert len(rec) == expected

# using filters
filter = {}
filter: dict[str, str | datetime] = {}
if start_date is not None:
filter["__start_date"] = start_date
if end_date is not None:
Expand All @@ -250,7 +250,7 @@ def search_date(start_date, end_date, expected):
rec = vec.search([1.0, 2.0], limit=4, filter=filter)
assert len(rec) == expected
# using predicates
predicates = []
predicates: list[tuple[str, str, str | datetime]] = []
if start_date is not None:
predicates.append(("__uuid_timestamp", ">=", start_date))
if end_date is not None:
Expand Down
25 changes: 25 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
from typing import Any

import vcr

vcr_cassette_path = os.path.join(os.path.dirname(__file__), "vcr_cassettes")


def remove_set_cookie_header(response: dict[str, Any]):
headers = response["headers"]
headers_to_remove = ["set-cookie", "Set-Cookie"]

for header in headers_to_remove:
if header in headers:
del headers[header]

return response


http_recorder = vcr.VCR(
cassette_library_dir=vcr_cassette_path,
record_mode="once",
filter_headers=["authorization", "cookie"],
before_record_response=remove_set_cookie_header,
)
Loading

0 comments on commit 7b177d0

Please sign in to comment.