Skip to content

Commit

Permalink
Add dynamically loading table credentials support for unity delta wri…
Browse files Browse the repository at this point in the history
…te (#2)

* add support for dynamically loading table credentials to write delta

* add AWS_SESSION_TOKEN

* handle non-existing tables

* remove support for loading delta tables directly as unity catalog tables
  • Loading branch information
chidifrank authored Sep 6, 2024
1 parent ff9d4c4 commit f88ade5
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 98 deletions.
81 changes: 38 additions & 43 deletions dbt/adapters/duckdb/plugins/unity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pyarrow as pa
from unitycatalog import Unitycatalog
from unitycatalog.types import GenerateTemporaryTableCredentialResponse
from unitycatalog.types.table_create_params import Column

from . import BasePlugin
Expand Down Expand Up @@ -46,6 +47,36 @@ def uc_table_exists(
return table_name in [table.name for table in table_list_request.tables]


def uc_get_storage_credentials(
client: Unitycatalog, catalog_name: str, schema_name: str, table_name: str
) -> dict:
"""Get temporary table credentials for a UC table if they exist."""

# Get the table ID

if not uc_table_exists(client, table_name, schema_name, catalog_name):
return {}

table_response = client.tables.retrieve(full_name=f"{catalog_name}.{schema_name}.{table_name}")

if not table_response.table_id:
return {}

# Get the temporary table credentials
creds: GenerateTemporaryTableCredentialResponse = client.temporary_table_credentials.create(
operation="READ_WRITE", table_id=table_response.table_id
)

if creds.aws_temp_credentials:
return {
"AWS_ACCESS_KEY_ID": creds.aws_temp_credentials.access_key_id,
"AWS_SECRET_ACCESS_KEY": creds.aws_temp_credentials.secret_access_key,
"AWS_SESSION_TOKEN": creds.aws_temp_credentials.session_token,
}

return {}


UCSupportedTypeLiteral = Literal[
"BOOLEAN",
"BYTE",
Expand Down Expand Up @@ -214,48 +245,7 @@ def initialize(self, config: Dict[str, Any]):
self.uc_client: Unitycatalog = Unitycatalog(base_url=catalog_base_url)

def load(self, source_config: SourceConfig):
# Assert that the source_config has a name, schema, and database
assert source_config.identifier is not None, "Name is required for loading data!"
assert source_config.schema is not None, "Schema is required for loading data!"
assert source_config.get("location") is not None, "Location is required for loading data!"

# Get the required variables from the source configuration
table_path = source_config.get("location")
table_name = source_config.identifier
schema_name = source_config.schema

# Get the optional variables from the source configuration
storage_format = source_config.get("format", self.default_format)
storage_options = source_config.get("storage_options", {})
as_of_version = source_config.get("as_of_version", None)
as_of_datetime = source_config.get("as_of_datetime", None)

if storage_format == StorageFormat.DELTA:
from .delta import delta_load

df = delta_load(
table_path=table_path,
storage_options=storage_options,
as_of_version=as_of_version,
as_of_datetime=as_of_datetime,
)
else:
raise NotImplementedError(f"Loading storage format {storage_format} not supported!")

converted_schema = pyarrow_schema_to_columns(schema=df.schema)

# Create he table in the Unitycatalog if it does not exist
create_table_if_not_exists(
uc_client=self.uc_client,
table_name=table_name,
schema_name=schema_name,
catalog_name=self.catalog_name,
storage_location=table_path,
schema=converted_schema,
storage_format=storage_format,
)

return df
raise NotImplementedError("Loading data to Unitycatalog is not supported!")

def store(self, target_config: TargetConfig, df: pa.lib.Table = None):
# Assert that the target_config has a location and relation identifier
Expand Down Expand Up @@ -286,7 +276,7 @@ def store(self, target_config: TargetConfig, df: pa.lib.Table = None):
# Convert the pa schema to columns
converted_schema = pyarrow_schema_to_columns(schema=df.schema)

# Create he table in the Unitycatalog if it does not exist
# Create the table in the Unitycatalog if it does not exist
create_table_if_not_exists(
uc_client=self.uc_client,
table_name=table_name,
Expand All @@ -297,6 +287,11 @@ def store(self, target_config: TargetConfig, df: pa.lib.Table = None):
storage_format=storage_format,
)

# extend the storage options with the temporary table credentials
storage_options = storage_options | uc_get_storage_credentials(
self.uc_client, self.catalog_name, schema_name, table_name
)

if storage_format == StorageFormat.DELTA:
from .delta import delta_write

Expand Down
56 changes: 1 addition & 55 deletions tests/functional/plugins/test_unity.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,10 @@
import tempfile
from pathlib import Path

import pandas as pd
import pytest
from dbt.tests.util import (
run_dbt,
)
from deltalake.writer import write_deltalake

unity_schema_yml = """
version: 2
sources:
- name: default
meta:
plugin: unity
tables:
- name: unity_source_table
description: "A UC table"
meta:
location: "{unity_source_table_location}"
format: DELTA
- name: test
meta:
plugin: unity
tables:
- name: unity_source_table_with_version
description: "A UC table that loads a specific version of the table"
meta:
location: "{unity_source_table_with_version_location}"
format: DELTA
as_of_version: 0
"""

ref1 = """
select 2 as a, 'test' as b
Expand Down Expand Up @@ -63,29 +36,6 @@ def unity_create_table_and_schema_sql(location: str) -> str:

@pytest.mark.skip_profile("buenavista", "file", "memory", "md")
class TestPlugins:
@pytest.fixture(scope="class")
def unity_source_table(self):
with tempfile.TemporaryDirectory() as tmpdir:
table_path = Path(tmpdir) / "unity_source_table"

df = pd.DataFrame({"x": [1, 2, 3]})
write_deltalake(table_path, df, mode="overwrite")

yield table_path

@pytest.fixture(scope="class")
def unity_source_table_with_version(self):
with tempfile.TemporaryDirectory() as tmpdir:
table_path = Path(tmpdir) / "unity_source_table_with_version"

df1 = pd.DataFrame({"x": [1], "y": ["a"]})
write_deltalake(table_path, df1, mode="overwrite")

df2 = pd.DataFrame({"x": [1, 2], "y": ["a", "b"]})
write_deltalake(table_path, df2, mode="overwrite")

yield table_path

@pytest.fixture(scope="class")
def unity_create_table(self):
td = tempfile.TemporaryDirectory()
Expand Down Expand Up @@ -128,12 +78,8 @@ def profiles_config_update(self, dbt_profile_target):
}

@pytest.fixture(scope="class")
def models(self, unity_create_table, unity_create_table_and_schema, unity_source_table, unity_source_table_with_version):
def models(self, unity_create_table, unity_create_table_and_schema):
return {
"source_schema.yml": unity_schema_yml.format(
unity_source_table_location=unity_source_table,
unity_source_table_with_version_location=unity_source_table_with_version
),
"unity_create_table.sql": unity_create_table_sql(str(unity_create_table)),
"unity_create_table_and_schema.sql": unity_create_table_and_schema_sql(str(unity_create_table_and_schema)),
"ref1.sql": ref1
Expand Down

0 comments on commit f88ade5

Please sign in to comment.