Skip to content

Commit

Permalink
Python-CDK: Add CDK sql module for new MotherDuck destination (#47260)
Browse files Browse the repository at this point in the history
Co-authored-by: Guen Prawiroatmodjo <guen@motherduck.com>
  • Loading branch information
aaronsteers and guenp authored Oct 23, 2024
1 parent 155593d commit 8d6a7aa
Show file tree
Hide file tree
Showing 16 changed files with 2,651 additions and 717 deletions.
Empty file.
Empty file.
293 changes: 293 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sql/_processors/duckdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
"""A DuckDB implementation of the cache."""

from __future__ import annotations

import logging
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal

import pyarrow as pa
from airbyte_cdk import DestinationSyncMode
from airbyte_cdk.sql import exceptions as exc
from airbyte_cdk.sql.constants import AB_EXTRACTED_AT_COLUMN
from airbyte_cdk.sql.secrets import SecretString
from airbyte_cdk.sql.shared.sql_processor import SqlConfig, SqlProcessorBase, SQLRuntimeError
from duckdb_engine import DuckDBEngineWarning
from overrides import overrides
from pydantic import Field
from sqlalchemy import Executable, TextClause, text
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError

if TYPE_CHECKING:
from sqlalchemy.engine import Connection, Engine

logger = logging.getLogger(__name__)


# @dataclass
class DuckDBConfig(SqlConfig):
"""Configuration for DuckDB."""

db_path: Path | str = Field()
"""Normally db_path is a Path object.
The database name will be inferred from the file name. For example, given a `db_path` of
`/path/to/my/duckdb-file`, the database name is `my_db`.
"""

schema_name: str = Field(default="main")
"""The name of the schema to write to. Defaults to "main"."""

@overrides
def get_sql_alchemy_url(self) -> SecretString:
"""Return the SQLAlchemy URL to use."""
# Suppress warnings from DuckDB about reflection on indices.
# https://github.com/Mause/duckdb_engine/issues/905
warnings.filterwarnings(
"ignore",
message="duckdb-engine doesn't yet support reflection on indices",
category=DuckDBEngineWarning,
)
return SecretString(f"duckdb:///{self.db_path!s}")

@overrides
def get_database_name(self) -> str:
"""Return the name of the database."""
if self.db_path == ":memory:":
return "memory"

# Split the path on the appropriate separator ("/" or "\")
split_on: Literal["/", "\\"] = "\\" if "\\" in str(self.db_path) else "/"

# Return the file name without the extension
return str(self.db_path).split(sep=split_on)[-1].split(".")[0]

def _is_file_based_db(self) -> bool:
"""Return whether the database is file-based."""
if isinstance(self.db_path, Path):
return True

db_path_str = str(self.db_path)
return (
("/" in db_path_str or "\\" in db_path_str)
and db_path_str != ":memory:"
and "md:" not in db_path_str
and "motherduck:" not in db_path_str
)

@overrides
def get_sql_engine(self) -> Engine:
"""Return the SQL Alchemy engine.
This method is overridden to ensure that the database parent directory is created if it
doesn't exist.
"""
if self._is_file_based_db():
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)

return super().get_sql_engine()


class DuckDBSqlProcessor(SqlProcessorBase):
"""A DuckDB implementation of the cache.
Jsonl is used for local file storage before bulk loading.
Unlike the Snowflake implementation, we can't use the COPY command to load data
so we insert as values instead.
"""

supports_merge_insert = False
sql_config: DuckDBConfig

@overrides
def _setup(self) -> None:
"""Create the database parent folder if it doesn't yet exist."""
if self.sql_config.db_path == ":memory:":
return

Path(self.sql_config.db_path).parent.mkdir(parents=True, exist_ok=True)

def _create_table_if_not_exists(
self,
table_name: str,
column_definition_str: str,
primary_keys: list[str] | None = None,
) -> None:
if primary_keys:
pk_str = ", ".join(primary_keys)
column_definition_str += f",\n PRIMARY KEY ({pk_str})"

