From 67e084df7faf9a12f3b82fa504600d462aae1ac7 Mon Sep 17 00:00:00 2001 From: Roman Isecke Date: Fri, 13 Dec 2024 11:58:04 -0500 Subject: [PATCH] update couchbase --- .../v2/processes/connectors/couchbase.py | 93 ++++++++++--------- .../connectors/elasticsearch/elasticsearch.py | 51 ++++------ 2 files changed, 70 insertions(+), 74 deletions(-) diff --git a/unstructured_ingest/v2/processes/connectors/couchbase.py b/unstructured_ingest/v2/processes/connectors/couchbase.py index 1777d0e2..ac1946a6 100644 --- a/unstructured_ingest/v2/processes/connectors/couchbase.py +++ b/unstructured_ingest/v2/processes/connectors/couchbase.py @@ -1,13 +1,12 @@ import hashlib import json -import sys import time from dataclasses import dataclass, field from datetime import timedelta from pathlib import Path from typing import TYPE_CHECKING, Any, Generator, List -from pydantic import Field, Secret +from pydantic import BaseModel, Field, Secret from unstructured_ingest.error import ( DestinationConnectionError, @@ -18,6 +17,8 @@ from unstructured_ingest.utils.dep_check import requires_dependencies from unstructured_ingest.v2.interfaces import ( AccessConfig, + BatchFileData, + BatchItem, ConnectionConfig, Downloader, DownloaderConfig, @@ -40,11 +41,20 @@ if TYPE_CHECKING: from couchbase.cluster import Cluster + from couchbase.collection import Collection CONNECTOR_TYPE = "couchbase" SERVER_API_VERSION = "1" +class FileDataMetadata(BaseModel): + bucket: str + + +class CouchbaseBatchFileData(BatchFileData): + additional_metadata: FileDataMetadata + + class CouchbaseAccessConfig(AccessConfig): password: str = Field(description="The password for the Couchbase server") @@ -166,7 +176,7 @@ def _get_doc_ids(self) -> List[str]: try: cluster = self.connection_config.connect_to_couchbase() result = cluster.query(query) - document_ids = [row["id"] for row in result] + document_ids = sorted([row["id"] for row in result]) return document_ids except Exception as e: attempts += 1 @@ -174,31 +184,19 @@ def _get_doc_ids(self) -> List[str]: if attempts == max_attempts: raise SourceConnectionError(f"failed to get document ids: {e}") - def run(self, **kwargs: Any) -> Generator[FileData, None, None]: + def run(self, **kwargs: Any) -> Generator[CouchbaseBatchFileData, None, None]: ids = self._get_doc_ids() - - id_batches = [ - ids[i * self.index_config.batch_size : (i + 1) * self.index_config.batch_size] - for i in range( - (len(ids) + self.index_config.batch_size - 1) // self.index_config.batch_size - ) - ] - for batch in id_batches: + for batch in batch_generator(ids, self.index_config.batch_size): # Make sure the hash is always a positive number to create identified - identified = str(hash(tuple(batch)) + sys.maxsize + 1) - yield FileData( - identifier=identified, + yield CouchbaseBatchFileData( connector_type=CONNECTOR_TYPE, - doc_type="batch", metadata=FileDataSourceMetadata( url=f"{self.connection_config.connection_string}/" f"{self.connection_config.bucket}", date_processed=str(time.time()), ), - additional_metadata={ - "ids": list(batch), - "bucket": self.connection_config.bucket, - }, + additional_metadata=FileDataMetadata(bucket=self.connection_config.bucket), + batch_items=[BatchItem(identifier=b) for b in batch], ) @@ -235,7 +233,7 @@ def map_cb_results(self, cb_results: dict) -> str: return concatenated_values def generate_download_response( - self, result: dict, bucket: str, file_data: FileData + self, result: dict, bucket: str, file_data: CouchbaseBatchFileData ) -> DownloadResponse: record_id = result[self.download_config.collection_id] filename_id = self.get_identifier(bucket=bucket, record_id=record_id) @@ -255,44 +253,53 @@ def generate_download_response( exc_info=True, ) raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}") - return DownloadResponse( - file_data=FileData( - identifier=filename_id, - connector_type=CONNECTOR_TYPE, - metadata=FileDataSourceMetadata( - version=None, - date_processed=str(time.time()), - record_locator={ - "connection_string": self.connection_config.connection_string, - "bucket": bucket, - "scope": self.connection_config.scope, - "collection": self.connection_config.collection, - "document_id": record_id, - }, - ), - ), - path=download_path, + cast_file_data = FileData.cast(file_data=file_data) + cast_file_data.identifier = filename_id + cast_file_data.metadata.date_processed = str(time.time()) + cast_file_data.metadata.record_locator = { + "connection_string": self.connection_config.connection_string, + "bucket": bucket, + "scope": self.connection_config.scope, + "collection": self.connection_config.collection, + "document_id": record_id, + } + return super().generate_download_response( + file_data=cast_file_data, + download_path=download_path, ) def run(self, file_data: FileData, **kwargs: Any) -> download_responses: - bucket_name: str = file_data.additional_metadata["bucket"] - ids: list[str] = file_data.additional_metadata["ids"] + couchbase_file_data = CouchbaseBatchFileData.cast(file_data=file_data) + bucket_name: str = couchbase_file_data.additional_metadata.bucket + ids: list[str] = [item.identifier for item in couchbase_file_data.batch_items] cluster = self.connection_config.connect_to_couchbase() bucket = cluster.bucket(bucket_name) scope = bucket.scope(self.connection_config.scope) collection = scope.collection(self.connection_config.collection) - download_resp = self.process_all_doc_ids(ids, collection, bucket_name, file_data) + download_resp = self.process_all_doc_ids(ids, collection, bucket_name, couchbase_file_data) return list(download_resp) - def process_doc_id(self, doc_id, collection, bucket_name, file_data): + def process_doc_id( + self, + doc_id: str, + collection: "Collection", + bucket_name: str, + file_data: CouchbaseBatchFileData, + ): result = collection.get(doc_id) return self.generate_download_response( result=result.content_as[dict], bucket=bucket_name, file_data=file_data ) - def process_all_doc_ids(self, ids, collection, bucket_name, file_data): + def process_all_doc_ids( + self, + ids: list[str], + collection: "Collection", + bucket_name: str, + file_data: CouchbaseBatchFileData, + ): for doc_id in ids: yield self.process_doc_id(doc_id, collection, bucket_name, file_data) diff --git a/unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py b/unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py index 6712997c..3e5fa177 100644 --- a/unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py +++ b/unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py @@ -15,7 +15,11 @@ SourceConnectionNetworkError, WriteError, ) -from unstructured_ingest.utils.data_prep import flatten_dict, generator_batching_wbytes +from unstructured_ingest.utils.data_prep import ( + batch_generator, + flatten_dict, + generator_batching_wbytes, +) from unstructured_ingest.utils.dep_check import requires_dependencies from unstructured_ingest.v2.constants import RECORD_ID_LABEL from unstructured_ingest.v2.interfaces import ( @@ -187,19 +191,7 @@ def _get_doc_ids(self) -> set[str]: def run(self, **kwargs: Any) -> Generator[ElasticsearchBatchFileData, None, None]: all_ids = self._get_doc_ids() ids = list(all_ids) - id_batches: list[frozenset[str]] = [ - frozenset( - ids[ - i - * self.index_config.batch_size : (i + 1) # noqa - * self.index_config.batch_size - ] - ) - for i in range( - (len(ids) + self.index_config.batch_size - 1) // self.index_config.batch_size - ) - ] - for batch in id_batches: + for batch in batch_generator(ids, self.index_config.batch_size): # Make sure the hash is always a positive number to create identified yield ElasticsearchBatchFileData( connector_type=CONNECTOR_TYPE, @@ -244,7 +236,7 @@ def map_es_results(self, es_results: dict) -> str: return concatenated_values def generate_download_response( - self, result: dict, index_name: str, file_data: FileData + self, result: dict, index_name: str, file_data: ElasticsearchBatchFileData ) -> DownloadResponse: record_id = result["_id"] filename_id = self.get_identifier(index_name=index_name, record_id=record_id) @@ -264,22 +256,19 @@ def generate_download_response( exc_info=True, ) raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}") - return DownloadResponse( - file_data=FileData( - identifier=filename_id, - connector_type=CONNECTOR_TYPE, - source_identifiers=SourceIdentifiers(filename=filename, fullpath=filename), - metadata=FileDataSourceMetadata( - version=str(result["_version"]) if "_version" in result else None, - date_processed=str(time()), - record_locator={ - "hosts": self.connection_config.hosts, - "index_name": index_name, - "document_id": record_id, - }, - ), - ), - path=download_path, + cast_file_data = FileData.cast(file_data=file_data) + cast_file_data.identifier = filename_id + cast_file_data.metadata.date_processed = str(time()) + cast_file_data.metadata.version = str(result["_version"]) if "_version" in result else None + cast_file_data.metadata.record_locator = { + "hosts": self.connection_config.hosts, + "index_name": index_name, + "document_id": record_id, + } + cast_file_data.source_identifiers = SourceIdentifiers(filename=filename, fullpath=filename) + return super().generate_download_response( + file_data=cast_file_data, + download_path=download_path, ) def run(self, file_data: FileData, **kwargs: Any) -> download_responses: