Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/add support for databricks personal access token #304

Merged
merged 3 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ jobs:
KAFKA_API_KEY: ${{ secrets.KAFKA_API_KEY }}
KAFKA_SECRET: ${{ secrets.KAFKA_SECRET }}
KAFKA_BOOTSTRAP_SERVER: ${{ secrets.KAFKA_BOOTSTRAP_SERVER }}
DATABRICKS_PAT: ${{ secrets.DATABRICKS_PAT }}
run : |
source .venv/bin/activate
make install-test
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
## 0.3.11-dev0

### Enhancements

* **Support Databricks personal access token**

## 0.3.10

### Enhancements
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import os
import tempfile
import uuid
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from unittest import mock

import pytest
from databricks.sdk import WorkspaceClient
Expand All @@ -31,11 +31,15 @@


@dataclass
class EnvData:
class BaseEnvData:
host: str
catalog: str


@dataclass
class BasicAuthEnvData(BaseEnvData):
client_id: str
client_secret: str
catalog: str

def get_connection_config(self) -> DatabricksNativeVolumesConnectionConfig:
return DatabricksNativeVolumesConnectionConfig(
Expand All @@ -47,32 +51,52 @@ def get_connection_config(self) -> DatabricksNativeVolumesConnectionConfig:
)


def get_env_data() -> EnvData:
return EnvData(
@dataclass
class PATEnvData(BaseEnvData):
token: str

def get_connection_config(self) -> DatabricksNativeVolumesConnectionConfig:
return DatabricksNativeVolumesConnectionConfig(
host=self.host,
access_config=DatabricksNativeVolumesAccessConfig(
token=self.token,
),
)


def get_basic_auth_env_data() -> BasicAuthEnvData:
return BasicAuthEnvData(
host=os.environ["DATABRICKS_HOST"],
client_id=os.environ["DATABRICKS_CLIENT_ID"],
client_secret=os.environ["DATABRICKS_CLIENT_SECRET"],
catalog=os.environ["DATABRICKS_CATALOG"],
)


def get_pat_env_data() -> PATEnvData:
return PATEnvData(
host=os.environ["DATABRICKS_HOST"],
catalog=os.environ["DATABRICKS_CATALOG"],
token=os.environ["DATABRICKS_PAT"],
)


@pytest.mark.asyncio
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
@requires_env(
"DATABRICKS_HOST", "DATABRICKS_CLIENT_ID", "DATABRICKS_CLIENT_SECRET", "DATABRICKS_CATALOG"
)
async def test_volumes_native_source():
env_data = get_env_data()
indexer_config = DatabricksNativeVolumesIndexerConfig(
recursive=True,
volume="test-platform",
volume_path="databricks-volumes-test-input",
catalog=env_data.catalog,
)
connection_config = env_data.get_connection_config()
with tempfile.TemporaryDirectory() as tempdir:
tempdir_path = Path(tempdir)
download_config = DatabricksNativeVolumesDownloaderConfig(download_dir=tempdir_path)
async def test_volumes_native_source(tmp_path: Path):
env_data = get_basic_auth_env_data()
with mock.patch.dict(os.environ, clear=True):
indexer_config = DatabricksNativeVolumesIndexerConfig(
recursive=True,
volume="test-platform",
volume_path="databricks-volumes-test-input",
catalog=env_data.catalog,
)
connection_config = env_data.get_connection_config()
download_config = DatabricksNativeVolumesDownloaderConfig(download_dir=tmp_path)
indexer = DatabricksNativeVolumesIndexer(
connection_config=connection_config, index_config=indexer_config
)
Expand All @@ -89,12 +113,44 @@ async def test_volumes_native_source():
)


@pytest.mark.asyncio
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
@requires_env("DATABRICKS_HOST", "DATABRICKS_PAT", "DATABRICKS_CATALOG")
async def test_volumes_native_source_pat(tmp_path: Path):
env_data = get_pat_env_data()
with mock.patch.dict(os.environ, clear=True):
indexer_config = DatabricksNativeVolumesIndexerConfig(
recursive=True,
volume="test-platform",
volume_path="databricks-volumes-test-input",
catalog=env_data.catalog,
)
connection_config = env_data.get_connection_config()
download_config = DatabricksNativeVolumesDownloaderConfig(download_dir=tmp_path)
indexer = DatabricksNativeVolumesIndexer(
connection_config=connection_config, index_config=indexer_config
)
downloader = DatabricksNativeVolumesDownloader(
connection_config=connection_config, download_config=download_config
)
await source_connector_validation(
indexer=indexer,
downloader=downloader,
configs=SourceValidationConfigs(
test_id="databricks_volumes_native_pat",
expected_num_files=1,
),
)


def _get_volume_path(catalog: str, volume: str, volume_path: str):
return f"/Volumes/{catalog}/default/{volume}/{volume_path}"


