diff --git a/dask_deltatable/core.py b/dask_deltatable/core.py index b89e6d8..f79c4b8 100644 --- a/dask_deltatable/core.py +++ b/dask_deltatable/core.py @@ -18,7 +18,7 @@ from pyarrow import dataset as pa_ds from .types import Filters -from .utils import get_partition_filters +from .utils import get_partition_filters, maybe_set_aws_credentials if Version(pa.__version__) >= Version("10.0.0"): filters_to_expression = pq.filters_to_expression @@ -94,6 +94,9 @@ def _read_from_filesystem( """ Reads the list of parquet files in parallel """ + storage_options = maybe_set_aws_credentials(path, storage_options) + delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options) + fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options) dt = DeltaTable( table_uri=path, version=version, storage_options=delta_storage_options @@ -116,12 +119,14 @@ def _read_from_filesystem( if columns: meta = meta[columns] + kws = dict(meta=meta, label="read-delta-table") + if not dd._dask_expr_enabled(): + # Setting token not supported in dask-expr + kws["token"] = tokenize(path, fs_token, **kwargs) return dd.from_map( partial(_read_delta_partition, fs=fs, columns=columns, schema=schema, **kwargs), pq_files, - meta=meta, - label="read-delta-table", - token=tokenize(path, fs_token, **kwargs), + **kws, ) @@ -270,6 +275,8 @@ def read_deltalake( else: if path is None: raise ValueError("Please Provide Delta Table path") + + delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options) resultdf = _read_from_filesystem( path=path, version=version, diff --git a/dask_deltatable/utils.py b/dask_deltatable/utils.py index 3901f63..2534e88 100644 --- a/dask_deltatable/utils.py +++ b/dask_deltatable/utils.py @@ -1,10 +1,77 @@ from __future__ import annotations -from typing import cast +from typing import cast, Any from .types import Filter, Filters +def get_bucket_region(path: str): + import boto3 + + if not path.startswith("s3://"): + raise ValueError(f"'{path}' is not an S3 path") + bucket = path.replace("s3://", "").split("/")[0] + resp = boto3.client("s3").get_bucket_location(Bucket=bucket) + # Buckets in region 'us-east-1' results in None, b/c why not. + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/get_bucket_location.html#S3.Client.get_bucket_location + return resp["LocationConstraint"] or "us-east-1" + + +def maybe_set_aws_credentials(path: Any, options: dict) -> dict: + """ + Maybe set AWS credentials into ``options`` if existing AWS specific keys + not found in it and path is s3:// format. + + Parameters + ---------- + path : Any + If it's a string, we'll check if it starts with 's3://' then determine bucket + region if the AWS credentials should be set. + options : dict[str, Any] + Options, any kwargs to be supplied to things like S3FileSystem or similar + that may accept AWS credentials set. A copy is made and returned if modified. + + Returns + ------- + dict + Either the original options if not modified, or a copied and updated options + with AWS credentials inserted. + """ + + is_s3_path = getattr(path, "startswith", lambda _: False)("s3://") + if not is_s3_path: + return + + # Avoid overwriting already provided credentials + keys = ("AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY", "access_key", "secret_key") + if not any(k in (options or ()) for k in keys): + + # defers installing boto3 upfront, xref _read_from_catalog + import boto3 + + session = boto3.session.Session() + credentials = session.get_credentials() + region = get_bucket_region(path) + + options = (options or {}).copy() + options.update( + # Capitalized is used in delta specific API and lowercase is for S3FileSystem + dict( + # TODO: w/o this, we need to configure a LockClient which seems to require dynamodb. + AWS_S3_ALLOW_UNSAFE_RENAME="true", + AWS_SECRET_ACCESS_KEY=credentials.secret_key, + AWS_ACCESS_KEY_ID=credentials.access_key, + AWS_SESSION_TOKEN=credentials.token, + AWS_REGION=region, + secret_key=credentials.secret_key, + access_key=credentials.access_key, + token=credentials.token, + region=region, + ) + ) + return options + + def get_partition_filters( partition_columns: list[str], filters: Filters ) -> list[list[Filter]] | None: diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py index 75eca45..1b0fde6 100644 --- a/dask_deltatable/write.py +++ b/dask_deltatable/write.py @@ -15,8 +15,15 @@ from dask.dataframe.core import Scalar from dask.highlevelgraph import HighLevelGraph from deltalake import DeltaTable + +try: + from deltalake.writer import MAX_SUPPORTED_WRITER_VERSION +except ImportError: + from deltalake.writer import ( + MAX_SUPPORTED_PYARROW_WRITER_VERSION as MAX_SUPPORTED_WRITER_VERSION, + ) + from deltalake.writer import ( - MAX_SUPPORTED_WRITER_VERSION, PYARROW_MAJOR_VERSION, AddAction, DeltaJSONEncoder, @@ -31,6 +38,7 @@ from toolz.itertoolz import pluck from ._schema import pyarrow_to_deltalake, validate_compatible +from .utils import maybe_set_aws_credentials def to_deltalake( @@ -123,6 +131,7 @@ def to_deltalake( ------- dask.Scalar """ + storage_options = maybe_set_aws_credentials(table_or_uri, storage_options) table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) # We need to write against the latest table version @@ -136,6 +145,7 @@ def to_deltalake( storage_options = table._storage_options or {} storage_options.update(storage_options or {}) + storage_options = maybe_set_aws_credentials(table_uri, storage_options) filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) if isinstance(partition_by, str): @@ -253,6 +263,7 @@ def _commit( schema = validate_compatible(schemas) assert schema if table is None: + storage_options = maybe_set_aws_credentials(table_uri, storage_options) write_deltalake_pyarrow( table_uri, schema,