Skip to content

Commit

Permalink
Make metadata field optional
Browse files Browse the repository at this point in the history
  • Loading branch information
Askir committed Oct 28, 2024
1 parent 849a88f commit 112f28a
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 15 deletions.
19 changes: 17 additions & 2 deletions tests/async_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@
@pytest.mark.parametrize("schema", ["temp", None])
async def test_vector(service_url: str, schema: str) -> None:
vec = Async(
service_url, "data_table", 2, schema_name=schema, embedding_table_name="data_table", id_column_name="id"
service_url,
"data_table",
2,
schema_name=schema,
embedding_table_name="data_table",
id_column_name="id",
metadata_column_name="metadata",
)
await vec.drop_table()
await vec.create_tables()
Expand Down Expand Up @@ -258,7 +264,15 @@ async def test_vector(service_url: str, schema: str) -> None:
await vec.drop_table()
await vec.close()

vec = Async(service_url, "data_table", 2, id_type="TEXT", embedding_table_name="data_table", id_column_name="id")
vec = Async(
service_url,
"data_table",
2,
id_type="TEXT",
embedding_table_name="data_table",
id_column_name="id",
metadata_column_name="metadata",
)
await vec.create_tables()
empty = await vec.table_is_empty()
assert empty
Expand All @@ -278,6 +292,7 @@ async def test_vector(service_url: str, schema: str) -> None:
time_partition_interval=timedelta(seconds=60),
embedding_table_name="data_table",
id_column_name="id",
metadata_column_name="metadata",
)
await vec.create_tables()
empty = await vec.table_is_empty()
Expand Down
56 changes: 55 additions & 1 deletion tests/compatability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_metadata_filtered_search(quickstart: None, service_url: str): # noqa:

@pytest.fixture(scope="function")
def sync_client(service_url: str) -> client.Sync:
return client.Sync(service_url, "blog_contents_embeddings", 768)
return client.Sync(service_url, "blog_contents_embeddings", 768, metadata_column_name="metadata")


def test_basic_similarity_search(sync_client: client.Sync, quickstart: None): # noqa: ARG001
Expand Down Expand Up @@ -244,3 +244,57 @@ def test_index_operations(sync_client: client.Sync, quickstart: None): # noqa:
assert len(results_with_params) == 3

sync_client.drop_embedding_index()


def test_semantic_search_without_metadata(service_url: str, quickstart: None): # noqa: ARG001
conn = psycopg2.connect(service_url)
conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT)

with conn.cursor() as cursor:
cursor.execute("DROP VIEW IF EXISTS public.blog_contents_embeddings;")
cursor.execute("""
CREATE VIEW public.blog_contents_embeddings AS
SELECT
t.embedding_uuid,
t.chunk_seq,
t.chunk,
t.embedding,
t.id,
s.title,
s.authors,
s.contents
FROM (public.blog_contents_embeddings_store t
LEFT JOIN public.blog s ON ((t.id = s.id)));
""")

sync_client = client.Sync(service_url, "blog_contents_embeddings", 768)
results = sync_client.search(embeddings["artificial intelligence"], limit=3)

assert len(results) == 3
assert all(isinstance(r["embedding_uuid"], uuid.UUID) for r in results)
assert all(isinstance(r["chunk"], str) for r in results)
assert all(isinstance(r["embedding"], numpy.ndarray) for r in results)
assert all(isinstance(r["distance"], float) for r in results)

assert all("metadata" not in r or not r["metadata"] for r in results)

# Restore the original view
with conn.cursor() as cursor:
cursor.execute("DROP VIEW IF EXISTS public.blog_contents_embeddings;")
cursor.execute("""
CREATE VIEW public.blog_contents_embeddings AS
SELECT
t.embedding_uuid,
t.chunk_seq,
t.chunk,
t.embedding,
t.id,
s.title,
s.authors,
s.contents,
s.metadata
FROM (public.blog_contents_embeddings_store t
LEFT JOIN public.blog s ON ((t.id = s.id)));
""")

conn.close()
12 changes: 11 additions & 1 deletion tests/sync_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@

