Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support auto-setting AWS credentials for storage options #78

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions dask_deltatable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from packaging.version import Version
from pyarrow import dataset as pa_ds

from . import utils
from .types import Filters
from .utils import get_partition_filters

if Version(pa.__version__) >= Version("10.0.0"):
filters_to_expression = pq.filters_to_expression
Expand All @@ -44,7 +44,9 @@ def _get_pq_files(dt: DeltaTable, filter: Filters = None) -> list[str]:
list[str]
List of files matching optional filter.
"""
partition_filters = get_partition_filters(dt.metadata().partition_columns, filter)
partition_filters = utils.get_partition_filters(
dt.metadata().partition_columns, filter
)
if not partition_filters:
# can't filter
return sorted(dt.file_uris())
Expand Down Expand Up @@ -94,6 +96,9 @@ def _read_from_filesystem(
"""
Reads the list of parquet files in parallel
"""
storage_options = utils.maybe_set_aws_credentials(path, storage_options) # type: ignore
delta_storage_options = utils.maybe_set_aws_credentials(path, delta_storage_options) # type: ignore

fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options)
dt = DeltaTable(
table_uri=path, version=version, storage_options=delta_storage_options
Expand All @@ -116,12 +121,14 @@ def _read_from_filesystem(
if columns:
meta = meta[columns]

kws = dict(meta=meta, label="read-delta-table")
if not dd._dask_expr_enabled():
# Setting token not supported in dask-expr
kws["token"] = tokenize(path, fs_token, **kwargs) # type: ignore
return dd.from_map(
partial(_read_delta_partition, fs=fs, columns=columns, schema=schema, **kwargs),
pq_files,
meta=meta,
label="read-delta-table",
token=tokenize(path, fs_token, **kwargs),
**kws,
)


Expand Down Expand Up @@ -270,6 +277,8 @@ def read_deltalake(
else:
if path is None:
raise ValueError("Please Provide Delta Table path")

delta_storage_options = utils.maybe_set_aws_credentials(path, delta_storage_options) # type: ignore
resultdf = _read_from_filesystem(
path=path,
version=version,
Expand Down
70 changes: 69 additions & 1 deletion dask_deltatable/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,78 @@
from __future__ import annotations

from typing import cast
from typing import Any, cast

from .types import Filter, Filters


def get_bucket_region(path: str):
import boto3

if not path.startswith("s3://"):
raise ValueError(f"'{path}' is not an S3 path")
bucket = path.replace("s3://", "").split("/")[0]
resp = boto3.client("s3").get_bucket_location(Bucket=bucket)
# Buckets in region 'us-east-1' results in None, b/c why not.
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/get_bucket_location.html#S3.Client.get_bucket_location
return resp["LocationConstraint"] or "us-east-1"


def maybe_set_aws_credentials(path: Any, options: dict[str, Any]) -> dict[str, Any]:
"""
Maybe set AWS credentials into ``options`` if existing AWS specific keys
not found in it and path is s3:// format.

Parameters
----------
path : Any
If it's a string, we'll check if it starts with 's3://' then determine bucket
region if the AWS credentials should be set.
options : dict[str, Any]
Options, any kwargs to be supplied to things like S3FileSystem or similar
that may accept AWS credentials set. A copy is made and returned if modified.

Returns
-------
dict
Either the original options if not modified, or a copied and updated options
with AWS credentials inserted.
"""

is_s3_path = getattr(path, "startswith", lambda _: False)("s3://")
if not is_s3_path:
return options

# Avoid overwriting already provided credentials
keys = ("AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY", "access_key", "secret_key")
if not any(k in (options or ()) for k in keys):
# defers installing boto3 upfront, xref _read_from_catalog
import boto3

session = boto3.session.Session()
credentials = session.get_credentials()
if credentials is None:
return options
region = get_bucket_region(path)

options = (options or {}).copy()
options.update(
# Capitalized is used in delta specific API and lowercase is for S3FileSystem
dict(
# TODO: w/o this, we need to configure a LockClient which seems to require dynamodb.
AWS_S3_ALLOW_UNSAFE_RENAME="true",
AWS_SECRET_ACCESS_KEY=credentials.secret_key,
AWS_ACCESS_KEY_ID=credentials.access_key,
AWS_SESSION_TOKEN=credentials.token,
AWS_REGION=region,
secret_key=credentials.secret_key,
access_key=credentials.access_key,
token=credentials.token,
region=region,
)
)
return options


def get_partition_filters(
partition_columns: list[str], filters: Filters
) -> list[list[Filter]] | None:
Expand Down
13 changes: 12 additions & 1 deletion dask_deltatable/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
from dask.dataframe.core import Scalar
from dask.highlevelgraph import HighLevelGraph
from deltalake import DeltaTable

try:
from deltalake.writer import MAX_SUPPORTED_WRITER_VERSION # type: ignore
except ImportError:
from deltalake.writer import (
MAX_SUPPORTED_PYARROW_WRITER_VERSION as MAX_SUPPORTED_WRITER_VERSION,
)

from deltalake.writer import (
MAX_SUPPORTED_WRITER_VERSION,
PYARROW_MAJOR_VERSION,
AddAction,
DeltaJSONEncoder,
Expand All @@ -30,6 +37,7 @@
)
from toolz.itertoolz import pluck

from . import utils
from ._schema import pyarrow_to_deltalake, validate_compatible


Expand Down Expand Up @@ -123,6 +131,7 @@ def to_deltalake(
-------
dask.Scalar
"""
storage_options = utils.maybe_set_aws_credentials(table_or_uri, storage_options) # type: ignore
table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options)

# We need to write against the latest table version
Expand All @@ -136,6 +145,7 @@ def to_deltalake(
storage_options = table._storage_options or {}
storage_options.update(storage_options or {})

storage_options = utils.maybe_set_aws_credentials(table_uri, storage_options)
filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options))

