Skip to content

Commit

Permalink
add new methods to all uploaders
Browse files Browse the repository at this point in the history
  • Loading branch information
rbiseck3 committed Dec 16, 2024
1 parent 7c0b03f commit 663786c
Show file tree
Hide file tree
Showing 19 changed files with 146 additions and 176 deletions.
36 changes: 36 additions & 0 deletions unstructured_ingest/utils/data_prep.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import itertools
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Generator, Iterable, Optional, Sequence, TypeVar, cast

import ndjson
import pandas as pd

DATE_FORMATS = ("%Y-%m-%d", "%Y-%m-%dT%H:%M:%S", "%Y-%m-%d+%H:%M:%S", "%Y-%m-%dT%H:%M:%S%z")
Expand Down Expand Up @@ -131,3 +133,37 @@ def validate_date_args(date: Optional[str] = None) -> bool:
f"The argument {date} does not satisfy the format:"
f" YYYY-MM-DD or YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD+HH:MM:SS or YYYY-MM-DDTHH:MM:SS±HHMM",
)


def get_data(path: Path) -> list[dict]:
with path.open() as f:
if path.suffix == ".json":
return json.load(f)
elif path.suffix == ".ndjson":
return ndjson.load(f)
elif path.suffix == ".csv":
df = pd.read_csv(path)
return df.to_dict(orient="records")
elif path.suffix == ".parquet":
df = pd.read_parquet(path)
return df.to_dict(orient="records")
else:
raise ValueError(f"Unsupported file type: {path}")


def get_data_df(path: Path) -> pd.DataFrame:
with path.open() as f:
if path.suffix == ".json":
data = json.load(f)
return pd.DataFrame(data=data)
elif path.suffix == ".ndjson":
data = ndjson.load(f)
return pd.DataFrame(data=data)
elif path.suffix == ".csv":
df = pd.read_csv(path)
return df
elif path.suffix == ".parquet":
df = pd.read_parquet(path)
return df
else:
raise ValueError(f"Unsupported file type: {path}")
13 changes: 11 additions & 2 deletions unstructured_ingest/v2/interfaces/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pydantic import BaseModel

from unstructured_ingest.utils.data_prep import get_data
from unstructured_ingest.v2.interfaces.connector import BaseConnector
from unstructured_ingest.v2.interfaces.file_data import FileData
from unstructured_ingest.v2.interfaces.process import BaseProcess
Expand Down Expand Up @@ -38,7 +39,15 @@ def run_batch(self, contents: list[UploadContent], **kwargs: Any) -> None:
raise NotImplementedError()

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
raise NotImplementedError()
data = get_data(path=path)
self.run_data(data=data, file_data=file_data, **kwargs)

