Skip to content

Commit

Permalink
bugfix/make sure all SDK clients get closed (#300)
Browse files Browse the repository at this point in the history
* make sure all clients get closed

* fix azure ai int test

* fix gitlab precheck

* make sure all uploaders get precheck called in int tests

* fix s3 int test and quadrant precheck

* use non async client for precheck in qdrant uploader

* remove unstructured constraint

* bump changelog
  • Loading branch information
rbiseck3 authored Dec 16, 2024
1 parent 0dffae1 commit 7c0b03f
Show file tree
Hide file tree
Showing 24 changed files with 270 additions and 223 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
## 0.3.9-dev1
## 0.3.9-dev2

### Enhancements

* **Support ndjson files in stagers**

### Fixes

* **Make sure any SDK clients that support closing get called**

## 0.3.8

Expand Down
2 changes: 1 addition & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ annotated-types==0.7.0
# via pydantic
cachetools==5.5.0
# via google-auth
certifi==2024.8.30
certifi==2024.12.14
# via requests
cffi==1.17.1
# via cryptography
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ async def test_volumes_native_destination(upload_file: Path):
catalog=env_data.catalog,
),
)
uploader.precheck()
if uploader.is_async():
await uploader.run_async(path=upload_file, file_data=file_data)
else:
Expand Down
2 changes: 1 addition & 1 deletion test/integration/connectors/sql/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ async def test_postgres_destination(upload_file: Path, temp_dir: Path):
access_config=PostgresAccessConfig(password=connect_params["password"]),
)
)

uploader.precheck()
uploader.run(path=staged_path, file_data=mock_file_data)

with staged_path.open("r") as f:
Expand Down
2 changes: 1 addition & 1 deletion test/integration/connectors/sql/test_singlestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ async def test_singlestore_destination(upload_file: Path, temp_dir: Path):
table_name="elements",
),
)

uploader.precheck()
uploader.run(path=staged_path, file_data=mock_file_data)

with staged_path.open("r") as f:
Expand Down
2 changes: 1 addition & 1 deletion test/integration/connectors/sql/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ async def test_snowflake_destination(
host=connect_params["host"],
)
)

uploader.precheck()
uploader.run(path=staged_path, file_data=mock_file_data)

with staged_path.open("r") as f:
Expand Down
1 change: 1 addition & 0 deletions test/integration/connectors/sql/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ async def test_sqlite_destination(
uploader = SQLiteUploader(
connection_config=SQLiteConnectionConfig(database_path=destination_database_setup)
)
uploader.precheck()
uploader.run(path=staged_path, file_data=mock_file_data)

with staged_path.open("r") as f:
Expand Down
7 changes: 4 additions & 3 deletions test/integration/connectors/test_azure_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,13 @@ async def test_azure_ai_search_destination(
with staged_filepath.open() as f:
staged_elements = json.load(f)
expected_count = len(staged_elements)
search_client: SearchClient = uploader.connection_config.get_search_client()
validate_count(search_client=search_client, expected_count=expected_count)
with uploader.connection_config.get_search_client() as search_client:
validate_count(search_client=search_client, expected_count=expected_count)

# Rerun and make sure the same documents get updated
uploader.run(path=staged_filepath, file_data=file_data)
validate_count(search_client=search_client, expected_count=expected_count)
with uploader.connection_config.get_search_client() as search_client:
validate_count(search_client=search_client, expected_count=expected_count)


@pytest.mark.parametrize("upload_file_str", ["upload_file_ndjson", "upload_file"])
Expand Down
1 change: 1 addition & 0 deletions test/integration/connectors/test_delta_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ async def test_delta_table_destination_s3(upload_file: Path, temp_dir: Path):
)

try:
uploader.precheck()
if uploader.is_async():
await uploader.run_async(path=new_upload_file, file_data=file_data)
else:
Expand Down
4 changes: 2 additions & 2 deletions test/integration/connectors/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ async def test_qdrant_destination_server(upload_file: Path, tmp_path: Path, dock
output_dir=tmp_path,
output_filename=upload_file.name,
)

uploader.precheck()
if uploader.is_async():
await uploader.run_async(path=staged_upload_file, file_data=file_data)
else:
Expand Down Expand Up @@ -188,7 +188,7 @@ async def test_qdrant_destination_cloud(upload_file: Path, tmp_path: Path):
output_dir=tmp_path,
output_filename=upload_file.name,
)