if isinstance(partition_by, str):
Expand Down Expand Up @@ -253,6 +263,7 @@ def _commit(
schema = validate_compatible(schemas)
assert schema
if table is None:
storage_options = utils.maybe_set_aws_credentials(table_uri, storage_options)
write_deltalake_pyarrow(
table_uri,
schema,
Expand Down
14 changes: 13 additions & 1 deletion tests/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import os
import shutil
import unittest.mock as mock
from urllib.request import urlretrieve

import dask.dataframe as dd
Expand Down Expand Up @@ -42,6 +43,15 @@ def download_data():
assert os.path.exists(DATA_DIR)


@mock.patch("dask_deltatable.utils.maybe_set_aws_credentials")
def test_reader_check_aws_credentials(maybe_set_aws_credentials):
# The full functionality of maybe_set_aws_credentials tests in test_utils
# we only need to ensure it's called here when reading with a str path
maybe_set_aws_credentials.return_value = dict()
ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta")
maybe_set_aws_credentials.assert_called()


def test_reader_all_primitive_types():
actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta")
expected_ddf = dd.read_parquet(
Expand All @@ -50,7 +60,9 @@ def test_reader_all_primitive_types():
# Dask and delta go through different parquet parsers which read the
# timestamp differently. This is likely a bug in arrow but the delta result
# is "more correct".
expected_ddf["timestamp"] = expected_ddf["timestamp"].astype("datetime64[us]")
expected_ddf["timestamp"] = (
expected_ddf["timestamp"].astype("datetime64[us]").dt.tz_localize("UTC")
)
assert_eq(actual_ddf, expected_ddf)


Expand Down
85 changes: 84 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

import pathlib
import unittest.mock as mock

import pytest

from dask_deltatable.utils import get_partition_filters
from dask_deltatable.utils import (
get_bucket_region,
get_partition_filters,
maybe_set_aws_credentials,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -31,3 +38,79 @@ def test_partition_filters(cols, filters, expected):
# make sure it works with additional level of wrapping
res = get_partition_filters(cols, filters)
assert res == expected


@mock.patch("dask_deltatable.utils.get_bucket_region")
@pytest.mark.parametrize(
"options",
(
None,
dict(),
dict(AWS_ACCESS_KEY_ID="foo", AWS_SECRET_ACCESS_KEY="bar"),
dict(access_key="foo", secret_key="bar"),
),
)
@pytest.mark.parametrize("path", ("s3://path", "/another/path", pathlib.Path(".")))
def test_maybe_set_aws_credentials(
mocked_get_bucket_region,
options,
path,
):
pytest.importorskip("boto3")

mocked_get_bucket_region.return_value = "foo-region"

mock_creds = mock.MagicMock()
type(mock_creds).token = mock.PropertyMock(return_value="token")
type(mock_creds).access_key = mock.PropertyMock(return_value="access-key")
type(mock_creds).secret_key = mock.PropertyMock(return_value="secret-key")

def mock_get_credentials():
return mock_creds

with mock.patch("boto3.session.Session") as mocked_session:
session = mocked_session.return_value
session.get_credentials.side_effect = mock_get_credentials

opts = maybe_set_aws_credentials(path, options)

if options and not any(k in options for k in ("AWS_ACCESS_KEY_ID", "access_key")):
assert opts["AWS_ACCESS_KEY_ID"] == "access-key"
assert opts["AWS_SECRET_ACCESS_KEY"] == "secret-key"
assert opts["AWS_SESSION_TOKEN"] == "token"
assert opts["AWS_REGION"] == "foo-region"

assert opts["access_key"] == "access-key"
assert opts["secret_key"] == "secret-key"
assert opts["token"] == "token"
assert opts["region"] == "foo-region"

# Did not alter input options if credentials were supplied by user
elif options:
assert options == opts


@pytest.mark.parametrize("location", (None, "region-foo"))
@pytest.mark.parametrize(
"path,bucket",
(("s3://foo/bar", "foo"), ("s3://fizzbuzz", "fizzbuzz"), ("/not/s3", None)),
)
def test_get_bucket_region(location, path, bucket):
pytest.importorskip("boto3")

with mock.patch("boto3.client") as mock_client:
mock_client = mock_client.return_value
mock_client.get_bucket_location.return_value = {"LocationConstraint": location}

if not path.startswith("s3://"):
with pytest.raises(ValueError, match="is not an S3 path"):
get_bucket_region(path)
return

region = get_bucket_region(path)

# AWS returns None if bucket located in us-east-1...
location = location if location else "us-east-1"
assert region == location

mock_client.get_bucket_location.assert_has_calls([mock.call(Bucket=bucket)])
13 changes: 13 additions & 0 deletions tests/test_write.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import os
import unittest.mock as mock

import dask.dataframe as dd
import pandas as pd
Expand Down Expand Up @@ -61,6 +62,18 @@ def test_roundtrip(tmpdir, with_index, freq, partition_freq):
assert_eq(ddf_read, ddf_dask)


@mock.patch("dask_deltatable.utils.maybe_set_aws_credentials")
def test_writer_check_aws_credentials(maybe_set_aws_credentials, tmpdir):
# The full functionality of maybe_set_aws_credentials tests in test_utils
# we only need to ensure it's called here when writing with a str path
maybe_set_aws_credentials.return_value = dict()

df = pd.DataFrame({"col1": range(10)})
ddf = dd.from_pandas(df, npartitions=2)
to_deltalake(str(tmpdir), ddf)
maybe_set_aws_credentials.assert_called()


@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"])
def test_datetime(tmpdir, unit):
"""Ensure we can write datetime with different resolutions,
Expand Down
Loading