-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Joe Reuter
committed
Aug 22, 2023
1 parent
dd170e2
commit 9a9f6fe
Showing
6 changed files
with
586 additions
and
0 deletions.
There are no files selected for viewing
26 changes: 26 additions & 0 deletions
26
airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/batcher.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from typing import Any, Callable, List | ||
|
||
|
||
class Batcher: | ||
def __init__(self, batch_size: int, flush_handler: Callable[[List[Any]], None]): | ||
self.batch_size = batch_size | ||
self.buffer = [] | ||
self.flush_handler = flush_handler | ||
|
||
def add(self, item: Any): | ||
self.buffer.append(item) | ||
self._flush_if_necessary() | ||
|
||
def flush(self): | ||
if len(self.buffer) == 0: | ||
return | ||
self.flush_handler(list(self.buffer)) | ||
self.buffer.clear() | ||
|
||
def _flush_if_necessary(self): | ||
if len(self.buffer) >= self.batch_size: | ||
self.flush() |
144 changes: 144 additions & 0 deletions
144
airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
import json | ||
import re | ||
from typing import List, Literal, Optional, Union | ||
|
||
from jsonschema import RefResolver | ||
from pydantic import BaseModel, Field | ||
|
||
|
||
class ProcessingConfigModel(BaseModel): | ||
chunk_size: int = Field( | ||
..., | ||
title="Chunk size", | ||
maximum=8191, | ||
description="Size of chunks in tokens to store in vector store (make sure it is not too big for the context if your LLM)", | ||
) | ||
chunk_overlap: int = Field( | ||
title="Chunk overlap", | ||
description="Size of overlap between chunks in tokens to store in vector store to better capture relevant context", | ||
default=0, | ||
) | ||
text_fields: Optional[List[str]] = Field( | ||
..., | ||
title="Text fields to embed", | ||
description="List of fields in the record that should be used to calculate the embedding. All other fields are passed along as meta fields. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered text fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array.", | ||
always_show=True, | ||
examples=["text", "user.name", "users.*.name"], | ||
) | ||
|
||
class Config: | ||
schema_extra = {"group": "processing"} | ||
|
||
|
||
class OpenAIEmbeddingConfigModel(BaseModel): | ||
mode: Literal["openai"] = Field("openai", const=True) | ||
openai_key: str = Field(..., title="OpenAI API key", airbyte_secret=True) | ||
|
||
class Config: | ||
title = "OpenAI" | ||
schema_extra = { | ||
"description": "Use the OpenAI API to embed text. This option is using the text-embedding-ada-002 model with 1536 embedding dimensions." | ||
} | ||
|
||
|
||
class FakeEmbeddingConfigModel(BaseModel): | ||
mode: Literal["fake"] = Field("fake", const=True) | ||
|
||
class Config: | ||
title = "Fake" | ||
schema_extra = { | ||
"description": "Use a fake embedding made out of random vectors with 1536 embedding dimensions. This is useful for testing the data pipeline without incurring any costs." | ||
} | ||
|
||
|
||
class PineconeIndexingModel(BaseModel): | ||
mode: Literal["pinecone"] = Field("pinecone", const=True) | ||
pinecone_key: str = Field(..., title="Pinecone API key", airbyte_secret=True) | ||
pinecone_environment: str = Field(..., title="Pinecone environment", description="Pinecone environment to use") | ||
index: str = Field(..., title="Index", description="Pinecone index to use") | ||
|
||
class Config: | ||
title = "Pinecone" | ||
schema_extra = { | ||
"description": "Pinecone is a popular vector store that can be used to store and retrieve embeddings. It is a managed service and can also be queried from outside of langchain." | ||
} | ||
|
||
|
||
class ChromaLocalIndexingModel(BaseModel): | ||
mode: Literal["chroma_local"] = Field("chroma_local", const=True) | ||
destination_path: str = Field( | ||
..., | ||
title="Destination Path", | ||
description="Path to the directory where chroma files will be written. The files will be placed inside that local mount.", | ||
examples=["/local/my_chroma_db"], | ||
) | ||
collection_name: str = Field( | ||
title="Collection Name", | ||
description="Name of the collection to use.", | ||
default="langchain", | ||
) | ||
|
||
class Config: | ||
title = "Chroma (local persistance)" | ||
schema_extra = { | ||
"description": "Chroma is a popular vector store that can be used to store and retrieve embeddings. It will build its index in memory and persist it to disk by the end of the sync." | ||
} | ||
|
||
|
||
class DocArrayHnswSearchIndexingModel(BaseModel): | ||
mode: Literal["DocArrayHnswSearch"] = Field("DocArrayHnswSearch", const=True) | ||
destination_path: str = Field( | ||
..., | ||
title="Destination Path", | ||
description="Path to the directory where hnswlib and meta data files will be written. The files will be placed inside that local mount. All files in the specified destination directory will be deleted on each run.", | ||
examples=["/local/my_hnswlib_index"], | ||
) | ||
|
||
class Config: | ||
title = "DocArrayHnswSearch" | ||
schema_extra = { | ||
"description": "DocArrayHnswSearch is a lightweight Document Index implementation provided by Docarray that runs fully locally and is best suited for small- to medium-sized datasets. It stores vectors on disk in hnswlib, and stores all other data in SQLite." | ||
} | ||
|
||
|
||
class ConfigModel(BaseModel): | ||
processing: ProcessingConfigModel | ||
embedding: Union[OpenAIEmbeddingConfigModel, FakeEmbeddingConfigModel] = Field( | ||
..., title="Embedding", description="Embedding configuration", discriminator="mode", group="embedding", type="object" | ||
) | ||
indexing: Union[PineconeIndexingModel, DocArrayHnswSearchIndexingModel, ChromaLocalIndexingModel] = Field( | ||
..., title="Indexing", description="Indexing configuration", discriminator="mode", group="indexing", type="object" | ||
) | ||
|
||
class Config: | ||
title = "Langchain Destination Config" | ||
schema_extra = { | ||
"groups": [ | ||
{"id": "processing", "title": "Processing"}, | ||
{"id": "embedding", "title": "Embedding"}, | ||
{"id": "indexing", "title": "Indexing"}, | ||
] | ||
} | ||
|
||
@staticmethod | ||
def resolve_refs(schema: dict) -> dict: | ||
# config schemas can't contain references, so inline them | ||
json_schema_ref_resolver = RefResolver.from_schema(schema) | ||
str_schema = json.dumps(schema) | ||
for ref_block in re.findall(r'{"\$ref": "#\/definitions\/.+?(?="})"}', str_schema): | ||
ref = json.loads(ref_block)["$ref"] | ||
str_schema = str_schema.replace(ref_block, json.dumps(json_schema_ref_resolver.resolve(ref)[1])) | ||
pyschema: dict = json.loads(str_schema) | ||
del pyschema["definitions"] | ||
return pyschema | ||
|
||
@classmethod | ||
def schema(cls): | ||
"""we're overriding the schema classmethod to enable some post-processing""" | ||
schema = super().schema() | ||
schema = cls.resolve_refs(schema) | ||
return schema |
120 changes: 120 additions & 0 deletions
120
airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/document_processor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
import logging | ||
from typing import List, Mapping, Optional, Tuple, Union | ||
|
||
import dpath.util | ||
from airbyte_cdk.models import AirbyteRecordMessage, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream | ||
from airbyte_cdk.models.airbyte_protocol import AirbyteStream, DestinationSyncMode | ||
from destination_langchain.config import ProcessingConfigModel | ||
from dpath.exceptions import PathNotFound | ||
from langchain.document_loaders.base import Document | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain.utils import stringify_dict | ||
|
||
METADATA_STREAM_FIELD = "_airbyte_stream" | ||
METADATA_RECORD_ID_FIELD = "_record_id" | ||
|
||
|
||
class DocumentProcessor: | ||
streams: Mapping[str, ConfiguredAirbyteStream] | ||
|
||
def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog, max_metadata_size: Optional[int] = None): | ||
self.streams = {self._stream_identifier(stream.stream): stream for stream in catalog.streams} | ||
self.max_metadata_size = max_metadata_size | ||
|
||
self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | ||
chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap | ||
) | ||
self.text_fields = config.text_fields | ||
self.logger = logging.getLogger("airbyte.document_processor") | ||
|
||
def _stream_identifier(self, stream: Union[AirbyteStream, AirbyteRecordMessage]) -> str: | ||
if isinstance(stream, AirbyteStream): | ||
return stream.name if stream.namespace is None else f"{stream.namespace}_{stream.name}" | ||
else: | ||
return stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}" | ||
|
||
def process(self, record: AirbyteRecordMessage) -> Tuple[List[Document], Optional[str]]: | ||
""" | ||
Generate documents from records. | ||
:param records: List of AirbyteRecordMessages | ||
:return: Tuple of (List of document chunks, record id to delete if a stream is in dedup mode to avoid stale documents in the vector store) | ||
""" | ||
doc = self._generate_document(record) | ||
if doc is None: | ||
self.logger.warning(f"Record {str(record.data)[:250]}... does not contain any text fields. Skipping.") | ||
return [], None | ||
chunks = self._split_document(doc) | ||
id_to_delete = doc.metadata[METADATA_RECORD_ID_FIELD] if METADATA_RECORD_ID_FIELD in doc.metadata else None | ||
return chunks, id_to_delete | ||
|
||
def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document]: | ||
relevant_fields = self._extract_relevant_fields(record) | ||
if len(relevant_fields) == 0: | ||
return None | ||
metadata = self._extract_metadata(record) | ||
text = stringify_dict(relevant_fields) | ||
return Document(page_content=text, metadata=metadata) | ||
|
||
def _extract_relevant_fields(self, record: AirbyteRecordMessage) -> dict: | ||
relevant_fields = {} | ||
if self.text_fields: | ||
for field in self.text_fields: | ||
values = dpath.util.values(record.data, field, separator=".") | ||
if values and len(values) > 0: | ||
relevant_fields[field] = values | ||
else: | ||
relevant_fields = record.data | ||
return relevant_fields | ||
|
||
def _extract_metadata(self, record: AirbyteRecordMessage) -> dict: | ||
metadata = record.data | ||
if self.text_fields: | ||
for field in self.text_fields: | ||
try: | ||
dpath.util.delete(metadata, field, separator=".") | ||
except PathNotFound: | ||
pass # if the field doesn't exist, do nothing | ||
metadata = self._truncate_metadata(metadata) | ||
stream_identifier = self._stream_identifier(record) | ||
current_stream: ConfiguredAirbyteStream = self.streams[stream_identifier] | ||
metadata[METADATA_STREAM_FIELD] = stream_identifier | ||
# if the sync mode is deduping, use the primary key to upsert existing records instead of appending new ones | ||
if current_stream.primary_key and current_stream.destination_sync_mode == DestinationSyncMode.append_dedup: | ||
metadata[METADATA_RECORD_ID_FIELD] = self._extract_primary_key(record, current_stream) | ||
return metadata | ||
|
||
def _extract_primary_key(self, record: AirbyteRecordMessage, stream: ConfiguredAirbyteStream) -> dict: | ||
primary_key = [] | ||
for key in stream.primary_key: | ||
try: | ||
primary_key.append(str(dpath.util.get(record.data, key))) | ||
except KeyError: | ||
primary_key.append("__not_found__") | ||
return "_".join(primary_key) | ||
|
||
def _truncate_metadata(self, metadata: dict) -> dict: | ||
""" | ||
Normalize metadata to ensure it is within the size limit and doesn't contain complex objects. | ||
""" | ||
result = {} | ||
current_size = 0 | ||
|
||
for key, value in metadata.items(): | ||
if isinstance(value, (str, int, float, bool)): | ||
# Calculate the size of the key and value | ||
item_size = len(str(key)) + len(str(value)) | ||
|
||
# Check if adding the item exceeds the size limit | ||
if self.max_metadata_size is None or current_size + item_size <= self.max_metadata_size: | ||
result[key] = value | ||
current_size += item_size | ||
|
||
return result | ||
|
||
def _split_document(self, doc: Document) -> List[Document]: | ||
chunks = self.splitter.split_documents([doc]) | ||
return chunks |
82 changes: 82 additions & 0 deletions
82
airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
# | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Optional | ||
|
||
from airbyte_cdk.destinations.vector_db_based.config import FakeEmbeddingConfigModel, OpenAIEmbeddingConfigModel | ||
from airbyte_cdk.destinations.vector_db_based.utils import format_exception | ||
from langchain.embeddings.base import Embeddings | ||
from langchain.embeddings.fake import FakeEmbeddings | ||
from langchain.embeddings.openai import OpenAIEmbeddings | ||
|
||
|
||
class Embedder(ABC): | ||
def __init__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def check(self) -> Optional[str]: | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def langchain_embeddings(self) -> Embeddings: | ||
pass | ||
|
||
@property | ||
@abstractmethod | ||
def embedding_dimensions(self) -> int: | ||
pass | ||
|
||
|
||
OPEN_AI_VECTOR_SIZE = 1536 | ||
|
||
|
||
class OpenAIEmbedder(Embedder): | ||
def __init__(self, config: OpenAIEmbeddingConfigModel): | ||
super().__init__() | ||
self.embeddings = OpenAIEmbeddings(openai_api_key=config.openai_key, chunk_size=8191) | ||
|
||
def check(self) -> Optional[str]: | ||
try: | ||
self.embeddings.embed_query("test") | ||
except Exception as e: | ||
return format_exception(e) | ||
return None | ||
|
||
@property | ||
def langchain_embeddings(self) -> Embeddings: | ||
return self.embeddings | ||
|
||
@property | ||
def embedding_dimensions(self) -> int: | ||
# vector size produced by text-embedding-ada-002 model | ||
return OPEN_AI_VECTOR_SIZE | ||
|
||
|
||
|
||
COHERE_VECTOR_SIZE = 1024 | ||
|
||
|
||
class FakeEmbedder(Embedder): | ||
def __init__(self, config: FakeEmbeddingConfigModel): | ||
super().__init__() | ||
self.embeddings = FakeEmbeddings(size=OPEN_AI_VECTOR_SIZE) | ||
|
||
def check(self) -> Optional[str]: | ||
try: | ||
self.embeddings.embed_query("test") | ||
except Exception as e: | ||
return format_exception(e) | ||
return None | ||
|
||
@property | ||
def langchain_embeddings(self) -> Embeddings: | ||
return self.embeddings | ||
|
||
@property | ||
def embedding_dimensions(self) -> int: | ||
# use same vector size as for OpenAI embeddings to keep it realistic | ||
return OPEN_AI_VECTOR_SIZE |
Oops, something went wrong.