uploader.precheck()
if uploader.is_async():
await uploader.run_async(path=staged_upload_file, file_data=file_data)
else:
Expand Down
5 changes: 4 additions & 1 deletion test/integration/connectors/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,14 @@ async def test_s3_destination(upload_file: Path):
identifier="mock file data",
)
try:
uploader.precheck()
if uploader.is_async():
await uploader.run_async(path=upload_file, file_data=file_data)
else:
uploader.run(path=upload_file, file_data=file_data)
uploaded_files = s3fs.ls(path=destination_path)
uploaded_files = [
Path(file) for file in s3fs.ls(path=destination_path) if Path(file).name != "_empty"
]
assert len(uploaded_files) == 1
finally:
s3fs.rm(path=destination_path, recursive=True)
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.9-dev1" # pragma: no cover
__version__ = "0.3.9-dev2" # pragma: no cover
19 changes: 12 additions & 7 deletions unstructured_ingest/v2/processes/connectors/azure_ai_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Generator

from pydantic import Field, Secret

Expand Down Expand Up @@ -49,29 +50,33 @@ class AzureAISearchConnectionConfig(ConnectionConfig):
access_config: Secret[AzureAISearchAccessConfig]

@requires_dependencies(["azure.search", "azure.core"], extras="azure-ai-search")
def get_search_client(self) -> "SearchClient":
@contextmanager
def get_search_client(self) -> Generator["SearchClient", None, None]:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient

return SearchClient(
with SearchClient(
endpoint=self.endpoint,
index_name=self.index,
credential=AzureKeyCredential(
self.access_config.get_secret_value().azure_ai_search_key
),
)
) as client:
yield client

@requires_dependencies(["azure.search", "azure.core"], extras="azure-ai-search")
def get_search_index_client(self) -> "SearchIndexClient":
@contextmanager
def get_search_index_client(self) -> Generator["SearchIndexClient", None, None]:
from azure.core.credentials import AzureKeyCredential
from azure.search.documents.indexes import SearchIndexClient

return SearchIndexClient(
with SearchIndexClient(
endpoint=self.endpoint,
credential=AzureKeyCredential(
self.access_config.get_secret_value().azure_ai_search_key
),
)
) as search_index_client:
yield search_index_client


class AzureAISearchUploadStagerConfig(UploadStagerConfig):
Expand Down
56 changes: 28 additions & 28 deletions unstructured_ingest/v2/processes/connectors/chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,32 @@ class ChromaConnectionConfig(ConnectionConfig):
)
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)

@requires_dependencies(["chromadb"], extras="chroma")
def get_client(self) -> "Client":
import chromadb

access_config = self.access_config.get_secret_value()
if path := self.path:
return chromadb.PersistentClient(
path=path,
settings=access_config.settings,
tenant=self.tenant,
database=self.database,
)

elif (host := self.host) and (port := self.port):
return chromadb.HttpClient(
host=host,
port=str(port),
ssl=self.ssl,
headers=access_config.headers,
settings=access_config.settings,
tenant=self.tenant,
database=self.database,
)
else:
raise ValueError("Chroma connector requires either path or host and port to be set.")


class ChromaUploadStagerConfig(UploadStagerConfig):
pass
Expand Down Expand Up @@ -107,37 +133,11 @@ class ChromaUploader(Uploader):

def precheck(self) -> None:
try:
self.create_client()
self.connection_config.get_client()
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")

@requires_dependencies(["chromadb"], extras="chroma")
def create_client(self) -> "Client":
import chromadb

access_config = self.connection_config.access_config.get_secret_value()
if self.connection_config.path:
return chromadb.PersistentClient(
path=self.connection_config.path,
settings=access_config.settings,
tenant=self.connection_config.tenant,
database=self.connection_config.database,
)

elif self.connection_config.host and self.connection_config.port:
return chromadb.HttpClient(
host=self.connection_config.host,
port=self.connection_config.port,
ssl=self.connection_config.ssl,
headers=access_config.headers,
settings=access_config.settings,
tenant=self.connection_config.tenant,
database=self.connection_config.database,
)
else:
raise ValueError("Chroma connector requires either path or host and port to be set.")

@DestinationConnectionError.wrap
def upsert_batch(self, collection, batch):

Expand Down Expand Up @@ -180,7 +180,7 @@ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
f"collection {self.connection_config.collection_name} "
f"at {self.connection_config.host}",
)
client = self.create_client()
client = self.connection_config.get_client()

collection = client.get_or_create_collection(name=self.connection_config.collection_name)
for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
Expand Down
53 changes: 31 additions & 22 deletions unstructured_ingest/v2/processes/connectors/couchbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import sys
import time
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
Expand Down Expand Up @@ -65,17 +66,23 @@ class CouchbaseConnectionConfig(ConnectionConfig):
access_config: Secret[CouchbaseAccessConfig]