cmd = f"""
CREATE TABLE IF NOT EXISTS {self._fully_qualified(table_name)} (
{column_definition_str}
)
"""
_ = self._execute_sql(cmd)

def _do_checkpoint(
self,
connection: Connection | None = None,
) -> None:
"""Checkpoint the given connection.
We override this method to ensure that the DuckDB WAL is checkpointed explicitly.
Otherwise DuckDB will lazily flush the WAL to disk, which can cause issues for users
who want to manipulate the DB files after writing them.
For more info:
- https://duckdb.org/docs/sql/statements/checkpoint.html
"""
if connection is not None:
connection.execute(text("CHECKPOINT"))
return

with self.get_sql_connection() as new_conn:
new_conn.execute(text("CHECKPOINT"))

def _executemany(self, sql: str | TextClause | Executable, params: list[list[Any]]) -> None:
"""Execute the given SQL statement."""
if isinstance(sql, str):
sql = text(sql)

with self.get_sql_connection() as conn:
try:
entries = list(params)
conn.engine.pool.connect().executemany(str(sql), entries) # type: ignore
except (
ProgrammingError,
SQLAlchemyError,
) as ex:
msg = f"Error when executing SQL:\n{sql}\n{type(ex).__name__}{ex!s}"
raise SQLRuntimeError(msg) from None # from ex

def _write_with_executemany(self, buffer: Dict[str, Dict[str, List[Any]]], stream_name: str, table_name: str) -> None:
column_names_list = list(buffer[stream_name].keys())
column_names = ", ".join(column_names_list)
params = ", ".join(["?"] * len(column_names_list))
sql = f"""
-- Write with executemany
INSERT INTO {self._fully_qualified(table_name)}
({column_names})
VALUES ({params})
"""
entries_to_write = buffer[stream_name]
num_entries = len(entries_to_write[column_names_list[0]])
parameters = [[entries_to_write[column_name][n] for column_name in column_names_list] for n in range(num_entries)]
self._executemany(sql, parameters)

def _write_from_pa_table(self, table_name: str, pa_table: pa.Table) -> None:
full_table_name = self._fully_qualified(table_name)
sql = f"""
-- Write from PyArrow table
INSERT INTO {full_table_name} SELECT * FROM pa_table
"""
self._execute_sql(sql)

def _write_temp_table_to_target_table(
self,
stream_name: str,
temp_table_name: str,
final_table_name: str,
sync_mode: DestinationSyncMode,
) -> None:
"""Write the temp table into the final table using the provided write strategy."""
if sync_mode == DestinationSyncMode.overwrite:
# Note: No need to check for schema compatibility
# here, because we are fully replacing the table.
self._swap_temp_table_with_final_table(
stream_name=stream_name,
temp_table_name=temp_table_name,
final_table_name=final_table_name,
)
return

if sync_mode == DestinationSyncMode.append:
self._ensure_compatible_table_schema(
stream_name=stream_name,
table_name=final_table_name,
)
self._append_temp_table_to_final_table(
stream_name=stream_name,
temp_table_name=temp_table_name,
final_table_name=final_table_name,
)
return

if sync_mode == DestinationSyncMode.append_dedup:
self._ensure_compatible_table_schema(
stream_name=stream_name,
table_name=final_table_name,
)
if not self.supports_merge_insert:
# Fallback to emulated merge if the database does not support merge natively.
self._emulated_merge_temp_table_to_final_table(
stream_name=stream_name,
temp_table_name=temp_table_name,
final_table_name=final_table_name,
)
return

self._merge_temp_table_to_final_table(
stream_name=stream_name,
temp_table_name=temp_table_name,
final_table_name=final_table_name,
)
return

raise exc.AirbyteInternalError(
message="Sync mode is not supported.",
context={
"sync_mode": sync_mode,
},
)

