Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Reuter committed Aug 22, 2023
1 parent dd170e2 commit 9a9f6fe
Show file tree
Hide file tree
Showing 6 changed files with 586 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from typing import Any, Callable, List


class Batcher:
def __init__(self, batch_size: int, flush_handler: Callable[[List[Any]], None]):
self.batch_size = batch_size
self.buffer = []
self.flush_handler = flush_handler

def add(self, item: Any):
self.buffer.append(item)
self._flush_if_necessary()

def flush(self):
if len(self.buffer) == 0:
return
self.flush_handler(list(self.buffer))
self.buffer.clear()

def _flush_if_necessary(self):
if len(self.buffer) >= self.batch_size:
self.flush()
144 changes: 144 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/destinations/vector_db_based/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import json
import re
from typing import List, Literal, Optional, Union

from jsonschema import RefResolver
from pydantic import BaseModel, Field


class ProcessingConfigModel(BaseModel):
chunk_size: int = Field(
...,
title="Chunk size",
maximum=8191,
description="Size of chunks in tokens to store in vector store (make sure it is not too big for the context if your LLM)",
)
chunk_overlap: int = Field(
title="Chunk overlap",
description="Size of overlap between chunks in tokens to store in vector store to better capture relevant context",
default=0,
)
text_fields: Optional[List[str]] = Field(
...,
title="Text fields to embed",
description="List of fields in the record that should be used to calculate the embedding. All other fields are passed along as meta fields. The field list is applied to all streams in the same way and non-existing fields are ignored. If none are defined, all fields are considered text fields. When specifying text fields, you can access nested fields in the record by using dot notation, e.g. `user.name` will access the `name` field in the `user` object. It's also possible to use wildcards to access all fields in an object, e.g. `users.*.name` will access all `names` fields in all entries of the `users` array.",
always_show=True,
examples=["text", "user.name", "users.*.name"],
)

class Config:
schema_extra = {"group": "processing"}


class OpenAIEmbeddingConfigModel(BaseModel):
mode: Literal["openai"] = Field("openai", const=True)
openai_key: str = Field(..., title="OpenAI API key", airbyte_secret=True)

class Config:
title = "OpenAI"
schema_extra = {
"description": "Use the OpenAI API to embed text. This option is using the text-embedding-ada-002 model with 1536 embedding dimensions."
}


class FakeEmbeddingConfigModel(BaseModel):
mode: Literal["fake"] = Field("fake", const=True)

class Config:
title = "Fake"
schema_extra = {
"description": "Use a fake embedding made out of random vectors with 1536 embedding dimensions. This is useful for testing the data pipeline without incurring any costs."
}


class PineconeIndexingModel(BaseModel):
mode: Literal["pinecone"] = Field("pinecone", const=True)
pinecone_key: str = Field(..., title="Pinecone API key", airbyte_secret=True)
pinecone_environment: str = Field(..., title="Pinecone environment", description="Pinecone environment to use")
index: str = Field(..., title="Index", description="Pinecone index to use")

class Config:
title = "Pinecone"
schema_extra = {
"description": "Pinecone is a popular vector store that can be used to store and retrieve embeddings. It is a managed service and can also be queried from outside of langchain."
}


class ChromaLocalIndexingModel(BaseModel):
mode: Literal["chroma_local"] = Field("chroma_local", const=True)
destination_path: str = Field(
...,
title="Destination Path",
description="Path to the directory where chroma files will be written. The files will be placed inside that local mount.",
examples=["/local/my_chroma_db"],
)
collection_name: str = Field(
title="Collection Name",
description="Name of the collection to use.",
default="langchain",
)

class Config:
title = "Chroma (local persistance)"
schema_extra = {
"description": "Chroma is a popular vector store that can be used to store and retrieve embeddings. It will build its index in memory and persist it to disk by the end of the sync."
}


class DocArrayHnswSearchIndexingModel(BaseModel):
mode: Literal["DocArrayHnswSearch"] = Field("DocArrayHnswSearch", const=True)
destination_path: str = Field(
...,
title="Destination Path",
description="Path to the directory where hnswlib and meta data files will be written. The files will be placed inside that local mount. All files in the specified destination directory will be deleted on each run.",
examples=["/local/my_hnswlib_index"],
)

