From 9a9f6fe8e9fddd43937aba20e7063e440b52cf44 Mon Sep 17 00:00:00 2001 From: Joe Reuter Date: Tue, 22 Aug 2023 15:53:42 +0200 Subject: [PATCH] wip --- .../destinations/vector_db_based/batcher.py | 26 +++ .../destinations/vector_db_based/config.py | 144 ++++++++++++ .../vector_db_based/document_processor.py | 120 ++++++++++ .../destinations/vector_db_based/embedder.py | 82 +++++++ .../destinations/vector_db_based/indexer.py | 205 ++++++++++++++++++ .../destinations/vector_db_based/utils.py | 9 + 6 files changed, 586 insertions(+) create mode 100644 airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/batcher.py create mode 100644 airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py create mode 100644 airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py create mode 100644 airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py create mode 100644 airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/indexer.py create mode 100644 airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/utils.py diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/batcher.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/batcher.py new file mode 100644 index 000000000000..ab3ad18c83d8 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/batcher.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from typing import Any, Callable, List + + +class Batcher: + def __init__(self, batch_size: int, flush_handler: Callable[[List[Any]], None]): + self.batch_size = batch_size + self.buffer = [] + self.flush_handler = flush_handler + + def add(self, item: Any): + self.buffer.append(item) + self._flush_if_necessary() + + def flush(self): + if len(self.buffer) == 0: + return + self.flush_handler(list(self.buffer)) + self.buffer.clear() + + def _flush_if_necessary(self): + if len(self.buffer) >= self.batch_size: + self.flush() diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py new file mode 100644 index 000000000000..04a3e5f7994c --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py @@ -0,0 +1,144 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import json +import re +from typing import List, Literal, Optional, Union + +from jsonschema import RefResolver +from pydantic import BaseModel, Field + + +class ProcessingConfigModel(BaseModel): + chunk_size: int = Field( + ..., + title="Chunk size", + maximum=8191, + description="Size of chunks in tokens to store in vector store (make sure it is not too big for the context if your LLM)", + ) + chunk_overlap: int = Field( + title="Chunk overlap", + description="Size of overlap between chunks in tokens to store in vector store to better capture relevant context", + default=0, + ) + text_fields: Optional[List[str]] = Field( + ..., + title="Text fields to embed", + description="List of fields in the record that should be used to calculate the embedding. All other fields are passed along as meta fields. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered text fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array.", + always_show=True, + examples=["text", "user.name", "users.*.name"], + ) + + class Config: + schema_extra = {"group": "processing"} + + +class OpenAIEmbeddingConfigModel(BaseModel): + mode: Literal["openai"] = Field("openai", const=True) + openai_key: str = Field(..., title="OpenAI API key", airbyte_secret=True) + + class Config: + title = "OpenAI" + schema_extra = { + "description": "Use the OpenAI API to embed text. This option is using the text-embedding-ada-002 model with 1536 embedding dimensions." + } + + +class FakeEmbeddingConfigModel(BaseModel): + mode: Literal["fake"] = Field("fake", const=True) + + class Config: + title = "Fake" + schema_extra = { + "description": "Use a fake embedding made out of random vectors with 1536 embedding dimensions. This is useful for testing the data pipeline without incurring any costs." + } + + +class PineconeIndexingModel(BaseModel): + mode: Literal["pinecone"] = Field("pinecone", const=True) + pinecone_key: str = Field(..., title="Pinecone API key", airbyte_secret=True) + pinecone_environment: str = Field(..., title="Pinecone environment", description="Pinecone environment to use") + index: str = Field(..., title="Index", description="Pinecone index to use") + + class Config: + title = "Pinecone" + schema_extra = { + "description": "Pinecone is a popular vector store that can be used to store and retrieve embeddings. It is a managed service and can also be queried from outside of langchain." + } + + +class ChromaLocalIndexingModel(BaseModel): + mode: Literal["chroma_local"] = Field("chroma_local", const=True) + destination_path: str = Field( + ..., + title="Destination Path", + description="Path to the directory where chroma files will be written. The files will be placed inside that local mount.", + examples=["/local/my_chroma_db"], + ) + collection_name: str = Field( + title="Collection Name", + description="Name of the collection to use.", + default="langchain", + ) + + class Config: + title = "Chroma (local persistance)" + schema_extra = { + "description": "Chroma is a popular vector store that can be used to store and retrieve embeddings. It will build its index in memory and persist it to disk by the end of the sync." + } + + +class DocArrayHnswSearchIndexingModel(BaseModel): + mode: Literal["DocArrayHnswSearch"] = Field("DocArrayHnswSearch", const=True) + destination_path: str = Field( + ..., + title="Destination Path", + description="Path to the directory where hnswlib and meta data files will be written. The files will be placed inside that local mount. All files in the specified destination directory will be deleted on each run.", + examples=["/local/my_hnswlib_index"], + ) + + class Config: + title = "DocArrayHnswSearch" + schema_extra = { + "description": "DocArrayHnswSearch is a lightweight Document Index implementation provided by Docarray that runs fully locally and is best suited for small- to medium-sized datasets. It stores vectors on disk in hnswlib, and stores all other data in SQLite." + } + + +class ConfigModel(BaseModel): + processing: ProcessingConfigModel + embedding: Union[OpenAIEmbeddingConfigModel, FakeEmbeddingConfigModel] = Field( + ..., title="Embedding", description="Embedding configuration", discriminator="mode", group="embedding", type="object" + ) + indexing: Union[PineconeIndexingModel, DocArrayHnswSearchIndexingModel, ChromaLocalIndexingModel] = Field( + ..., title="Indexing", description="Indexing configuration", discriminator="mode", group="indexing", type="object" + ) + + class Config: + title = "Langchain Destination Config" + schema_extra = { + "groups": [ + {"id": "processing", "title": "Processing"}, + {"id": "embedding", "title": "Embedding"}, + {"id": "indexing", "title": "Indexing"}, + ] + } + + @staticmethod + def resolve_refs(schema: dict) -> dict: + # config schemas can't contain references, so inline them + json_schema_ref_resolver = RefResolver.from_schema(schema) + str_schema = json.dumps(schema) + for ref_block in re.findall(r'{"\$ref": "#\/definitions\/.+?(?="})"}', str_schema): + ref = json.loads(ref_block)["$ref"] + str_schema = str_schema.replace(ref_block, json.dumps(json_schema_ref_resolver.resolve(ref)[1])) + pyschema: dict = json.loads(str_schema) + del pyschema["definitions"] + return pyschema + + @classmethod + def schema(cls): + """we're overriding the schema classmethod to enable some post-processing""" + schema = super().schema() + schema = cls.resolve_refs(schema) + return schema diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py new file mode 100644 index 000000000000..460a7614bf37 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py @@ -0,0 +1,120 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import logging +from typing import List, Mapping, Optional, Tuple, Union + +import dpath.util +from airbyte_cdk.models import AirbyteRecordMessage, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream +from airbyte_cdk.models.airbyte_protocol import AirbyteStream, DestinationSyncMode +from destination_langchain.config import ProcessingConfigModel +from dpath.exceptions import PathNotFound +from langchain.document_loaders.base import Document +from langchain.text_splitter import RecursiveCharacterTextSplitter +from langchain.utils import stringify_dict + +METADATA_STREAM_FIELD = "_airbyte_stream" +METADATA_RECORD_ID_FIELD = "_record_id" + + +class DocumentProcessor: + streams: Mapping[str, ConfiguredAirbyteStream] + + def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog, max_metadata_size: Optional[int] = None): + self.streams = {self._stream_identifier(stream.stream): stream for stream in catalog.streams} + self.max_metadata_size = max_metadata_size + + self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( + chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap + ) + self.text_fields = config.text_fields + self.logger = logging.getLogger("airbyte.document_processor") + + def _stream_identifier(self, stream: Union[AirbyteStream, AirbyteRecordMessage]) -> str: + if isinstance(stream, AirbyteStream): + return stream.name if stream.namespace is None else f"{stream.namespace}_{stream.name}" + else: + return stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}" + + def process(self, record: AirbyteRecordMessage) -> Tuple[List[Document], Optional[str]]: + """ + Generate documents from records. + :param records: List of AirbyteRecordMessages + :return: Tuple of (List of document chunks, record id to delete if a stream is in dedup mode to avoid stale documents in the vector store) + """ + doc = self._generate_document(record) + if doc is None: + self.logger.warning(f"Record {str(record.data)[:250]}... does not contain any text fields. Skipping.") + return [], None + chunks = self._split_document(doc) + id_to_delete = doc.metadata[METADATA_RECORD_ID_FIELD] if METADATA_RECORD_ID_FIELD in doc.metadata else None + return chunks, id_to_delete + + def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document]: + relevant_fields = self._extract_relevant_fields(record) + if len(relevant_fields) == 0: + return None + metadata = self._extract_metadata(record) + text = stringify_dict(relevant_fields) + return Document(page_content=text, metadata=metadata) + + def _extract_relevant_fields(self, record: AirbyteRecordMessage) -> dict: + relevant_fields = {} + if self.text_fields: + for field in self.text_fields: + values = dpath.util.values(record.data, field, separator=".") + if values and len(values) > 0: + relevant_fields[field] = values + else: + relevant_fields = record.data + return relevant_fields + + def _extract_metadata(self, record: AirbyteRecordMessage) -> dict: + metadata = record.data + if self.text_fields: + for field in self.text_fields: + try: + dpath.util.delete(metadata, field, separator=".") + except PathNotFound: + pass # if the field doesn't exist, do nothing + metadata = self._truncate_metadata(metadata) + stream_identifier = self._stream_identifier(record) + current_stream: ConfiguredAirbyteStream = self.streams[stream_identifier] + metadata[METADATA_STREAM_FIELD] = stream_identifier + # if the sync mode is deduping, use the primary key to upsert existing records instead of appending new ones + if current_stream.primary_key and current_stream.destination_sync_mode == DestinationSyncMode.append_dedup: + metadata[METADATA_RECORD_ID_FIELD] = self._extract_primary_key(record, current_stream) + return metadata + + def _extract_primary_key(self, record: AirbyteRecordMessage, stream: ConfiguredAirbyteStream) -> dict: + primary_key = [] + for key in stream.primary_key: + try: + primary_key.append(str(dpath.util.get(record.data, key))) + except KeyError: + primary_key.append("__not_found__") + return "_".join(primary_key) + + def _truncate_metadata(self, metadata: dict) -> dict: + """ + Normalize metadata to ensure it is within the size limit and doesn't contain complex objects. + """ + result = {} + current_size = 0 + + for key, value in metadata.items(): + if isinstance(value, (str, int, float, bool)): + # Calculate the size of the key and value + item_size = len(str(key)) + len(str(value)) + + # Check if adding the item exceeds the size limit + if self.max_metadata_size is None or current_size + item_size <= self.max_metadata_size: + result[key] = value + current_size += item_size + + return result + + def _split_document(self, doc: Document) -> List[Document]: + chunks = self.splitter.split_documents([doc]) + return chunks diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py new file mode 100644 index 000000000000..2cf7250c03a8 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py @@ -0,0 +1,82 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from abc import ABC, abstractmethod +from typing import Optional + +from airbyte_cdk.destinations.vector_db_based.config import FakeEmbeddingConfigModel, OpenAIEmbeddingConfigModel +from airbyte_cdk.destinations.vector_db_based.utils import format_exception +from langchain.embeddings.base import Embeddings +from langchain.embeddings.fake import FakeEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings + + +class Embedder(ABC): + def __init__(self): + pass + + @abstractmethod + def check(self) -> Optional[str]: + pass + + @property + @abstractmethod + def langchain_embeddings(self) -> Embeddings: + pass + + @property + @abstractmethod + def embedding_dimensions(self) -> int: + pass + + +OPEN_AI_VECTOR_SIZE = 1536 + + +class OpenAIEmbedder(Embedder): + def __init__(self, config: OpenAIEmbeddingConfigModel): + super().__init__() + self.embeddings = OpenAIEmbeddings(openai_api_key=config.openai_key, chunk_size=8191) + + def check(self) -> Optional[str]: + try: + self.embeddings.embed_query("test") + except Exception as e: + return format_exception(e) + return None + + @property + def langchain_embeddings(self) -> Embeddings: + return self.embeddings + + @property + def embedding_dimensions(self) -> int: + # vector size produced by text-embedding-ada-002 model + return OPEN_AI_VECTOR_SIZE + + + +COHERE_VECTOR_SIZE = 1024 + + +class FakeEmbedder(Embedder): + def __init__(self, config: FakeEmbeddingConfigModel): + super().__init__() + self.embeddings = FakeEmbeddings(size=OPEN_AI_VECTOR_SIZE) + + def check(self) -> Optional[str]: + try: + self.embeddings.embed_query("test") + except Exception as e: + return format_exception(e) + return None + + @property + def langchain_embeddings(self) -> Embeddings: + return self.embeddings + + @property + def embedding_dimensions(self) -> int: + # use same vector size as for OpenAI embeddings to keep it realistic + return OPEN_AI_VECTOR_SIZE diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/indexer.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/indexer.py new file mode 100644 index 000000000000..6f8fbaeafb0c --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/indexer.py @@ -0,0 +1,205 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import itertools +import os +import uuid +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +import pinecone +from airbyte_cdk.models import ConfiguredAirbyteCatalog +from airbyte_cdk.models.airbyte_protocol import AirbyteLogMessage, AirbyteMessage, DestinationSyncMode, Level, Type +from destination_langchain.config import ChromaLocalIndexingModel, DocArrayHnswSearchIndexingModel, PineconeIndexingModel +from destination_langchain.document_processor import METADATA_RECORD_ID_FIELD, METADATA_STREAM_FIELD +from destination_langchain.embedder import Embedder +from destination_langchain.measure_time import measure_time +from destination_langchain.utils import format_exception +from langchain.document_loaders.base import Document +from langchain.vectorstores import Chroma +from langchain.vectorstores.docarray import DocArrayHnswSearch + + +class Indexer(ABC): + def __init__(self, config: Any, embedder: Embedder): + self.config = config + self.embedder = embedder + pass + + def pre_sync(self, catalog: ConfiguredAirbyteCatalog): + pass + + def post_sync(self) -> List[AirbyteMessage]: + return [] + + @abstractmethod + def index(self, document_chunks: List[Document], delete_ids: List[str]): + pass + + @abstractmethod + def check(self) -> Optional[str]: + pass + + @property + def max_metadata_size(self) -> Optional[int]: + return None + + +def chunks(iterable, batch_size): + """A helper function to break an iterable into chunks of size batch_size.""" + it = iter(iterable) + chunk = tuple(itertools.islice(it, batch_size)) + while chunk: + yield chunk + chunk = tuple(itertools.islice(it, batch_size)) + + +# large enough to speed up processing, small enough to not hit pinecone request limits +PINECONE_BATCH_SIZE = 40 + + +class PineconeIndexer(Indexer): + config: PineconeIndexingModel + + def __init__(self, config: PineconeIndexingModel, embedder: Embedder): + super().__init__(config, embedder) + pinecone.init(api_key=config.pinecone_key, environment=config.pinecone_environment, threaded=True) + self.pinecone_index = pinecone.Index(config.index, pool_threads=10) + self.embed_fn = measure_time(self.embedder.langchain_embeddings.embed_documents) + + def pre_sync(self, catalog: ConfiguredAirbyteCatalog): + index_description = pinecone.describe_index(self.config.index) + self._pod_type = index_description.pod_type + for stream in catalog.streams: + if stream.destination_sync_mode == DestinationSyncMode.overwrite: + self._delete_vectors({METADATA_STREAM_FIELD: stream.stream.name}) + + def post_sync(self): + return [AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.WARN, message=self.embed_fn._get_stats()))] + + def _delete_vectors(self, filter): + if self._pod_type == "starter": + # Starter pod types have a maximum of 1000000 rows + top_k = 10000 + self._delete_by_metadata(filter, top_k) + else: + self.pinecone_index.delete(filter=filter) + + def _delete_by_metadata(self, filter, top_k): + zero_vector = [0.0] * self.embedder.embedding_dimensions + query_result = self.pinecone_index.query(vector=zero_vector, filter=filter, top_k=top_k) + vector_ids = [doc.id for doc in query_result.matches] + if len(vector_ids) > 0: + self.pinecone_index.delete(ids=vector_ids) + + def index(self, document_chunks, delete_ids): + if len(delete_ids) > 0: + self._delete_vectors({METADATA_RECORD_ID_FIELD: {"$in": delete_ids}}) + embedding_vectors = self.embed_fn([chunk.page_content for chunk in document_chunks]) + pinecone_docs = [] + for i in range(len(document_chunks)): + chunk = document_chunks[i] + metadata = chunk.metadata + metadata["text"] = chunk.page_content + pinecone_docs.append((str(uuid.uuid4()), embedding_vectors[i], metadata)) + async_results = [ + self.pinecone_index.upsert(vectors=ids_vectors_chunk, async_req=True, show_progress=False) + for ids_vectors_chunk in chunks(pinecone_docs, batch_size=PINECONE_BATCH_SIZE) + ] + # Wait for and retrieve responses (this raises in case of error) + [async_result.get() for async_result in async_results] + + def check(self) -> Optional[str]: + try: + description = pinecone.describe_index(self.config.index) + actual_dimension = int(description.dimension) + if actual_dimension != self.embedder.embedding_dimensions: + return f"Your embedding configuration will produce vectors with dimension {self.embedder.embedding_dimensions:d}, but your index is configured with dimension {actual_dimension:d}. Make sure embedding and indexing configurations match." + except Exception as e: + return format_exception(e) + return None + + @property + def max_metadata_size(self) -> int: + # leave some space for the text field + return 40_960 - 10_000 + + +class DocArrayHnswSearchIndexer(Indexer): + config: DocArrayHnswSearchIndexingModel + + def __init__(self, config: DocArrayHnswSearchIndexingModel, embedder: Embedder): + super().__init__(config, embedder) + + def _init_vectorstore(self): + self.vectorstore = DocArrayHnswSearch.from_params( + embedding=self.embedder.langchain_embeddings, work_dir=self.config.destination_path, n_dim=self.embedder.embedding_dimensions + ) + + def pre_sync(self, catalog: ConfiguredAirbyteCatalog): + for stream in catalog.streams: + if stream.destination_sync_mode != DestinationSyncMode.overwrite: + raise Exception( + f"DocArrayHnswSearchIndexer only supports overwrite mode, got {stream.destination_sync_mode} for stream {stream.stream.name}" + ) + for file in os.listdir(self.config.destination_path): + os.remove(os.path.join(self.config.destination_path, file)) + self._init_vectorstore() + + def post_sync(self): + return [AirbyteMessage(type=Type.LOG, log=AirbyteLogMessage(level=Level.WARN, message=self.index._get_stats()))] + + @measure_time + def index(self, document_chunks, delete_ids: List[str]): + # does not support deleting documents, always full refresh sync + self.vectorstore.add_documents(document_chunks) + + def check(self) -> Optional[str]: + try: + self._init_vectorstore() + except Exception as e: + return format_exception(e) + return None + + +class ChromaLocalIndexer(Indexer): + config: ChromaLocalIndexingModel + + def __init__(self, config: ChromaLocalIndexingModel, embedder: Embedder): + super().__init__(config, embedder) + + def _init_vectorstore(self): + self.vectorstore = Chroma( + collection_name=self.config.collection_name, + embedding_function=self.embedder.langchain_embeddings, + persist_directory=self.config.destination_path, + ) + + def pre_sync(self, catalog: ConfiguredAirbyteCatalog): + self._init_vectorstore() + for stream in catalog.streams: + if stream.destination_sync_mode == DestinationSyncMode.overwrite: + self.vectorstore._collection.delete(where={METADATA_STREAM_FIELD: {"$eq": stream.stream.name}}) + + def index(self, document_chunks, delete_ids): + for delete_in in delete_ids: + self.vectorstore._collection.delete(where={METADATA_RECORD_ID_FIELD: {"$eq": delete_in}}) + for chunk in document_chunks: + self._normalize_metadata(chunk) + self.vectorstore.add_documents(document_chunks) + + def _normalize_metadata(self, document: Document): + for key, value in document.metadata.items(): + # check bool separately because isinstance(True, int) == True + if not isinstance(value, (str, float, int)) or isinstance(value, bool): + document.metadata[key] = str(value) + + def check(self) -> Optional[str]: + try: + self._init_vectorstore() + # try reading collections to make sure it works + self.vectorstore._client.list_collections() + except Exception as e: + return format_exception(e) + return None diff --git a/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/utils.py b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/utils.py new file mode 100644 index 000000000000..05644e2e7709 --- /dev/null +++ b/airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/utils.py @@ -0,0 +1,9 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import traceback + + +def format_exception(exception: Exception) -> None: + return str(exception) + "\n" + "".join(traceback.TracebackException.from_exception(exception).format())