From 5158b3e401a0cf1d6bc0f16e60c838978b53063d Mon Sep 17 00:00:00 2001 From: Minura Punchihewa <49385643+MinuraPunchihewa@users.noreply.github.com> Date: Mon, 21 Oct 2024 20:39:35 +0530 Subject: [PATCH] feat(datasets): Added the Experimental ExternalTableDataset for Databricks (#885) * added the experimental ExternalTableDataset * fixed lint issues * added the missing location attr to the docstring * removed unused code from the tests for ManagedTableDataset * added tests for ExternalTableDataset * moved all fixtures to conftest.py * updated the format for the test for save_overwrite to Parquet * moved tests to the kedro_datasets_experimental pkg * updated the release doc * added the dataset to the documentation * listed the dependencies for the dataset --------- Signed-off-by: Minura Punchihewa --- kedro-datasets/RELEASE.md | 18 ++ .../api/kedro_datasets_experimental.rst | 1 + .../databricks/__init__.py | 12 ++ .../databricks/external_table_dataset.py | 179 ++++++++++++++++ .../tests/databricks/__init__.py | 0 .../tests/databricks/conftest.py | 200 ++++++++++++++++++ .../databricks/test_external_table_dataset.py | 45 ++++ kedro-datasets/pyproject.toml | 1 + kedro-datasets/tests/databricks/conftest.py | 170 +++++++++++++++ .../databricks/test_base_table_dataset.py | 171 +-------------- .../databricks/test_managed_table_dataset.py | 167 --------------- 11 files changed, 627 insertions(+), 337 deletions(-) create mode 100644 kedro-datasets/kedro_datasets_experimental/databricks/__init__.py create mode 100644 kedro-datasets/kedro_datasets_experimental/databricks/external_table_dataset.py create mode 100644 kedro-datasets/kedro_datasets_experimental/tests/databricks/__init__.py create mode 100644 kedro-datasets/kedro_datasets_experimental/tests/databricks/conftest.py create mode 100644 kedro-datasets/kedro_datasets_experimental/tests/databricks/test_external_table_dataset.py diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 4d63d8a25..fe94686f6 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -1,5 +1,23 @@ # Upcoming Release +## Major features and improvements + +- Added the following new **experimental** datasets: + +| Type | Description | Location | +| --------------------------------- | ------------------------------------------------------ | ---------------------------------------- | +| `databricks.ExternalTableDataset` | A dataset for accessing external tables in Databricks. | `kedro_datasets_experimental.databricks` | + +## Bug fixes and other changes + +## Breaking Changes + +## Community contributions + +Many thanks to the following Kedroids for contributing PRs to this release: + +- [Minura Punchihewa](https://github.com/MinuraPunchihewa) + # Release 5.1.0 ## Major features and improvements diff --git a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst index 219510954..a29e8449c 100644 --- a/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst +++ b/kedro-datasets/docs/source/api/kedro_datasets_experimental.rst @@ -11,6 +11,7 @@ kedro_datasets_experimental :toctree: :template: autosummary/class.rst + databricks.ExternalTableDataset langchain.ChatAnthropicDataset langchain.ChatCohereDataset langchain.ChatOpenAIDataset diff --git a/kedro-datasets/kedro_datasets_experimental/databricks/__init__.py b/kedro-datasets/kedro_datasets_experimental/databricks/__init__.py new file mode 100644 index 000000000..ec5a41a6f --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/databricks/__init__.py @@ -0,0 +1,12 @@ +"""Provides an interface to Unity Catalog External Tables.""" + +from typing import Any + +import lazy_loader as lazy + +# https://github.com/pylint-dev/pylint/issues/4300#issuecomment-1043601901 +ExternalTableDataset: Any + +__getattr__, __dir__, __all__ = lazy.attach( + __name__, submod_attrs={"external_table_dataset": ["ExternalTableDataset"]} +) diff --git a/kedro-datasets/kedro_datasets_experimental/databricks/external_table_dataset.py b/kedro-datasets/kedro_datasets_experimental/databricks/external_table_dataset.py new file mode 100644 index 000000000..a4f4a351b --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/databricks/external_table_dataset.py @@ -0,0 +1,179 @@ +"""``ExternalTableDataset`` implementation to access external tables +in Databricks. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +import pandas as pd +from kedro.io.core import DatasetError +from pyspark.sql import DataFrame + +from kedro_datasets.databricks._base_table_dataset import BaseTable, BaseTableDataset + +logger = logging.getLogger(__name__) +pd.DataFrame.iteritems = pd.DataFrame.items + + +@dataclass(frozen=True) +class ExternalTable(BaseTable): + """Stores the definition of an external table.""" + + def _validate_location(self) -> None: + """Validates that a location is provided if the table does not exist. + + Raises: + DatasetError: If the table does not exist and no location is provided. + """ + if not self.exists() and not self.location: + raise DatasetError( + "If the external table does not exists, the `location` parameter must be provided. " + "This should be valid path in an external location that has already been created." + ) + + def _validate_write_mode(self) -> None: + """Validates that the write mode is compatible with the format. + + Raises: + DatasetError: If the write mode is not compatible with the format. + """ + super()._validate_write_mode() + + if self.write_mode == "upsert" and self.format != "delta": + raise DatasetError( + f"Format '{self.format}' is not supported for upserts. " + f"Please use 'delta' format." + ) + + if self.write_mode == "overwrite" and self.format != "delta" and not self.location: + raise DatasetError( + f"Format '{self.format}' is supported for overwrites only if the location is provided. " + f"Please provide a valid path in an external location." + ) + + +class ExternalTableDataset(BaseTableDataset): + """``ExternalTableDataset`` loads and saves data into external tables in Databricks. + Load and save can be in Spark or Pandas dataframes, specified in dataframe_type. + + Example usage for the + `YAML API `_: + + .. code-block:: yaml + + names_and_ages@spark: + type: databricks.ExternalTableDataset + format: parquet + table: names_and_ages + + names_and_ages@pandas: + type: databricks.ExternalTableDataset + format: parquet + table: names_and_ages + dataframe_type: pandas + + Example usage for the + `Python API `_: + + .. code-block:: pycon + + >>> from kedro_datasets.databricks import ExternalTableDataset + >>> from pyspark.sql import SparkSession + >>> from pyspark.sql.types import IntegerType, Row, StringType, StructField, StructType + >>> import importlib_metadata + >>> + >>> DELTA_VERSION = importlib_metadata.version("delta-spark") + >>> schema = StructType( + ... [StructField("name", StringType(), True), StructField("age", IntegerType(), True)] + ... ) + >>> data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + >>> spark_df = ( + ... SparkSession.builder.config( + ... "spark.jars.packages", f"io.delta:delta-core_2.12:{DELTA_VERSION}" + ... ) + ... .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + ... .config( + ... "spark.sql.catalog.spark_catalog", + ... "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ... ) + ... .getOrCreate() + ... .createDataFrame(data, schema) + ... ) + >>> dataset = ExternalTableDataset( + ... table="names_and_ages", + ... write_mode="overwrite", + ... location="abfss://container@storageaccount.dfs.core.windows.net/depts/cust" + ... ) + >>> dataset.save(spark_df) + >>> reloaded = dataset.load() + >>> assert Row(name="Bob", age=12) in reloaded.take(4) + """ + + def _create_table( # noqa: PLR0913 + self, + table: str, + catalog: str | None, + database: str, + format: str, + write_mode: str | None, + location: str | None, + dataframe_type: str, + primary_key: str | list[str] | None, + json_schema: dict[str, Any] | None, + partition_columns: list[str] | None, + owner_group: str | None + ) -> ExternalTable: + """Creates a new ``ExternalTable`` instance with the provided attributes. + Args: + table: The name of the table. + catalog: The catalog of the table. + database: The database of the table. + format: The format of the table. + write_mode: The write mode for the table. + location: The location of the table. + dataframe_type: The type of dataframe. + primary_key: The primary key of the table. + json_schema: The JSON schema of the table. + partition_columns: The partition columns of the table. + owner_group: The owner group of the table. + Returns: + ``ExternalTable``: The new ``ExternalTable`` instance. + """ + return ExternalTable( + table=table, + catalog=catalog, + database=database, + write_mode=write_mode, + location=location, + dataframe_type=dataframe_type, + json_schema=json_schema, + partition_columns=partition_columns, + owner_group=owner_group, + primary_key=primary_key, + format=format + ) + + def _save_overwrite(self, data: DataFrame) -> None: + """Overwrites the data in the table with the data provided. + Args: + data (DataFrame): The Spark dataframe to overwrite the table with. + """ + writer = data.write.format(self._table.format).mode("overwrite").option( + "overwriteSchema", "true" + ) + + if self._table.partition_columns: + writer.partitionBy( + *self._table.partition_columns if isinstance(self._table.partition_columns, list) else self._table.partition_columns + ) + + if self._table.format == "delta" or (not self._table.exists()): + if self._table.location: + writer.option("path", self._table.location) + + writer.saveAsTable(self._table.full_table_location() or "") + + else: + writer.save(self._table.location) diff --git a/kedro-datasets/kedro_datasets_experimental/tests/databricks/__init__.py b/kedro-datasets/kedro_datasets_experimental/tests/databricks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kedro-datasets/kedro_datasets_experimental/tests/databricks/conftest.py b/kedro-datasets/kedro_datasets_experimental/tests/databricks/conftest.py new file mode 100644 index 000000000..6984faabb --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/databricks/conftest.py @@ -0,0 +1,200 @@ +""" +This file contains the fixtures that are reusable by any tests within +this directory. You don't need to import the fixtures as pytest will +discover them automatically. More info here: +https://docs.pytest.org/en/latest/fixture.html +""" +import os + +# importlib_metadata needs backport for python 3.8 and older +import importlib_metadata +import pandas as pd +import pytest +from pyspark.sql import SparkSession +from pyspark.sql.types import IntegerType, StringType, StructField, StructType + +DELTA_VERSION = importlib_metadata.version("delta-spark") + + +@pytest.fixture(scope="class", autouse=True) +def spark_session(): + spark = ( + SparkSession.builder.appName("test") + .config("spark.jars.packages", f"io.delta:delta-core_2.12:{DELTA_VERSION}") + .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") + .config( + "spark.sql.catalog.spark_catalog", + "org.apache.spark.sql.delta.catalog.DeltaCatalog", + ) + .getOrCreate() + ) + spark.sql("create database if not exists test") + yield spark + spark.sql("drop database test cascade;") + + +@pytest.fixture +def sample_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def mismatched_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} + ) + + +@pytest.fixture +def subset_expected_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Evan", 23), ("Frank", 13)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ("Frank", 13), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def external_location(): + return os.environ.get("DATABRICKS_EXTERNAL_LOCATION") diff --git a/kedro-datasets/kedro_datasets_experimental/tests/databricks/test_external_table_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/databricks/test_external_table_dataset.py new file mode 100644 index 000000000..0e4ef3446 --- /dev/null +++ b/kedro-datasets/kedro_datasets_experimental/tests/databricks/test_external_table_dataset.py @@ -0,0 +1,45 @@ +import pytest +from kedro.io.core import DatasetError +from pyspark.sql import DataFrame + +from kedro_datasets_experimental.databricks.external_table_dataset import ( + ExternalTableDataset, +) + + +class TestExternalTableDataset: + def test_location_for_non_existing_table(self): + with pytest.raises(DatasetError): + ExternalTableDataset(table="test") + + def test_invalid_upsert_write_mode(self): + with pytest.raises(DatasetError): + ExternalTableDataset(table="test", write_mode="upsert", format="parquet") + + def test_invalid_overwrite_write_mode(self): + with pytest.raises(DatasetError): + ExternalTableDataset(table="test", write_mode="overwrite", format="parquet") + + def test_save_overwrite_without_location(self): + with pytest.raises(DatasetError): + ExternalTableDataset(table="test", write_mode="overwrite", format="delta") + + def test_save_overwrite( + self, + sample_spark_df: DataFrame, + append_spark_df: DataFrame, + external_location: str, + ): + unity_ds = ExternalTableDataset( + database="test", + table="test_save", + format="parquet", + write_mode="overwrite", + location=f"{external_location}/test_save_overwrite_external", + ) + unity_ds.save(sample_spark_df) + unity_ds.save(append_spark_df) + + overwritten_table = unity_ds.load() + + assert append_spark_df.exceptAll(overwritten_table).count() == 0 diff --git a/kedro-datasets/pyproject.toml b/kedro-datasets/pyproject.toml index 706d2c80d..c33584112 100644 --- a/kedro-datasets/pyproject.toml +++ b/kedro-datasets/pyproject.toml @@ -174,6 +174,7 @@ yaml-yamldataset = ["kedro-datasets[pandas-base]", "PyYAML>=4.2, <7.0"] yaml = ["kedro-datasets[yaml-yamldataset]"] # Experimental Datasets +databricks-externaltabledataset = ["kedro-datasets[spark-base,pandas-base,delta-base,hdfs-base,s3fs-base]"] langchain-chatopenaidataset = ["langchain-openai~=0.1.7"] langchain-openaiembeddingsdataset = ["langchain-openai~=0.1.7"] langchain-chatanthropicdataset = ["langchain-anthropic~=0.1.13", "langchain-community~=0.2.0"] diff --git a/kedro-datasets/tests/databricks/conftest.py b/kedro-datasets/tests/databricks/conftest.py index ccc0c78ad..6984faabb 100644 --- a/kedro-datasets/tests/databricks/conftest.py +++ b/kedro-datasets/tests/databricks/conftest.py @@ -4,11 +4,14 @@ discover them automatically. More info here: https://docs.pytest.org/en/latest/fixture.html """ +import os # importlib_metadata needs backport for python 3.8 and older import importlib_metadata +import pandas as pd import pytest from pyspark.sql import SparkSession +from pyspark.sql.types import IntegerType, StringType, StructField, StructType DELTA_VERSION = importlib_metadata.version("delta-spark") @@ -28,3 +31,170 @@ def spark_session(): spark.sql("create database if not exists test") yield spark spark.sql("drop database test cascade;") + + +@pytest.fixture +def sample_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def mismatched_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + StructField("height", IntegerType(), True), + ] + ) + + data = [("Alex", 32, 174), ("Evan", 23, 166)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def subset_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} + ) + + +@pytest.fixture +def subset_expected_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Alex", 32), ("Evan", 23)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def sample_pandas_df(): + return pd.DataFrame( + {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} + ) + + +@pytest.fixture +def append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [("Evan", 23), ("Frank", 13)] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_append_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ("Frank", 13), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): + schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", IntegerType(), True), + ] + ) + + data = [ + ("Alex", 31), + ("Alex", 32), + ("Bob", 12), + ("Clarke", 65), + ("Dave", 29), + ("Evan", 23), + ] + + return spark_session.createDataFrame(data, schema) + + +@pytest.fixture +def external_location(): + return os.environ.get("DATABRICKS_EXTERNAL_LOCATION") diff --git a/kedro-datasets/tests/databricks/test_base_table_dataset.py b/kedro-datasets/tests/databricks/test_base_table_dataset.py index 5cc88e8df..49f2283a1 100644 --- a/kedro-datasets/tests/databricks/test_base_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_base_table_dataset.py @@ -1,181 +1,12 @@ -import os - import pandas as pd import pytest from kedro.io.core import DatasetError, Version, VersionNotFoundError -from pyspark.sql import DataFrame, SparkSession +from pyspark.sql import DataFrame from pyspark.sql.types import IntegerType, StringType, StructField, StructType from kedro_datasets.databricks._base_table_dataset import BaseTableDataset -@pytest.fixture -def sample_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 32), ("Evan", 23)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def mismatched_upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", IntegerType(), True), - ] - ) - - data = [("Alex", 32, 174), ("Evan", 23, 166)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def subset_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", IntegerType(), True), - ] - ) - - data = [("Alex", 32, 174), ("Evan", 23, 166)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def subset_pandas_df(): - return pd.DataFrame( - {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} - ) - - -@pytest.fixture -def subset_expected_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 32), ("Evan", 23)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def sample_pandas_df(): - return pd.DataFrame( - {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} - ) - - -@pytest.fixture -def append_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Evan", 23), ("Frank", 13)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_append_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 31), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ("Frank", 13), - ] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 32), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 31), - ("Alex", 32), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def external_location(): - return os.environ.get("DATABRICKS_EXTERNAL_LOCATION") - - class TestBaseTableDataset: def test_full_table(self): unity_ds = BaseTableDataset(catalog="test", database="test", table="test") diff --git a/kedro-datasets/tests/databricks/test_managed_table_dataset.py b/kedro-datasets/tests/databricks/test_managed_table_dataset.py index c3cc623f4..6c7acb97b 100644 --- a/kedro-datasets/tests/databricks/test_managed_table_dataset.py +++ b/kedro-datasets/tests/databricks/test_managed_table_dataset.py @@ -1,173 +1,6 @@ -import pandas as pd -import pytest -from pyspark.sql import SparkSession -from pyspark.sql.types import IntegerType, StringType, StructField, StructType - from kedro_datasets.databricks import ManagedTableDataset -@pytest.fixture -def sample_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 31), ("Bob", 12), ("Clarke", 65), ("Dave", 29)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 32), ("Evan", 23)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def mismatched_upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", IntegerType(), True), - ] - ) - - data = [("Alex", 32, 174), ("Evan", 23, 166)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def subset_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - StructField("height", IntegerType(), True), - ] - ) - - data = [("Alex", 32, 174), ("Evan", 23, 166)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def subset_pandas_df(): - return pd.DataFrame( - {"name": ["Alex", "Evan"], "age": [32, 23], "height": [174, 166]} - ) - - -@pytest.fixture -def subset_expected_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Alex", 32), ("Evan", 23)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def sample_pandas_df(): - return pd.DataFrame( - {"name": ["Alex", "Bob", "Clarke", "Dave"], "age": [31, 12, 65, 29]} - ) - - -@pytest.fixture -def append_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [("Evan", 23), ("Frank", 13)] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_append_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 31), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ("Frank", 13), - ] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_upsert_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 32), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ] - - return spark_session.createDataFrame(data, schema) - - -@pytest.fixture -def expected_upsert_multiple_primary_spark_df(spark_session: SparkSession): - schema = StructType( - [ - StructField("name", StringType(), True), - StructField("age", IntegerType(), True), - ] - ) - - data = [ - ("Alex", 31), - ("Alex", 32), - ("Bob", 12), - ("Clarke", 65), - ("Dave", 29), - ("Evan", 23), - ] - - return spark_session.createDataFrame(data, schema) - - class TestManagedTableDataset: def test_describe(self): unity_ds = ManagedTableDataset(table="test")