Skip to content

Commit

Permalink
Support auto-setting AWS credentials for storage options
Browse files Browse the repository at this point in the history
  • Loading branch information
milesgranger committed Mar 15, 2024
1 parent 05f7cf1 commit cfbc44f
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 6 deletions.
15 changes: 11 additions & 4 deletions dask_deltatable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pyarrow import dataset as pa_ds

from .types import Filters
from .utils import get_partition_filters
from .utils import get_partition_filters, maybe_set_aws_credentials

if Version(pa.__version__) >= Version("10.0.0"):
filters_to_expression = pq.filters_to_expression
Expand Down Expand Up @@ -94,6 +94,9 @@ def _read_from_filesystem(
"""
Reads the list of parquet files in parallel
"""
storage_options = maybe_set_aws_credentials(path, storage_options)
delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options)

fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options)
dt = DeltaTable(
table_uri=path, version=version, storage_options=delta_storage_options
Expand All @@ -116,12 +119,14 @@ def _read_from_filesystem(
if columns:
meta = meta[columns]

kws = dict(meta=meta, label="read-delta-table")
if not dd._dask_expr_enabled():
# Setting token not supported in dask-expr
kws["token"] = tokenize(path, fs_token, **kwargs)
return dd.from_map(
partial(_read_delta_partition, fs=fs, columns=columns, schema=schema, **kwargs),
pq_files,
meta=meta,
label="read-delta-table",
token=tokenize(path, fs_token, **kwargs),
**kws,
)


Expand Down Expand Up @@ -270,6 +275,8 @@ def read_deltalake(
else:
if path is None:
raise ValueError("Please Provide Delta Table path")

delta_storage_options = maybe_set_aws_credentials(path, delta_storage_options)
resultdf = _read_from_filesystem(
path=path,
version=version,
Expand Down
69 changes: 68 additions & 1 deletion dask_deltatable/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,77 @@
from __future__ import annotations

from typing import cast
from typing import cast, Any

from .types import Filter, Filters


def get_bucket_region(path: str):
import boto3

if not path.startswith("s3://"):
raise ValueError(f"'{path}' is not an S3 path")
bucket = path.replace("s3://", "").split("/")[0]
resp = boto3.client("s3").get_bucket_location(Bucket=bucket)
# Buckets in region 'us-east-1' results in None, b/c why not.
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/get_bucket_location.html#S3.Client.get_bucket_location
return resp["LocationConstraint"] or "us-east-1"


def maybe_set_aws_credentials(path: Any, options: dict) -> dict:
"""
Maybe set AWS credentials into ``options`` if existing AWS specific keys
not found in it and path is s3:// format.
Parameters
----------
path : Any
If it's a string, we'll check if it starts with 's3://' then determine bucket
region if the AWS credentials should be set.
options : dict[str, Any]
Options, any kwargs to be supplied to things like S3FileSystem or similar
that may accept AWS credentials set. A copy is made and returned if modified.
Returns
-------
dict
Either the original options if not modified, or a copied and updated options
with AWS credentials inserted.
"""

is_s3_path = getattr(path, "startswith", lambda _: False)("s3://")
if not is_s3_path:
return

# Avoid overwriting already provided credentials
keys = ("AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY", "access_key", "secret_key")
if not any(k in (options or ()) for k in keys):

# defers installing boto3 upfront, xref _read_from_catalog
import boto3

session = boto3.session.Session()
credentials = session.get_credentials()
region = get_bucket_region(path)

options = (options or {}).copy()
options.update(
# Capitalized is used in delta specific API and lowercase is for S3FileSystem
dict(
# TODO: w/o this, we need to configure a LockClient which seems to require dynamodb.
AWS_S3_ALLOW_UNSAFE_RENAME="true",
AWS_SECRET_ACCESS_KEY=credentials.secret_key,
AWS_ACCESS_KEY_ID=credentials.access_key,
AWS_SESSION_TOKEN=credentials.token,
AWS_REGION=region,
secret_key=credentials.secret_key,
access_key=credentials.access_key,
token=credentials.token,
region=region,
)
)
return options


def get_partition_filters(
partition_columns: list[str], filters: Filters
) -> list[list[Filter]] | None:
Expand Down
13 changes: 12 additions & 1 deletion dask_deltatable/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
from dask.dataframe.core import Scalar
from dask.highlevelgraph import HighLevelGraph
from deltalake import DeltaTable

try:
from deltalake.writer import MAX_SUPPORTED_WRITER_VERSION
except ImportError:
from deltalake.writer import (
MAX_SUPPORTED_PYARROW_WRITER_VERSION as MAX_SUPPORTED_WRITER_VERSION,
)

from deltalake.writer import (
MAX_SUPPORTED_WRITER_VERSION,
PYARROW_MAJOR_VERSION,
AddAction,
DeltaJSONEncoder,
Expand All @@ -31,6 +38,7 @@
from toolz.itertoolz import pluck

from ._schema import pyarrow_to_deltalake, validate_compatible
from .utils import maybe_set_aws_credentials


def to_deltalake(
Expand Down Expand Up @@ -123,6 +131,7 @@ def to_deltalake(
-------
dask.Scalar
"""
storage_options = maybe_set_aws_credentials(table_or_uri, storage_options)
table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options)

# We need to write against the latest table version
Expand All @@ -136,6 +145,7 @@ def to_deltalake(
storage_options = table._storage_options or {}
storage_options.update(storage_options or {})

storage_options = maybe_set_aws_credentials(table_uri, storage_options)
filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options))

if isinstance(partition_by, str):
Expand Down Expand Up @@ -253,6 +263,7 @@ def _commit(
schema = validate_compatible(schemas)
assert schema
if table is None:
storage_options = maybe_set_aws_credentials(table_uri, storage_options)
write_deltalake_pyarrow(
table_uri,
schema,
Expand Down

0 comments on commit cfbc44f

Please sign in to comment.