@contextmanager
def databricks_destination_context(env_data: EnvData, volume: str, volume_path) -> WorkspaceClient:
def databricks_destination_context(
env_data: BasicAuthEnvData, volume: str, volume_path
) -> WorkspaceClient:
client = WorkspaceClient(
host=env_data.host, client_id=env_data.client_id, client_secret=env_data.client_secret
)
Expand Down Expand Up @@ -137,7 +193,7 @@ def validate_upload(client: WorkspaceClient, catalog: str, volume: str, volume_p
"DATABRICKS_HOST", "DATABRICKS_CLIENT_ID", "DATABRICKS_CLIENT_SECRET", "DATABRICKS_CATALOG"
)
async def test_volumes_native_destination(upload_file: Path):
env_data = get_env_data()
env_data = get_basic_auth_env_data()
volume_path = f"databricks-volumes-test-output-{uuid.uuid4()}"
file_data = FileData(
source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"directory_structure": [
"fake-memo.pdf"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
"identifier": "9a6eb650-98d6-5465-8f1d-aa7118eee87e",
"connector_type": "databricks_volumes",
"source_identifiers": {
"filename": "fake-memo.pdf",
"fullpath": "/Volumes/utic-dev-tech-fixtures/default/test-platform/databricks-volumes-test-input/fake-memo.pdf",
"rel_path": "fake-memo.pdf"
},
"metadata": {
"url": "/Volumes/utic-dev-tech-fixtures/default/test-platform/databricks-volumes-test-input/fake-memo.pdf",
"version": null,
"record_locator": null,
"date_created": null,
"date_modified": "1729186569000",
"date_processed": null,
"permissions_data": null,
"filesize_bytes": null
},
"additional_metadata": {
"catalog": "utic-dev-tech-fixtures",
"path": "/Volumes/utic-dev-tech-fixtures/default/test-platform/databricks-volumes-test-input/fake-memo.pdf"
},
"reprocess": false,
"local_download_path": "/private/var/folders/n8/rps3wl195pj4p_0vyxqj5jrw0000gn/T/pytest-of-romanisecke/pytest-9/test_volumes_native_source_pat0/fake-memo.pdf",
"display_name": null
}
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.10" # pragma: no cover
__version__ = "0.3.11-dev0" # pragma: no cover
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.v2.interfaces import (
AccessConfig,
ConnectionConfig,
Downloader,
DownloaderConfig,
Expand Down Expand Up @@ -52,6 +53,10 @@ def path(self) -> str:
return path


class DatabricksVolumesAccessConfig(AccessConfig):
token: Optional[str] = Field(default=None, description="Databricks Personal Access Token")


class DatabricksVolumesConnectionConfig(ConnectionConfig, ABC):
host: Optional[str] = Field(
default=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from pydantic import Field, Secret

from unstructured_ingest.v2.interfaces import AccessConfig
from unstructured_ingest.v2.processes.connector_registry import (
DestinationRegistryEntry,
SourceRegistryEntry,
)
from unstructured_ingest.v2.processes.connectors.databricks.volumes import (
DatabricksVolumesAccessConfig,
DatabricksVolumesConnectionConfig,
DatabricksVolumesDownloader,
DatabricksVolumesDownloaderConfig,
Expand All @@ -21,7 +21,7 @@
CONNECTOR_TYPE = "databricks_volumes_aws"


class DatabricksAWSVolumesAccessConfig(AccessConfig):
class DatabricksAWSVolumesAccessConfig(DatabricksVolumesAccessConfig):
account_id: Optional[str] = Field(
default=None,
description="The Databricks account ID for the Databricks " "accounts endpoint",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from pydantic import Field, Secret

from unstructured_ingest.v2.interfaces import AccessConfig
from unstructured_ingest.v2.processes.connector_registry import (
DestinationRegistryEntry,
SourceRegistryEntry,
)
from unstructured_ingest.v2.processes.connectors.databricks.volumes import (
DatabricksVolumesAccessConfig,
DatabricksVolumesConnectionConfig,
DatabricksVolumesDownloader,
DatabricksVolumesDownloaderConfig,
Expand All @@ -21,7 +21,7 @@
CONNECTOR_TYPE = "databricks_volumes_azure"


class DatabricksAzureVolumesAccessConfig(AccessConfig):
class DatabricksAzureVolumesAccessConfig(DatabricksVolumesAccessConfig):
account_id: Optional[str] = Field(
default=None,
description="The Databricks account ID for the Databricks " "accounts endpoint.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from pydantic import Field, Secret

from unstructured_ingest.v2.interfaces import AccessConfig
from unstructured_ingest.v2.processes.connector_registry import (
DestinationRegistryEntry,
SourceRegistryEntry,
)
from unstructured_ingest.v2.processes.connectors.databricks.volumes import (
DatabricksVolumesAccessConfig,
DatabricksVolumesConnectionConfig,
DatabricksVolumesDownloader,
DatabricksVolumesDownloaderConfig,
Expand All @@ -21,7 +21,7 @@
CONNECTOR_TYPE = "databricks_volumes_gcp"


class DatabricksGoogleVolumesAccessConfig(AccessConfig):
class DatabricksGoogleVolumesAccessConfig(DatabricksVolumesAccessConfig):
account_id: Optional[str] = Field(
default=None,
description="The Databricks account ID for the Databricks " "accounts endpoint.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from pydantic import Field, Secret

from unstructured_ingest.v2.interfaces import AccessConfig
from unstructured_ingest.v2.processes.connector_registry import (
DestinationRegistryEntry,
SourceRegistryEntry,
)
from unstructured_ingest.v2.processes.connectors.databricks.volumes import (
DatabricksVolumesAccessConfig,
DatabricksVolumesConnectionConfig,
DatabricksVolumesDownloader,
DatabricksVolumesDownloaderConfig,
Expand All @@ -21,7 +21,7 @@
CONNECTOR_TYPE = "databricks_volumes"


class DatabricksNativeVolumesAccessConfig(AccessConfig):
class DatabricksNativeVolumesAccessConfig(DatabricksVolumesAccessConfig):
client_id: Optional[str] = Field(default=None, description="Client ID of the OAuth app.")
client_secret: Optional[str] = Field(
default=None, description="Client Secret of the OAuth app."
Expand Down
Loading