@pytest.mark.parametrize("schema", ["temp", None])
def test_sync_client(service_url: str, schema: str) -> None:
vec = Sync(service_url, "data_table", 2, schema_name=schema, embedding_table_name="data_table", id_column_name="id")
vec = Sync(
service_url,
"data_table",
2,
schema_name=schema,
embedding_table_name="data_table",
id_column_name="id",
metadata_column_name="metadata",
)
vec.drop_table()
vec.create_tables()
empty = vec.table_is_empty()
Expand Down Expand Up @@ -179,6 +187,7 @@ def test_sync_client(service_url: str, schema: str) -> None:
schema_name=schema,
embedding_table_name="data_table",
id_column_name="id",
metadata_column_name="metadata",
)
vec.create_tables()
assert vec.table_is_empty()
Expand All @@ -197,6 +206,7 @@ def test_sync_client(service_url: str, schema: str) -> None:
schema_name=schema,
embedding_table_name="data_table",
id_column_name="id",
metadata_column_name="metadata",
)
vec.create_tables()
assert vec.table_is_empty()
Expand Down
2 changes: 2 additions & 0 deletions timescale_vector/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
schema_name: str | None = None,
embedding_table_name: str | None = None,
id_column_name: str = "embedding_uuid",
metadata_column_name: str | None = None,
) -> None:
"""
Initializes a async client for storing vector data.
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(
schema_name,
embedding_table_name,
id_column_name,
metadata_column_name,
)
self.service_url: str = service_url
self.pool: Pool | None = None
Expand Down
35 changes: 24 additions & 11 deletions timescale_vector/client/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
schema_name: str | None,
embedding_table_name: str | None = None,
id_column_name: str = "embedding_uuid",
metadata_column_name: str | None = None, # Added this parameter
) -> None:
"""
Initializes a base Vector object to generate queries for vector clients.
Expand All @@ -46,10 +47,13 @@ def __init__(
Whether to infer start and end times from the special __start_date and __end_date filters.
schema_name
The schema name for the table (optional, uses the database's default schema if not specified).
metadata_column_name
The name of the metadata column (optional, if None metadata will not be queried).
"""
self.view_name: str = table_name
self.embedding_table_name = embedding_table_name or table_name + "_store"
self.id_column_name = id_column_name
self.metadata_column_name = metadata_column_name
self.schema_name: str | None = schema_name
self.num_dimensions: int = num_dimensions
if distance_type == "cosine" or distance_type == "<=>":
Expand Down Expand Up @@ -118,10 +122,13 @@ def get_upsert_query(self) -> str:
)
return (
f"INSERT INTO {self._quoted_table_name(self.embedding_table_name)} "
f"({self._quote_ident(self.id_column_name)}, metadata, chunk, embedding) "
f"({self._quote_ident(self.id_column_name)},{self.metadata_or_empty()} chunk, embedding) "
f"VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING"
)

def metadata_or_empty(self):
return f"{self._quote_ident(self.metadata_column_name)}," if self.metadata_column_name is not None else ""

def get_approx_count_query(self) -> str:
"""
Generate a query to find the approximate count of records in the table.
Expand Down Expand Up @@ -187,14 +194,14 @@ def get_create_query(self) -> str:
CREATE TABLE IF NOT EXISTS {self._quoted_table_name(self.embedding_table_name)} (
{self._quote_ident(self.id_column_name)} {self.id_type} PRIMARY KEY,
metadata JSONB,
{self.metadata_column_name} JSONB,
chunk TEXT,
embedding VECTOR({self.num_dimensions})
);
CREATE INDEX IF NOT EXISTS {self._quote_ident(self.view_name + "_meta_idx")}
ON {self._quoted_table_name(self.embedding_table_name)}
USING GIN(metadata jsonb_path_ops);
USING GIN({self.metadata_column_name} jsonb_path_ops);
{hypertable_sql}
"""
Expand All @@ -215,8 +222,10 @@ def delete_all_query(self) -> str:
return f"TRUNCATE {self._quoted_table_name(self.embedding_table_name)};"

def delete_by_ids_query(self, ids: list[uuid.UUID] | list[str]) -> tuple[str, list[Any]]:
query = (f"DELETE FROM {self._quoted_table_name(self.embedding_table_name)} "
f"WHERE {self._quote_ident(self.id_column_name)} = ANY($1::{self.id_type}[]);")
query = (
f"DELETE FROM {self._quoted_table_name(self.embedding_table_name)} "
f"WHERE {self._quote_ident(self.id_column_name)} = ANY($1::{self.id_type}[]);"
)
return (query, [ids])

def delete_by_metadata_query(self, filter_conditions: Filter) -> tuple[str, list[Any]]:
Expand Down Expand Up @@ -292,14 +301,14 @@ def _where_clause_for_filter(
return "TRUE", params

if isinstance(filter, dict):
where = f"metadata @> ${len(params)+1}"
where = f"{self.metadata_column_name} @> ${len(params)+1}"
json_object = json.dumps(filter)
params = params + [json_object]
elif isinstance(filter, list):
any_params: list[str] = []
for _idx, filter_dict in enumerate(filter, start=len(params) + 1):
any_params.append(json.dumps(filter_dict))
where = f"metadata @> ANY(${len(params) + 1}::jsonb[])"
where = f"{self.metadata_column_name} @> ANY(${len(params) + 1}::jsonb[])"
params = params + [any_params]
else:
raise ValueError(f"Unknown filter type: {type(filter)}")
Expand All @@ -322,7 +331,7 @@ def search_query(
"""
params: list[Any] = []
if query_embedding is not None:
distance = f"embedding {self.distance_type} ${len(params)+1}"
distance = f"embedding {self.distance_type} ${len(params) + 1}"
params = params + [query_embedding]
order_by_clause = f"ORDER BY {distance} ASC"
else:
Expand All @@ -346,11 +355,11 @@ def search_query(
del filter["__end_date"]

where_clauses: list[str] = []
if filter is not None:
if filter is not None and self.metadata_column_name is not None:
(where_filter, params) = self._where_clause_for_filter(params, filter)
where_clauses.append(where_filter)

if predicates is not None:
if predicates is not None and self.metadata_column_name is not None:
(where_predicates, params) = predicates.build_query(params)
where_clauses.append(where_predicates)

Expand All @@ -362,7 +371,11 @@ def search_query(

query = f"""
SELECT
{self._quote_ident(self.id_column_name)}, metadata, chunk, embedding, {distance} as distance
{self._quote_ident(self.id_column_name)},
{self.metadata_or_empty()}
chunk,
embedding,
{distance} as distance
FROM
{self._quoted_table_name(self.view_name)}
WHERE
Expand Down
2 changes: 2 additions & 0 deletions timescale_vector/client/sync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
schema_name: str | None = None,
embedding_table_name: str | None = None,
id_column_name: str = "embedding_uuid",
metadata_column_name: str | None = None,
) -> None:
"""
Initializes a sync client for storing vector data.
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
schema_name,
embedding_table_name,
id_column_name,
metadata_column_name,
)
self.service_url: str = service_url
self.pool: SimpleConnectionPool | None = None
Expand Down

0 comments on commit 112f28a

Please sign in to comment.