@requires_dependencies(["couchbase"], extras="couchbase")
def connect_to_couchbase(self) -> "Cluster":
@contextmanager
def get_client(self) -> Generator["Cluster", None, None]:
from couchbase.auth import PasswordAuthenticator
from couchbase.cluster import Cluster
from couchbase.options import ClusterOptions

auth = PasswordAuthenticator(self.username, self.access_config.get_secret_value().password)
options = ClusterOptions(auth)
options.apply_profile("wan_development")
cluster = Cluster(self.connection_string, options)
cluster.wait_until_ready(timedelta(seconds=5))
return cluster
cluster = None
try:
cluster = Cluster(self.connection_string, options)
cluster.wait_until_ready(timedelta(seconds=5))
yield cluster
finally:
if cluster:
cluster.close()


class CouchbaseUploadStagerConfig(UploadStagerConfig):
Expand Down Expand Up @@ -112,7 +119,7 @@ class CouchbaseUploader(Uploader):

def precheck(self) -> None:
try:
self.connection_config.connect_to_couchbase()
self.connection_config.get_client()
except Exception as e:
logger.error(f"Failed to validate connection {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")
Expand All @@ -125,13 +132,15 @@ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
f"bucket, {self.connection_config.bucket} "
f"at {self.connection_config.connection_string}",
)
cluster = self.connection_config.connect_to_couchbase()
bucket = cluster.bucket(self.connection_config.bucket)
scope = bucket.scope(self.connection_config.scope)
collection = scope.collection(self.connection_config.collection)
with self.connection_config.get_client() as client:
bucket = client.bucket(self.connection_config.bucket)
scope = bucket.scope(self.connection_config.scope)
collection = scope.collection(self.connection_config.collection)

for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
collection.upsert_multi({doc_id: doc for doc in chunk for doc_id, doc in doc.items()})
for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
collection.upsert_multi(
{doc_id: doc for doc in chunk for doc_id, doc in doc.items()}
)


class CouchbaseIndexerConfig(IndexerConfig):
Expand All @@ -146,7 +155,7 @@ class CouchbaseIndexer(Indexer):

def precheck(self) -> None:
try:
self.connection_config.connect_to_couchbase()
self.connection_config.get_client()
except Exception as e:
logger.error(f"Failed to validate connection {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")
Expand All @@ -164,10 +173,10 @@ def _get_doc_ids(self) -> List[str]:
attempts = 0
while attempts < max_attempts:
try:
cluster = self.connection_config.connect_to_couchbase()
result = cluster.query(query)
document_ids = [row["id"] for row in result]
return document_ids
with self.connection_config.get_client() as client:
result = client.query(query)
document_ids = [row["id"] for row in result]
return document_ids
except Exception as e:
attempts += 1
time.sleep(3)
Expand Down Expand Up @@ -278,13 +287,13 @@ def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
bucket_name: str = file_data.additional_metadata["bucket"]
ids: list[str] = file_data.additional_metadata["ids"]

cluster = self.connection_config.connect_to_couchbase()
bucket = cluster.bucket(bucket_name)
scope = bucket.scope(self.connection_config.scope)
collection = scope.collection(self.connection_config.collection)
with self.connection_config.get_client() as client:
bucket = client.bucket(bucket_name)
scope = bucket.scope(self.connection_config.scope)
collection = scope.collection(self.connection_config.collection)

download_resp = self.process_all_doc_ids(ids, collection, bucket_name, file_data)
return list(download_resp)
download_resp = self.process_all_doc_ids(ids, collection, bucket_name, file_data)
return list(download_resp)

def process_doc_id(self, doc_id, collection, bucket_name, file_data):
result = collection.get(doc_id)
Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/v2/processes/connectors/delta_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CONNECTOR_TYPE = "delta_table"


@requires_dependencies(["deltalake"], extras="delta-table")
def write_deltalake_with_error_handling(queue, **kwargs):
from deltalake.writer import write_deltalake

Expand Down Expand Up @@ -165,7 +166,6 @@ def read_dataframe(self, path: Path) -> pd.DataFrame:
else:
raise ValueError(f"Unsupported file type, must be parquet, json or csv file: {path}")

@requires_dependencies(["deltalake"], extras="delta-table")
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:

df = self.read_dataframe(path)
Expand Down
Loading

0 comments on commit 7c0b03f

Please sign in to comment.