diff --git a/dask_deltatable/core.py b/dask_deltatable/core.py index 4966d52..5b60476 100644 --- a/dask_deltatable/core.py +++ b/dask_deltatable/core.py @@ -8,6 +8,7 @@ import dask.dataframe as dd import pyarrow as pa import pyarrow.parquet as pq +from dask.base import tokenize from dask.dataframe.io.parquet.arrow import ArrowDatasetEngine from dask.dataframe.utils import make_meta from deltalake import DataCatalog, DeltaTable @@ -15,8 +16,8 @@ from packaging.version import Version from pyarrow import dataset as pa_ds +from . import utils from .types import Filters -from .utils import get_partition_filters if Version(pa.__version__) >= Version("10.0.0"): filters_to_expression = pq.filters_to_expression @@ -42,7 +43,9 @@ def _get_pq_files(dt: DeltaTable, filter: Filters = None) -> list[str]: list[str] List of files matching optional filter. """ - partition_filters = get_partition_filters(dt.metadata().partition_columns, filter) + partition_filters = utils.get_partition_filters( + dt.metadata().partition_columns, filter + ) if not partition_filters: # can't filter return sorted(dt.file_uris()) @@ -92,6 +95,9 @@ def _read_from_filesystem( """ Reads the list of parquet files in parallel """ + storage_options = utils.maybe_set_aws_credentials(path, storage_options) # type: ignore + delta_storage_options = utils.maybe_set_aws_credentials(path, delta_storage_options) # type: ignore + fs, fs_token, _ = get_fs_token_paths(path, storage_options=storage_options) dt = DeltaTable( table_uri=path, version=version, storage_options=delta_storage_options @@ -114,14 +120,17 @@ def _read_from_filesystem( if columns: meta = meta[columns] + if not dd._dask_expr_enabled(): + # Setting token not supported in dask-expr + kwargs["token"] = tokenize(path, fs_token, **kwargs) # type: ignore return dd.from_map( _read_delta_partition, pq_files, - meta=meta, - label="read-delta-table", fs=fs, columns=columns, schema=schema, + meta=meta, + label="read-delta-table", **kwargs, ) @@ -271,6 +280,8 @@ def read_deltalake( else: if path is None: raise ValueError("Please Provide Delta Table path") + + delta_storage_options = utils.maybe_set_aws_credentials(path, delta_storage_options) # type: ignore resultdf = _read_from_filesystem( path=path, version=version, diff --git a/dask_deltatable/utils.py b/dask_deltatable/utils.py index 3901f63..dabb6b4 100644 --- a/dask_deltatable/utils.py +++ b/dask_deltatable/utils.py @@ -1,10 +1,78 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast from .types import Filter, Filters +def get_bucket_region(path: str): + import boto3 + + if not path.startswith("s3://"): + raise ValueError(f"'{path}' is not an S3 path") + bucket = path.replace("s3://", "").split("/")[0] + resp = boto3.client("s3").get_bucket_location(Bucket=bucket) + # Buckets in region 'us-east-1' results in None, b/c why not. + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3/client/get_bucket_location.html#S3.Client.get_bucket_location + return resp["LocationConstraint"] or "us-east-1" + + +def maybe_set_aws_credentials(path: Any, options: dict[str, Any]) -> dict[str, Any]: + """ + Maybe set AWS credentials into ``options`` if existing AWS specific keys + not found in it and path is s3:// format. + + Parameters + ---------- + path : Any + If it's a string, we'll check if it starts with 's3://' then determine bucket + region if the AWS credentials should be set. + options : dict[str, Any] + Options, any kwargs to be supplied to things like S3FileSystem or similar + that may accept AWS credentials set. A copy is made and returned if modified. + + Returns + ------- + dict + Either the original options if not modified, or a copied and updated options + with AWS credentials inserted. + """ + + is_s3_path = getattr(path, "startswith", lambda _: False)("s3://") + if not is_s3_path: + return options + + # Avoid overwriting already provided credentials + keys = ("AWS_ACCESS_KEY", "AWS_SECRET_ACCESS_KEY", "access_key", "secret_key") + if not any(k in (options or ()) for k in keys): + # defers installing boto3 upfront, xref _read_from_catalog + import boto3 + + session = boto3.session.Session() + credentials = session.get_credentials() + if credentials is None: + return options + region = get_bucket_region(path) + + options = (options or {}).copy() + options.update( + # Capitalized is used in delta specific API and lowercase is for S3FileSystem + dict( + # TODO: w/o this, we need to configure a LockClient which seems to require dynamodb. + AWS_S3_ALLOW_UNSAFE_RENAME="true", + AWS_SECRET_ACCESS_KEY=credentials.secret_key, + AWS_ACCESS_KEY_ID=credentials.access_key, + AWS_SESSION_TOKEN=credentials.token, + AWS_REGION=region, + secret_key=credentials.secret_key, + access_key=credentials.access_key, + token=credentials.token, + region=region, + ) + ) + return options + + def get_partition_filters( partition_columns: list[str], filters: Filters ) -> list[list[Filter]] | None: diff --git a/dask_deltatable/write.py b/dask_deltatable/write.py index cbed6cd..7c92885 100644 --- a/dask_deltatable/write.py +++ b/dask_deltatable/write.py @@ -15,8 +15,15 @@ from dask.dataframe.core import Scalar from dask.highlevelgraph import HighLevelGraph from deltalake import DeltaTable + +try: + from deltalake.writer import MAX_SUPPORTED_PYARROW_WRITER_VERSION +except ImportError: + from deltalake.writer import ( # type: ignore + MAX_SUPPORTED_WRITER_VERSION as MAX_SUPPORTED_PYARROW_WRITER_VERSION, + ) + from deltalake.writer import ( - MAX_SUPPORTED_PYARROW_WRITER_VERSION, PYARROW_MAJOR_VERSION, AddAction, DeltaJSONEncoder, @@ -30,6 +37,7 @@ ) from toolz.itertoolz import pluck +from . import utils from ._schema import pyarrow_to_deltalake, validate_compatible @@ -123,6 +131,7 @@ def to_deltalake( ------- dask.Scalar """ + storage_options = utils.maybe_set_aws_credentials(table_or_uri, storage_options) # type: ignore table, table_uri = try_get_table_and_table_uri(table_or_uri, storage_options) # We need to write against the latest table version @@ -136,6 +145,7 @@ def to_deltalake( storage_options = table._storage_options or {} storage_options.update(storage_options or {}) + storage_options = utils.maybe_set_aws_credentials(table_uri, storage_options) filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri, storage_options)) if isinstance(partition_by, str): @@ -253,6 +263,7 @@ def _commit( schema = validate_compatible(schemas) assert schema if table is None: + storage_options = utils.maybe_set_aws_credentials(table_uri, storage_options) write_deltalake_pyarrow( table_uri, schema, diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index ca95744..86f0b17 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -14,6 +14,7 @@ import os import shutil +import unittest.mock as mock from urllib.request import urlretrieve import dask.dataframe as dd @@ -42,6 +43,15 @@ def download_data(): assert os.path.exists(DATA_DIR) +@mock.patch("dask_deltatable.utils.maybe_set_aws_credentials") +def test_reader_check_aws_credentials(maybe_set_aws_credentials): + # The full functionality of maybe_set_aws_credentials tests in test_utils + # we only need to ensure it's called here when reading with a str path + maybe_set_aws_credentials.return_value = dict() + ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta") + maybe_set_aws_credentials.assert_called() + + def test_reader_all_primitive_types(): actual_ddf = ddt.read_deltalake(f"{DATA_DIR}/all_primitive_types/delta") expected_ddf = dd.read_parquet( diff --git a/tests/test_utils.py b/tests/test_utils.py index d8b49dd..4743f23 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,15 @@ from __future__ import annotations +import pathlib +import unittest.mock as mock + import pytest -from dask_deltatable.utils import get_partition_filters +from dask_deltatable.utils import ( + get_bucket_region, + get_partition_filters, + maybe_set_aws_credentials, +) @pytest.mark.parametrize( @@ -31,3 +38,79 @@ def test_partition_filters(cols, filters, expected): # make sure it works with additional level of wrapping res = get_partition_filters(cols, filters) assert res == expected + + +@mock.patch("dask_deltatable.utils.get_bucket_region") +@pytest.mark.parametrize( + "options", + ( + None, + dict(), + dict(AWS_ACCESS_KEY_ID="foo", AWS_SECRET_ACCESS_KEY="bar"), + dict(access_key="foo", secret_key="bar"), + ), +) +@pytest.mark.parametrize("path", ("s3://path", "/another/path", pathlib.Path("."))) +def test_maybe_set_aws_credentials( + mocked_get_bucket_region, + options, + path, +): + pytest.importorskip("boto3") + + mocked_get_bucket_region.return_value = "foo-region" + + mock_creds = mock.MagicMock() + type(mock_creds).token = mock.PropertyMock(return_value="token") + type(mock_creds).access_key = mock.PropertyMock(return_value="access-key") + type(mock_creds).secret_key = mock.PropertyMock(return_value="secret-key") + + def mock_get_credentials(): + return mock_creds + + with mock.patch("boto3.session.Session") as mocked_session: + session = mocked_session.return_value + session.get_credentials.side_effect = mock_get_credentials + + opts = maybe_set_aws_credentials(path, options) + + if options and not any(k in options for k in ("AWS_ACCESS_KEY_ID", "access_key")): + assert opts["AWS_ACCESS_KEY_ID"] == "access-key" + assert opts["AWS_SECRET_ACCESS_KEY"] == "secret-key" + assert opts["AWS_SESSION_TOKEN"] == "token" + assert opts["AWS_REGION"] == "foo-region" + + assert opts["access_key"] == "access-key" + assert opts["secret_key"] == "secret-key" + assert opts["token"] == "token" + assert opts["region"] == "foo-region" + + # Did not alter input options if credentials were supplied by user + elif options: + assert options == opts + + +@pytest.mark.parametrize("location", (None, "region-foo")) +@pytest.mark.parametrize( + "path,bucket", + (("s3://foo/bar", "foo"), ("s3://fizzbuzz", "fizzbuzz"), ("/not/s3", None)), +) +def test_get_bucket_region(location, path, bucket): + pytest.importorskip("boto3") + + with mock.patch("boto3.client") as mock_client: + mock_client = mock_client.return_value + mock_client.get_bucket_location.return_value = {"LocationConstraint": location} + + if not path.startswith("s3://"): + with pytest.raises(ValueError, match="is not an S3 path"): + get_bucket_region(path) + return + + region = get_bucket_region(path) + + # AWS returns None if bucket located in us-east-1... + location = location if location else "us-east-1" + assert region == location + + mock_client.get_bucket_location.assert_has_calls([mock.call(Bucket=bucket)]) diff --git a/tests/test_write.py b/tests/test_write.py index 8afefbe..3feae43 100644 --- a/tests/test_write.py +++ b/tests/test_write.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import unittest.mock as mock import dask.dataframe as dd import pandas as pd @@ -61,6 +62,18 @@ def test_roundtrip(tmpdir, with_index, freq, partition_freq): assert_eq(ddf_read, ddf_dask) +@mock.patch("dask_deltatable.utils.maybe_set_aws_credentials") +def test_writer_check_aws_credentials(maybe_set_aws_credentials, tmpdir): + # The full functionality of maybe_set_aws_credentials tests in test_utils + # we only need to ensure it's called here when writing with a str path + maybe_set_aws_credentials.return_value = dict() + + df = pd.DataFrame({"col1": range(10)}) + ddf = dd.from_pandas(df, npartitions=2) + to_deltalake(str(tmpdir), ddf) + maybe_set_aws_credentials.assert_called() + + @pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) def test_datetime(tmpdir, unit): """Ensure we can write datetime with different resolutions,