async def run_async(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
return self.run(contents=[UploadContent(path=path, file_data=file_data)], **kwargs)
data = get_data(path=path)
await self.run_data_async(data=data, file_data=file_data, **kwargs)

def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
raise NotImplementedError()

async def run_data_async(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
return self.run_data(data=data, file_data=file_data, **kwargs)
15 changes: 8 additions & 7 deletions unstructured_ingest/v2/processes/connectors/astradb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import copy
import csv
import hashlib
import json
import sys
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -17,7 +16,7 @@
SourceConnectionError,
SourceConnectionNetworkError,
)
from unstructured_ingest.utils.data_prep import batch_generator
from unstructured_ingest.utils.data_prep import batch_generator, get_data
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.utils.string_and_date_utils import truncate_string_bytes
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
Expand Down Expand Up @@ -363,11 +362,9 @@ def delete_by_record_id(self, collection: "AstraDBCollection", file_data: FileDa
f"deleted {delete_resp.deleted_count} records from collection {collection.name}"
)

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
with path.open("r") as file:
elements_dict = json.load(file)
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
logger.info(
f"writing {len(elements_dict)} objects to destination "
f"writing {len(data)} objects to destination "
f"collection {self.upload_config.collection_name}"
)

Expand All @@ -376,9 +373,13 @@ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:

self.delete_by_record_id(collection=collection, file_data=file_data)

for chunk in batch_generator(elements_dict, astra_db_batch_size):
for chunk in batch_generator(data, astra_db_batch_size):
collection.insert_many(chunk)

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
data = get_data(path=path)
self.run_data(data=data, file_data=file_data, **kwargs)


astra_db_source_entry = SourceRegistryEntry(
indexer=AstraDBIndexer,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Generator

from pydantic import Field, Secret
Expand Down Expand Up @@ -249,9 +248,7 @@ def precheck(self) -> None:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
with path.open("r") as file:
elements_dict = json.load(file)
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
logger.info(
f"writing document batches to destination"
f" endpoint at {str(self.connection_config.endpoint)}"
Expand All @@ -266,7 +263,7 @@ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:

batch_size = self.upload_config.batch_size
with self.connection_config.get_search_client() as search_client:
for chunk in batch_generator(elements_dict, batch_size):
for chunk in batch_generator(data, batch_size):
self.write_dict(elements_dict=chunk, search_client=search_client) # noqa: E203


Expand Down
11 changes: 3 additions & 8 deletions unstructured_ingest/v2/processes/connectors/chroma.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import json
from dataclasses import dataclass, field
from datetime import date, datetime
from pathlib import Path
from typing import TYPE_CHECKING, Annotated, Any, Optional

from dateutil import parser
Expand Down Expand Up @@ -171,19 +169,16 @@ def prepare_chroma_list(chunk: tuple[dict[str, Any]]) -> dict[str, list[Any]]:
)
return chroma_dict

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
with path.open("r") as file:
elements_dict = json.load(file)

def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
logger.info(
f"writing {len(elements_dict)} objects to destination "
f"writing {len(data)} objects to destination "
f"collection {self.connection_config.collection_name} "
f"at {self.connection_config.host}",
)
client = self.connection_config.get_client()

collection = client.get_or_create_collection(name=self.connection_config.collection_name)
for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
for chunk in batch_generator(data, self.upload_config.batch_size):
self.upsert_batch(collection, self.prepare_chroma_list(chunk))


Expand Down
9 changes: 3 additions & 6 deletions unstructured_ingest/v2/processes/connectors/couchbase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import hashlib
import json
import sys
import time
from contextlib import contextmanager
Expand Down Expand Up @@ -124,11 +123,9 @@ def precheck(self) -> None:
logger.error(f"Failed to validate connection {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
with path.open("r") as file:
elements_dict = json.load(file)
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
logger.info(
f"writing {len(elements_dict)} objects to destination "
f"writing {len(data)} objects to destination "
f"bucket, {self.connection_config.bucket} "
f"at {self.connection_config.connection_string}",
)
Expand All @@ -137,7 +134,7 @@ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
scope = bucket.scope(self.connection_config.scope)
collection = scope.collection(self.connection_config.collection)

for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
for chunk in batch_generator(data, self.upload_config.batch_size):
collection.upsert_multi(
{doc_id: doc for doc in chunk for doc_id, doc in doc.items()}
)
Expand Down
42 changes: 10 additions & 32 deletions unstructured_ingest/v2/processes/connectors/delta_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pydantic import Field, Secret

from unstructured_ingest.error import DestinationConnectionError
from unstructured_ingest.utils.data_prep import get_data_df
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.utils.table import convert_to_pandas_dataframe
from unstructured_ingest.v2.interfaces import (
Expand Down Expand Up @@ -137,38 +138,7 @@ def precheck(self):
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")

def process_csv(self, csv_paths: list[Path]) -> pd.DataFrame:
logger.debug(f"uploading content from {len(csv_paths)} csv files")
df = pd.concat((pd.read_csv(path) for path in csv_paths), ignore_index=True)
return df

def process_json(self, json_paths: list[Path]) -> pd.DataFrame:
logger.debug(f"uploading content from {len(json_paths)} json files")
all_records = []
for p in json_paths:
with open(p) as json_file:
all_records.extend(json.load(json_file))

return pd.DataFrame(data=all_records)

def process_parquet(self, parquet_paths: list[Path]) -> pd.DataFrame:
logger.debug(f"uploading content from {len(parquet_paths)} parquet files")
df = pd.concat((pd.read_parquet(path) for path in parquet_paths), ignore_index=True)
return df

def read_dataframe(self, path: Path) -> pd.DataFrame:
if path.suffix == ".csv":
return self.process_csv(csv_paths=[path])
elif path.suffix == ".json":
return self.process_json(json_paths=[path])
elif path.suffix == ".parquet":
return self.process_parquet(parquet_paths=[path])
else:
raise ValueError(f"Unsupported file type, must be parquet, json or csv file: {path}")

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:

df = self.read_dataframe(path)
def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
updated_upload_path = os.path.join(
self.connection_config.table_uri, file_data.source_identifiers.relative_path
)
Expand Down Expand Up @@ -203,6 +173,14 @@ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
logger.error(f"Exception occurred in write_deltalake: {error_message}")
raise RuntimeError(f"Error in write_deltalake: {error_message}")

def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
df = pd.DataFrame(data=data)
self.upload_dataframe(df=df, file_data=file_data)

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
df = get_data_df(path)
self.upload_dataframe(df=df, file_data=file_data)


delta_table_destination_entry = DestinationRegistryEntry(
connection_config=DeltaTableConnectionConfig,
Expand Down
16 changes: 9 additions & 7 deletions unstructured_ingest/v2/processes/connectors/duckdb/duckdb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -8,6 +7,7 @@
from pydantic import Field, Secret

from unstructured_ingest.error import DestinationConnectionError
from unstructured_ingest.utils.data_prep import get_data_df
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.v2.interfaces import (
AccessConfig,
Expand Down Expand Up @@ -101,19 +101,21 @@ def precheck(self) -> None:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")

def upload_contents(self, path: Path) -> None:
with path.open() as f:
data = json.load(f)
df_elements = pd.DataFrame(data=data)
logger.debug(f"uploading {len(df_elements)} entries to {self.connection_config.database} ")
def upload_dataframe(self, df: pd.DataFrame) -> None:
logger.debug(f"uploading {len(df)} entries to {self.connection_config.database} ")

with self.connection_config.get_client() as conn:
conn.query(
f"INSERT INTO {self.connection_config.db_schema}.{self.connection_config.table} BY NAME SELECT * FROM df_elements" # noqa: E501
)

def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
df = pd.DataFrame(data=data)
self.upload_dataframe(df=df)

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
self.upload_contents(path=path)
df = get_data_df(path)
self.upload_dataframe(df=df)


duckdb_destination_entry = DestinationRegistryEntry(
Expand Down
17 changes: 9 additions & 8 deletions unstructured_ingest/v2/processes/connectors/duckdb/motherduck.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
Expand All @@ -9,6 +8,7 @@

from unstructured_ingest.__version__ import __version__ as unstructured_io_ingest_version
from unstructured_ingest.error import DestinationConnectionError
from unstructured_ingest.utils.data_prep import get_data_df
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.v2.interfaces import (
AccessConfig,
Expand Down Expand Up @@ -100,20 +100,21 @@ def precheck(self) -> None:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")

def upload_contents(self, path: Path) -> None:
with path.open() as f:
data = json.load(f)

df_elements = pd.DataFrame(data=data)
logger.debug(f"uploading {len(df_elements)} entries to {self.connection_config.database} ")
def upload_dataframe(self, df: pd.DataFrame) -> None:
logger.debug(f"uploading {len(df)} entries to {self.connection_config.database} ")

with self.connection_config.get_client() as conn:
conn.query(
f"INSERT INTO {self.connection_config.db_schema}.{self.connection_config.table} BY NAME SELECT * FROM df_elements" # noqa: E501
)

def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
df = pd.DataFrame(data=data)
self.upload_dataframe(df=df)

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
self.upload_contents(path=path)
df = get_data_df(path)
self.upload_dataframe(df=df)


motherduck_destination_entry = DestinationRegistryEntry(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import collections
import hashlib
import json
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -405,16 +404,14 @@ def delete_by_record_id(self, client, file_data: FileData) -> None:
raise WriteError(f"failed to delete records: {failures}")

@requires_dependencies(["elasticsearch"], extras="elasticsearch")
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None: # type: ignore
def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None: # noqa: E501
from elasticsearch.helpers.errors import BulkIndexError

parallel_bulk = self.load_parallel_bulk()
with path.open("r") as file:
elements_dict = json.load(file)
upload_destination = self.connection_config.hosts or self.connection_config.cloud_id

logger.info(
f"writing {len(elements_dict)} elements via document batches to destination "
f"writing {len(data)} elements via document batches to destination "
f"index named {self.upload_config.index_name} at {upload_destination} with "
f"batch size (in bytes) {self.upload_config.batch_size_bytes} with "
f"{self.upload_config.num_threads} (number of) threads"
Expand All @@ -429,7 +426,7 @@ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None: # type:
f"This may cause issues when uploading."
)
for batch in generator_batching_wbytes(
elements_dict, batch_size_limit_bytes=self.upload_config.batch_size_bytes
data, batch_size_limit_bytes=self.upload_config.batch_size_bytes
):
try:
iterator = parallel_bulk(
Expand Down
Loading

0 comments on commit 663786c

Please sign in to comment.