class Config:
title = "DocArrayHnswSearch"
schema_extra = {
"description": "DocArrayHnswSearch is a lightweight Document Index implementation provided by Docarray that runs fully locally and is best suited for small- to medium-sized datasets. It stores vectors on disk in hnswlib, and stores all other data in SQLite."
}


class ConfigModel(BaseModel):
processing: ProcessingConfigModel
embedding: Union[OpenAIEmbeddingConfigModel, FakeEmbeddingConfigModel] = Field(
..., title="Embedding", description="Embedding configuration", discriminator="mode", group="embedding", type="object"
)
indexing: Union[PineconeIndexingModel, DocArrayHnswSearchIndexingModel, ChromaLocalIndexingModel] = Field(
..., title="Indexing", description="Indexing configuration", discriminator="mode", group="indexing", type="object"
)

class Config:
title = "Langchain Destination Config"
schema_extra = {
"groups": [
{"id": "processing", "title": "Processing"},
{"id": "embedding", "title": "Embedding"},
{"id": "indexing", "title": "Indexing"},
]
}

@staticmethod
def resolve_refs(schema: dict) -> dict:
# config schemas can't contain references, so inline them
json_schema_ref_resolver = RefResolver.from_schema(schema)
str_schema = json.dumps(schema)
for ref_block in re.findall(r'{"\$ref": "#\/definitions\/.+?(?="})"}', str_schema):
ref = json.loads(ref_block)["$ref"]
str_schema = str_schema.replace(ref_block, json.dumps(json_schema_ref_resolver.resolve(ref)[1]))
pyschema: dict = json.loads(str_schema)
del pyschema["definitions"]
return pyschema

@classmethod
def schema(cls):
"""we're overriding the schema classmethod to enable some post-processing"""
schema = super().schema()
schema = cls.resolve_refs(schema)
return schema
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import logging
from typing import List, Mapping, Optional, Tuple, Union

import dpath.util
from airbyte_cdk.models import AirbyteRecordMessage, ConfiguredAirbyteCatalog, ConfiguredAirbyteStream
from airbyte_cdk.models.airbyte_protocol import AirbyteStream, DestinationSyncMode
from destination_langchain.config import ProcessingConfigModel
from dpath.exceptions import PathNotFound
from langchain.document_loaders.base import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.utils import stringify_dict

METADATA_STREAM_FIELD = "_airbyte_stream"
METADATA_RECORD_ID_FIELD = "_record_id"


class DocumentProcessor:
streams: Mapping[str, ConfiguredAirbyteStream]

def __init__(self, config: ProcessingConfigModel, catalog: ConfiguredAirbyteCatalog, max_metadata_size: Optional[int] = None):
self.streams = {self._stream_identifier(stream.stream): stream for stream in catalog.streams}
self.max_metadata_size = max_metadata_size

self.splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
chunk_size=config.chunk_size, chunk_overlap=config.chunk_overlap
)
self.text_fields = config.text_fields
self.logger = logging.getLogger("airbyte.document_processor")

def _stream_identifier(self, stream: Union[AirbyteStream, AirbyteRecordMessage]) -> str:
if isinstance(stream, AirbyteStream):
return stream.name if stream.namespace is None else f"{stream.namespace}_{stream.name}"
else:
return stream.stream if stream.namespace is None else f"{stream.namespace}_{stream.stream}"

def process(self, record: AirbyteRecordMessage) -> Tuple[List[Document], Optional[str]]:
"""
Generate documents from records.
:param records: List of AirbyteRecordMessages
:return: Tuple of (List of document chunks, record id to delete if a stream is in dedup mode to avoid stale documents in the vector store)
"""
doc = self._generate_document(record)
if doc is None:
self.logger.warning(f"Record {str(record.data)[:250]}... does not contain any text fields. Skipping.")
return [], None
chunks = self._split_document(doc)
id_to_delete = doc.metadata[METADATA_RECORD_ID_FIELD] if METADATA_RECORD_ID_FIELD in doc.metadata else None
return chunks, id_to_delete