def _drop_duplicates(self, table_name: str, stream_name: str) -> str:
primary_keys = self.catalog_provider.get_primary_keys(stream_name)
new_table_name = f"{table_name}_deduped"
if primary_keys:
pks = ", ".join(primary_keys)
sql = f"""
-- Drop duplicates from temp table
CREATE TABLE {self._fully_qualified(new_table_name)} AS (
SELECT * FROM {self._fully_qualified(table_name)}
QUALIFY row_number() OVER (PARTITION BY ({pks}) ORDER BY {AB_EXTRACTED_AT_COLUMN} DESC) = 1
)
"""
self._execute_sql(sql)
return new_table_name
return table_name

def write_stream_data_from_buffer(
self,
buffer: Dict[str, Dict[str, List[Any]]],
stream_name: str,
sync_mode: DestinationSyncMode,
) -> None:
table_name = f"_airbyte_raw_{stream_name}"
temp_table_name = self._create_table_for_loading(stream_name, batch_id=None)
try:
pa_table = pa.Table.from_pydict(buffer[stream_name])
except Exception:
logger.exception(
"Writing with PyArrow table failed, falling back to writing with executemany. Expect some performance degradation."
)
self._write_with_executemany(buffer, stream_name, temp_table_name)
else:
# DuckDB will automatically find and SELECT from the `pa_table`
# local variable defined above.
self._write_from_pa_table(temp_table_name, pa_table)

temp_table_name_dedup = self._drop_duplicates(temp_table_name, stream_name)

try:
self._write_temp_table_to_target_table(
stream_name=stream_name,
temp_table_name=temp_table_name_dedup,
final_table_name=table_name,
sync_mode=sync_mode,
)
finally:
self._drop_temp_table(temp_table_name_dedup, if_exists=True)
self._drop_temp_table(temp_table_name, if_exists=True)
81 changes: 81 additions & 0 deletions airbyte-cdk/python/airbyte_cdk/sql/_processors/motherduck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
"""A MotherDuck implementation of the cache, built on DuckDB.
## Usage Example
```python
from airbyte as ab
from airbyte.caches import MotherDuckCache
cache = MotherDuckCache(
database="mydatabase",
schema_name="myschema",
api_key=ab.get_secret("MOTHERDUCK_API_KEY"),
)
"""

from __future__ import annotations

import warnings

from airbyte_cdk.sql._processors.duckdb import DuckDBConfig, DuckDBSqlProcessor
from airbyte_cdk.sql.secrets import SecretString
from duckdb_engine import DuckDBEngineWarning
from overrides import overrides
from pydantic import Field

# Suppress warnings from DuckDB about reflection on indices.
# https://github.com/Mause/duckdb_engine/issues/905
warnings.filterwarnings(
"ignore",
message="duckdb-engine doesn't yet support reflection on indices",
category=DuckDBEngineWarning,
)


class MotherDuckConfig(DuckDBConfig):
"""Configuration for the MotherDuck cache."""

database: str = Field()
api_key: SecretString = Field()
db_path: str = Field(default="md:")
custom_user_agent: str = Field(default="airbyte")

@overrides
def get_sql_alchemy_url(self) -> SecretString:
"""Return the SQLAlchemy URL to use."""
# Suppress warnings from DuckDB about reflection on indices.
# https://github.com/Mause/duckdb_engine/issues/905
warnings.filterwarnings(
"ignore",
message="duckdb-engine doesn't yet support reflection on indices",
category=DuckDBEngineWarning,
)

return SecretString(
f"duckdb:///md:{self.database}?motherduck_token={self.api_key}"
f"&custom_user_agent=={self.custom_user_agent}"
# Not sure why this doesn't work. We have to override later in the flow.
# f"&schema={self.schema_name}"
)

@overrides
def get_database_name(self) -> str:
"""Return the name of the database."""
return self.database


class MotherDuckSqlProcessor(DuckDBSqlProcessor):
"""A cache implementation for MotherDuck."""

supports_merge_insert = False

@overrides
def _setup(self) -> None:
"""Do any necessary setup, if applicable.
Note: The DuckDB parent class requires pre-creation of local directory structure. We
don't need to do that here so we override the method be a no-op.
"""
# No setup to do and no need to pre-create local file storage.
pass
Empty file.
Loading

0 comments on commit 8d6a7aa

Please sign in to comment.