-
Notifications
You must be signed in to change notification settings - Fork 4.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Python-CDK: Add CDK
sql
module for new MotherDuck destination (#47260)
Co-authored-by: Guen Prawiroatmodjo <guen@motherduck.com>
- Loading branch information
1 parent
155593d
commit 8d6a7aa
Showing
16 changed files
with
2,651 additions
and
717 deletions.
There are no files selected for viewing
Empty file.
Empty file.
293 changes: 293 additions & 0 deletions
293
airbyte-cdk/python/airbyte_cdk/sql/_processors/duckdb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,293 @@ | ||
# Copyright (c) 2023 Airbyte, Inc., all rights reserved. | ||
"""A DuckDB implementation of the cache.""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
import warnings | ||
from pathlib import Path | ||
from typing import TYPE_CHECKING, Any, Dict, List, Literal | ||
|
||
import pyarrow as pa | ||
from airbyte_cdk import DestinationSyncMode | ||
from airbyte_cdk.sql import exceptions as exc | ||
from airbyte_cdk.sql.constants import AB_EXTRACTED_AT_COLUMN | ||
from airbyte_cdk.sql.secrets import SecretString | ||
from airbyte_cdk.sql.shared.sql_processor import SqlConfig, SqlProcessorBase, SQLRuntimeError | ||
from duckdb_engine import DuckDBEngineWarning | ||
from overrides import overrides | ||
from pydantic import Field | ||
from sqlalchemy import Executable, TextClause, text | ||
from sqlalchemy.exc import ProgrammingError, SQLAlchemyError | ||
|
||
if TYPE_CHECKING: | ||
from sqlalchemy.engine import Connection, Engine | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
# @dataclass | ||
class DuckDBConfig(SqlConfig): | ||
"""Configuration for DuckDB.""" | ||
|
||
db_path: Path | str = Field() | ||
"""Normally db_path is a Path object. | ||
The database name will be inferred from the file name. For example, given a `db_path` of | ||
`/path/to/my/duckdb-file`, the database name is `my_db`. | ||
""" | ||
|
||
schema_name: str = Field(default="main") | ||
"""The name of the schema to write to. Defaults to "main".""" | ||
|
||
@overrides | ||
def get_sql_alchemy_url(self) -> SecretString: | ||
"""Return the SQLAlchemy URL to use.""" | ||
# Suppress warnings from DuckDB about reflection on indices. | ||
# https://github.com/Mause/duckdb_engine/issues/905 | ||
warnings.filterwarnings( | ||
"ignore", | ||
message="duckdb-engine doesn't yet support reflection on indices", | ||
category=DuckDBEngineWarning, | ||
) | ||
return SecretString(f"duckdb:///{self.db_path!s}") | ||
|
||
@overrides | ||
def get_database_name(self) -> str: | ||
"""Return the name of the database.""" | ||
if self.db_path == ":memory:": | ||
return "memory" | ||
|
||
# Split the path on the appropriate separator ("/" or "\") | ||
split_on: Literal["/", "\\"] = "\\" if "\\" in str(self.db_path) else "/" | ||
|
||
# Return the file name without the extension | ||
return str(self.db_path).split(sep=split_on)[-1].split(".")[0] | ||
|
||
def _is_file_based_db(self) -> bool: | ||
"""Return whether the database is file-based.""" | ||
if isinstance(self.db_path, Path): | ||
return True | ||
|
||
db_path_str = str(self.db_path) | ||
return ( | ||
("/" in db_path_str or "\\" in db_path_str) | ||
and db_path_str != ":memory:" | ||
and "md:" not in db_path_str | ||
and "motherduck:" not in db_path_str | ||
) | ||
|
||
@overrides | ||
def get_sql_engine(self) -> Engine: | ||
"""Return the SQL Alchemy engine. | ||
This method is overridden to ensure that the database parent directory is created if it | ||
doesn't exist. | ||
""" | ||
if self._is_file_based_db(): | ||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True) | ||
|
||
return super().get_sql_engine() | ||
|
||
|
||
class DuckDBSqlProcessor(SqlProcessorBase): | ||
"""A DuckDB implementation of the cache. | ||
Jsonl is used for local file storage before bulk loading. | ||
Unlike the Snowflake implementation, we can't use the COPY command to load data | ||
so we insert as values instead. | ||
""" | ||
|
||
supports_merge_insert = False | ||
sql_config: DuckDBConfig | ||
|
||
@overrides | ||
def _setup(self) -> None: | ||
"""Create the database parent folder if it doesn't yet exist.""" | ||
if self.sql_config.db_path == ":memory:": | ||
return | ||
|
||
Path(self.sql_config.db_path).parent.mkdir(parents=True, exist_ok=True) | ||
|
||
def _create_table_if_not_exists( | ||
self, | ||
table_name: str, | ||
column_definition_str: str, | ||
primary_keys: list[str] | None = None, | ||
) -> None: | ||
if primary_keys: | ||
pk_str = ", ".join(primary_keys) | ||
column_definition_str += f",\n PRIMARY KEY ({pk_str})" | ||
|
||
cmd = f""" | ||
CREATE TABLE IF NOT EXISTS {self._fully_qualified(table_name)} ( | ||
{column_definition_str} | ||
) | ||
""" | ||
_ = self._execute_sql(cmd) | ||
|
||
def _do_checkpoint( | ||
self, | ||
connection: Connection | None = None, | ||
) -> None: | ||
"""Checkpoint the given connection. | ||
We override this method to ensure that the DuckDB WAL is checkpointed explicitly. | ||
Otherwise DuckDB will lazily flush the WAL to disk, which can cause issues for users | ||
who want to manipulate the DB files after writing them. | ||
For more info: | ||
- https://duckdb.org/docs/sql/statements/checkpoint.html | ||
""" | ||
if connection is not None: | ||
connection.execute(text("CHECKPOINT")) | ||
return | ||
|
||
with self.get_sql_connection() as new_conn: | ||
new_conn.execute(text("CHECKPOINT")) | ||
|
||
def _executemany(self, sql: str | TextClause | Executable, params: list[list[Any]]) -> None: | ||
"""Execute the given SQL statement.""" | ||
if isinstance(sql, str): | ||
sql = text(sql) | ||
|
||
with self.get_sql_connection() as conn: | ||
try: | ||
entries = list(params) | ||
conn.engine.pool.connect().executemany(str(sql), entries) # type: ignore | ||
except ( | ||
ProgrammingError, | ||
SQLAlchemyError, | ||
) as ex: | ||
msg = f"Error when executing SQL:\n{sql}\n{type(ex).__name__}{ex!s}" | ||
raise SQLRuntimeError(msg) from None # from ex | ||
|
||
def _write_with_executemany(self, buffer: Dict[str, Dict[str, List[Any]]], stream_name: str, table_name: str) -> None: | ||
column_names_list = list(buffer[stream_name].keys()) | ||
column_names = ", ".join(column_names_list) | ||
params = ", ".join(["?"] * len(column_names_list)) | ||
sql = f""" | ||
-- Write with executemany | ||
INSERT INTO {self._fully_qualified(table_name)} | ||
({column_names}) | ||
VALUES ({params}) | ||
""" | ||
entries_to_write = buffer[stream_name] | ||
num_entries = len(entries_to_write[column_names_list[0]]) | ||
parameters = [[entries_to_write[column_name][n] for column_name in column_names_list] for n in range(num_entries)] | ||
self._executemany(sql, parameters) | ||
|
||
def _write_from_pa_table(self, table_name: str, pa_table: pa.Table) -> None: | ||
full_table_name = self._fully_qualified(table_name) | ||
sql = f""" | ||
-- Write from PyArrow table | ||
INSERT INTO {full_table_name} SELECT * FROM pa_table | ||
""" | ||
self._execute_sql(sql) | ||
|
||
def _write_temp_table_to_target_table( | ||
self, | ||
stream_name: str, | ||
temp_table_name: str, | ||
final_table_name: str, | ||
sync_mode: DestinationSyncMode, | ||
) -> None: | ||
"""Write the temp table into the final table using the provided write strategy.""" | ||
if sync_mode == DestinationSyncMode.overwrite: | ||
# Note: No need to check for schema compatibility | ||
# here, because we are fully replacing the table. | ||
self._swap_temp_table_with_final_table( | ||
stream_name=stream_name, | ||
temp_table_name=temp_table_name, | ||
final_table_name=final_table_name, | ||
) | ||
return | ||
|
||
if sync_mode == DestinationSyncMode.append: | ||
self._ensure_compatible_table_schema( | ||
stream_name=stream_name, | ||
table_name=final_table_name, | ||
) | ||
self._append_temp_table_to_final_table( | ||
stream_name=stream_name, | ||
temp_table_name=temp_table_name, | ||
final_table_name=final_table_name, | ||
) | ||
return | ||
|
||
if sync_mode == DestinationSyncMode.append_dedup: | ||
self._ensure_compatible_table_schema( | ||
stream_name=stream_name, | ||
table_name=final_table_name, | ||
) | ||
if not self.supports_merge_insert: | ||
# Fallback to emulated merge if the database does not support merge natively. | ||
self._emulated_merge_temp_table_to_final_table( | ||
stream_name=stream_name, | ||
temp_table_name=temp_table_name, | ||
final_table_name=final_table_name, | ||
) | ||
return | ||
|
||
self._merge_temp_table_to_final_table( | ||
stream_name=stream_name, | ||
temp_table_name=temp_table_name, | ||
final_table_name=final_table_name, | ||
) | ||
return | ||
|
||
raise exc.AirbyteInternalError( | ||
message="Sync mode is not supported.", | ||
context={ | ||
"sync_mode": sync_mode, | ||
}, | ||
) | ||
|
||
def _drop_duplicates(self, table_name: str, stream_name: str) -> str: | ||
primary_keys = self.catalog_provider.get_primary_keys(stream_name) | ||
new_table_name = f"{table_name}_deduped" | ||
if primary_keys: | ||
pks = ", ".join(primary_keys) | ||
sql = f""" | ||
-- Drop duplicates from temp table | ||
CREATE TABLE {self._fully_qualified(new_table_name)} AS ( | ||
SELECT * FROM {self._fully_qualified(table_name)} | ||
QUALIFY row_number() OVER (PARTITION BY ({pks}) ORDER BY {AB_EXTRACTED_AT_COLUMN} DESC) = 1 | ||
) | ||
""" | ||
self._execute_sql(sql) | ||
return new_table_name | ||
return table_name | ||
|
||
def write_stream_data_from_buffer( | ||
self, | ||
buffer: Dict[str, Dict[str, List[Any]]], | ||
stream_name: str, | ||
sync_mode: DestinationSyncMode, | ||
) -> None: | ||
table_name = f"_airbyte_raw_{stream_name}" | ||
temp_table_name = self._create_table_for_loading(stream_name, batch_id=None) | ||
try: | ||
pa_table = pa.Table.from_pydict(buffer[stream_name]) | ||
except Exception: | ||
logger.exception( | ||
"Writing with PyArrow table failed, falling back to writing with executemany. Expect some performance degradation." | ||
) | ||
self._write_with_executemany(buffer, stream_name, temp_table_name) | ||
else: | ||
# DuckDB will automatically find and SELECT from the `pa_table` | ||
# local variable defined above. | ||
self._write_from_pa_table(temp_table_name, pa_table) | ||
|
||
temp_table_name_dedup = self._drop_duplicates(temp_table_name, stream_name) | ||
|
||
try: | ||
self._write_temp_table_to_target_table( | ||
stream_name=stream_name, | ||
temp_table_name=temp_table_name_dedup, | ||
final_table_name=table_name, | ||
sync_mode=sync_mode, | ||
) | ||
finally: | ||
self._drop_temp_table(temp_table_name_dedup, if_exists=True) | ||
self._drop_temp_table(temp_table_name, if_exists=True) |
81 changes: 81 additions & 0 deletions
81
airbyte-cdk/python/airbyte_cdk/sql/_processors/motherduck.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) 2024 Airbyte, Inc., all rights reserved. | ||
"""A MotherDuck implementation of the cache, built on DuckDB. | ||
## Usage Example | ||
```python | ||
from airbyte as ab | ||
from airbyte.caches import MotherDuckCache | ||
cache = MotherDuckCache( | ||
database="mydatabase", | ||
schema_name="myschema", | ||
api_key=ab.get_secret("MOTHERDUCK_API_KEY"), | ||
) | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import warnings | ||
|
||
from airbyte_cdk.sql._processors.duckdb import DuckDBConfig, DuckDBSqlProcessor | ||
from airbyte_cdk.sql.secrets import SecretString | ||
from duckdb_engine import DuckDBEngineWarning | ||
from overrides import overrides | ||
from pydantic import Field | ||
|
||
# Suppress warnings from DuckDB about reflection on indices. | ||
# https://github.com/Mause/duckdb_engine/issues/905 | ||
warnings.filterwarnings( | ||
"ignore", | ||
message="duckdb-engine doesn't yet support reflection on indices", | ||
category=DuckDBEngineWarning, | ||
) | ||
|
||
|
||
class MotherDuckConfig(DuckDBConfig): | ||
"""Configuration for the MotherDuck cache.""" | ||
|
||
database: str = Field() | ||
api_key: SecretString = Field() | ||
db_path: str = Field(default="md:") | ||
custom_user_agent: str = Field(default="airbyte") | ||
|
||
@overrides | ||
def get_sql_alchemy_url(self) -> SecretString: | ||
"""Return the SQLAlchemy URL to use.""" | ||
# Suppress warnings from DuckDB about reflection on indices. | ||
# https://github.com/Mause/duckdb_engine/issues/905 | ||
warnings.filterwarnings( | ||
"ignore", | ||
message="duckdb-engine doesn't yet support reflection on indices", | ||
category=DuckDBEngineWarning, | ||
) | ||
|
||
return SecretString( | ||
f"duckdb:///md:{self.database}?motherduck_token={self.api_key}" | ||
f"&custom_user_agent=={self.custom_user_agent}" | ||
# Not sure why this doesn't work. We have to override later in the flow. | ||
# f"&schema={self.schema_name}" | ||
) | ||
|
||
@overrides | ||
def get_database_name(self) -> str: | ||
"""Return the name of the database.""" | ||
return self.database | ||
|
||
|
||
class MotherDuckSqlProcessor(DuckDBSqlProcessor): | ||
"""A cache implementation for MotherDuck.""" | ||
|
||
supports_merge_insert = False | ||
|
||
@overrides | ||
def _setup(self) -> None: | ||
"""Do any necessary setup, if applicable. | ||
Note: The DuckDB parent class requires pre-creation of local directory structure. We | ||
don't need to do that here so we override the method be a no-op. | ||
""" | ||
# No setup to do and no need to pre-create local file storage. | ||
pass |
Empty file.
Oops, something went wrong.