Skip to content

Commit

Permalink
update couchbase
Browse files Browse the repository at this point in the history
  • Loading branch information
rbiseck3 committed Dec 13, 2024
1 parent 7ea2f52 commit 67e084d
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 74 deletions.
93 changes: 50 additions & 43 deletions unstructured_ingest/v2/processes/connectors/couchbase.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import hashlib
import json
import sys
import time
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator, List

from pydantic import Field, Secret
from pydantic import BaseModel, Field, Secret

from unstructured_ingest.error import (
DestinationConnectionError,
Expand All @@ -18,6 +17,8 @@
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.v2.interfaces import (
AccessConfig,
BatchFileData,
BatchItem,
ConnectionConfig,
Downloader,
DownloaderConfig,
Expand All @@ -40,11 +41,20 @@

if TYPE_CHECKING:
from couchbase.cluster import Cluster
from couchbase.collection import Collection

CONNECTOR_TYPE = "couchbase"
SERVER_API_VERSION = "1"


class FileDataMetadata(BaseModel):
bucket: str


class CouchbaseBatchFileData(BatchFileData):
additional_metadata: FileDataMetadata


class CouchbaseAccessConfig(AccessConfig):
password: str = Field(description="The password for the Couchbase server")

Expand Down Expand Up @@ -166,39 +176,27 @@ def _get_doc_ids(self) -> List[str]:
try:
cluster = self.connection_config.connect_to_couchbase()
result = cluster.query(query)
document_ids = [row["id"] for row in result]
document_ids = sorted([row["id"] for row in result])
return document_ids
except Exception as e:
attempts += 1
time.sleep(3)
if attempts == max_attempts:
raise SourceConnectionError(f"failed to get document ids: {e}")

def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
def run(self, **kwargs: Any) -> Generator[CouchbaseBatchFileData, None, None]:
ids = self._get_doc_ids()

id_batches = [
ids[i * self.index_config.batch_size : (i + 1) * self.index_config.batch_size]
for i in range(
(len(ids) + self.index_config.batch_size - 1) // self.index_config.batch_size
)
]
for batch in id_batches:
for batch in batch_generator(ids, self.index_config.batch_size):
# Make sure the hash is always a positive number to create identified
identified = str(hash(tuple(batch)) + sys.maxsize + 1)
yield FileData(
identifier=identified,
yield CouchbaseBatchFileData(
connector_type=CONNECTOR_TYPE,
doc_type="batch",
metadata=FileDataSourceMetadata(
url=f"{self.connection_config.connection_string}/"
f"{self.connection_config.bucket}",
date_processed=str(time.time()),
),
additional_metadata={
"ids": list(batch),
"bucket": self.connection_config.bucket,
},
additional_metadata=FileDataMetadata(bucket=self.connection_config.bucket),
batch_items=[BatchItem(identifier=b) for b in batch],
)


Expand Down Expand Up @@ -235,7 +233,7 @@ def map_cb_results(self, cb_results: dict) -> str:
return concatenated_values

def generate_download_response(
self, result: dict, bucket: str, file_data: FileData
self, result: dict, bucket: str, file_data: CouchbaseBatchFileData
) -> DownloadResponse:
record_id = result[self.download_config.collection_id]
filename_id = self.get_identifier(bucket=bucket, record_id=record_id)
Expand All @@ -255,44 +253,53 @@ def generate_download_response(
exc_info=True,
)
raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
return DownloadResponse(
file_data=FileData(
identifier=filename_id,
connector_type=CONNECTOR_TYPE,
metadata=FileDataSourceMetadata(
version=None,
date_processed=str(time.time()),
record_locator={
"connection_string": self.connection_config.connection_string,
"bucket": bucket,
"scope": self.connection_config.scope,
"collection": self.connection_config.collection,
"document_id": record_id,
},
),
),
path=download_path,
cast_file_data = FileData.cast(file_data=file_data)
cast_file_data.identifier = filename_id
cast_file_data.metadata.date_processed = str(time.time())
cast_file_data.metadata.record_locator = {
"connection_string": self.connection_config.connection_string,
"bucket": bucket,
"scope": self.connection_config.scope,
"collection": self.connection_config.collection,
"document_id": record_id,
}
return super().generate_download_response(
file_data=cast_file_data,
download_path=download_path,
)

def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
bucket_name: str = file_data.additional_metadata["bucket"]
ids: list[str] = file_data.additional_metadata["ids"]
couchbase_file_data = CouchbaseBatchFileData.cast(file_data=file_data)
bucket_name: str = couchbase_file_data.additional_metadata.bucket
ids: list[str] = [item.identifier for item in couchbase_file_data.batch_items]

cluster = self.connection_config.connect_to_couchbase()
bucket = cluster.bucket(bucket_name)
scope = bucket.scope(self.connection_config.scope)
collection = scope.collection(self.connection_config.collection)

download_resp = self.process_all_doc_ids(ids, collection, bucket_name, file_data)
download_resp = self.process_all_doc_ids(ids, collection, bucket_name, couchbase_file_data)
return list(download_resp)

def process_doc_id(self, doc_id, collection, bucket_name, file_data):
def process_doc_id(
self,
doc_id: str,
collection: "Collection",
bucket_name: str,
file_data: CouchbaseBatchFileData,
):
result = collection.get(doc_id)
return self.generate_download_response(
result=result.content_as[dict], bucket=bucket_name, file_data=file_data
)

def process_all_doc_ids(self, ids, collection, bucket_name, file_data):
def process_all_doc_ids(
self,
ids: list[str],
collection: "Collection",
bucket_name: str,
file_data: CouchbaseBatchFileData,
):
for doc_id in ids:
yield self.process_doc_id(doc_id, collection, bucket_name, file_data)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
SourceConnectionNetworkError,
WriteError,
)
from unstructured_ingest.utils.data_prep import flatten_dict, generator_batching_wbytes
from unstructured_ingest.utils.data_prep import (
batch_generator,
flatten_dict,
generator_batching_wbytes,
)
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
from unstructured_ingest.v2.interfaces import (
Expand Down Expand Up @@ -187,19 +191,7 @@ def _get_doc_ids(self) -> set[str]:
def run(self, **kwargs: Any) -> Generator[ElasticsearchBatchFileData, None, None]:
all_ids = self._get_doc_ids()
ids = list(all_ids)
id_batches: list[frozenset[str]] = [
frozenset(
ids[
i
* self.index_config.batch_size : (i + 1) # noqa
* self.index_config.batch_size
]
)
for i in range(
(len(ids) + self.index_config.batch_size - 1) // self.index_config.batch_size
)
]
for batch in id_batches:
for batch in batch_generator(ids, self.index_config.batch_size):
# Make sure the hash is always a positive number to create identified
yield ElasticsearchBatchFileData(
connector_type=CONNECTOR_TYPE,
Expand Down Expand Up @@ -244,7 +236,7 @@ def map_es_results(self, es_results: dict) -> str:
return concatenated_values

def generate_download_response(
self, result: dict, index_name: str, file_data: FileData
self, result: dict, index_name: str, file_data: ElasticsearchBatchFileData
) -> DownloadResponse:
record_id = result["_id"]
filename_id = self.get_identifier(index_name=index_name, record_id=record_id)
Expand All @@ -264,22 +256,19 @@ def generate_download_response(
exc_info=True,
)
raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
return DownloadResponse(
file_data=FileData(
identifier=filename_id,
connector_type=CONNECTOR_TYPE,
source_identifiers=SourceIdentifiers(filename=filename, fullpath=filename),
metadata=FileDataSourceMetadata(
version=str(result["_version"]) if "_version" in result else None,
date_processed=str(time()),
record_locator={
"hosts": self.connection_config.hosts,
"index_name": index_name,
"document_id": record_id,
},
),
),
path=download_path,
cast_file_data = FileData.cast(file_data=file_data)
cast_file_data.identifier = filename_id
cast_file_data.metadata.date_processed = str(time())
cast_file_data.metadata.version = str(result["_version"]) if "_version" in result else None
cast_file_data.metadata.record_locator = {
"hosts": self.connection_config.hosts,
"index_name": index_name,
"document_id": record_id,
}
cast_file_data.source_identifiers = SourceIdentifiers(filename=filename, fullpath=filename)
return super().generate_download_response(
file_data=cast_file_data,
download_path=download_path,
)

def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
Expand Down

0 comments on commit 67e084d

Please sign in to comment.