def _generate_document(self, record: AirbyteRecordMessage) -> Optional[Document]:
relevant_fields = self._extract_relevant_fields(record)
if len(relevant_fields) == 0:
return None
metadata = self._extract_metadata(record)
text = stringify_dict(relevant_fields)
return Document(page_content=text, metadata=metadata)

def _extract_relevant_fields(self, record: AirbyteRecordMessage) -> dict:
relevant_fields = {}
if self.text_fields:
for field in self.text_fields:
values = dpath.util.values(record.data, field, separator=".")
if values and len(values) > 0:
relevant_fields[field] = values
else:
relevant_fields = record.data
return relevant_fields

def _extract_metadata(self, record: AirbyteRecordMessage) -> dict:
metadata = record.data
if self.text_fields:
for field in self.text_fields:
try:
dpath.util.delete(metadata, field, separator=".")
except PathNotFound:
pass # if the field doesn't exist, do nothing
metadata = self._truncate_metadata(metadata)
stream_identifier = self._stream_identifier(record)
current_stream: ConfiguredAirbyteStream = self.streams[stream_identifier]
metadata[METADATA_STREAM_FIELD] = stream_identifier
# if the sync mode is deduping, use the primary key to upsert existing records instead of appending new ones
if current_stream.primary_key and current_stream.destination_sync_mode == DestinationSyncMode.append_dedup:
metadata[METADATA_RECORD_ID_FIELD] = self._extract_primary_key(record, current_stream)
return metadata

def _extract_primary_key(self, record: AirbyteRecordMessage, stream: ConfiguredAirbyteStream) -> dict:
primary_key = []
for key in stream.primary_key:
try:
primary_key.append(str(dpath.util.get(record.data, key)))
except KeyError:
primary_key.append("__not_found__")
return "_".join(primary_key)

def _truncate_metadata(self, metadata: dict) -> dict:
"""
Normalize metadata to ensure it is within the size limit and doesn't contain complex objects.
"""
result = {}
current_size = 0

for key, value in metadata.items():
if isinstance(value, (str, int, float, bool)):
# Calculate the size of the key and value
item_size = len(str(key)) + len(str(value))

# Check if adding the item exceeds the size limit
if self.max_metadata_size is None or current_size + item_size <= self.max_metadata_size:
result[key] = value
current_size += item_size

return result

def _split_document(self, doc: Document) -> List[Document]:
chunks = self.splitter.split_documents([doc])
return chunks
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

from abc import ABC, abstractmethod
from typing import Optional

from airbyte_cdk.destinations.vector_db_based.config import FakeEmbeddingConfigModel, OpenAIEmbeddingConfigModel
from airbyte_cdk.destinations.vector_db_based.utils import format_exception
from langchain.embeddings.base import Embeddings
from langchain.embeddings.fake import FakeEmbeddings
from langchain.embeddings.openai import OpenAIEmbeddings


class Embedder(ABC):
def __init__(self):
pass

@abstractmethod
def check(self) -> Optional[str]:
pass

@property
@abstractmethod
def langchain_embeddings(self) -> Embeddings:
pass

@property
@abstractmethod
def embedding_dimensions(self) -> int:
pass


OPEN_AI_VECTOR_SIZE = 1536


class OpenAIEmbedder(Embedder):
def __init__(self, config: OpenAIEmbeddingConfigModel):
super().__init__()
self.embeddings = OpenAIEmbeddings(openai_api_key=config.openai_key, chunk_size=8191)

def check(self) -> Optional[str]:
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None

@property
def langchain_embeddings(self) -> Embeddings:
return self.embeddings

@property
def embedding_dimensions(self) -> int:
# vector size produced by text-embedding-ada-002 model
return OPEN_AI_VECTOR_SIZE



COHERE_VECTOR_SIZE = 1024


class FakeEmbedder(Embedder):
def __init__(self, config: FakeEmbeddingConfigModel):
super().__init__()
self.embeddings = FakeEmbeddings(size=OPEN_AI_VECTOR_SIZE)

def check(self) -> Optional[str]:
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None

@property
def langchain_embeddings(self) -> Embeddings:
return self.embeddings

@property
def embedding_dimensions(self) -> int:
# use same vector size as for OpenAI embeddings to keep it realistic
return OPEN_AI_VECTOR_SIZE
Loading

0 comments on commit 9a9f6fe

Please sign in to comment.