From b29e7e81b251f7b5eb84e0e55fa95fc37e8e9375 Mon Sep 17 00:00:00 2001 From: Cor Date: Tue, 4 Jun 2024 18:21:34 +0200 Subject: [PATCH] Support dbt version 1.8 (#58) * Move dbt version to DBT_INSTALLED_VERSION global as tuple * Register adapter for dbt version 1.8.0 and later * Add REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES to args * Set context invocation in manifest * Add dbt v1.8.0 to testing --- .gitignore | 3 +++ .../dbt_project/macros/spark_adapter.sql | 7 +++++ setup.cfg | 6 +++-- src/pytest_dbt_core/fixtures.py | 27 ++++++++++++++++--- 4 files changed, 38 insertions(+), 5 deletions(-) create mode 100644 docs/source/_static/dbt_project/macros/spark_adapter.sql diff --git a/.gitignore b/.gitignore index a1d3f8b..19c891e 100644 --- a/.gitignore +++ b/.gitignore @@ -136,3 +136,6 @@ src/pytest_dbt_core/_version.py # dbt .user.yml dbt_packages/ + +# Editor +.idea/ diff --git a/docs/source/_static/dbt_project/macros/spark_adapter.sql b/docs/source/_static/dbt_project/macros/spark_adapter.sql new file mode 100644 index 0000000..2f32665 --- /dev/null +++ b/docs/source/_static/dbt_project/macros/spark_adapter.sql @@ -0,0 +1,7 @@ +{% macro spark__list_relations_without_caching(relation) %} + {% call statement('list_relations_without_caching', fetch_result=True) -%} + show table extended in {{ relation }} like '*' + {% endcall %} + + {% do return(load_result('list_relations_without_caching').table) %} +{% endmacro %} diff --git a/setup.cfg b/setup.cfg index 4b3fba9..46e10fb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -72,8 +72,8 @@ spark_options = [tox:tox] envlist = py3.7-dbt-spark{11,12,13,14,15} # dbt-spark v1.6 or later does not support Python 3.7 - py{3.8,3.9,3.10}-dbt-spark{11,12,13,14,15,16,17} - py3.11-dbt-spark{14,15,16,17} # Previous dbt-spark versions fail when using Python 3.11 + py{3.8,3.9,3.10}-dbt-spark{11,12,13,14,15,16,17,18} + py3.11-dbt-spark{14,15,16,17,18} # Previous dbt-spark versions fail when using Python 3.11 isolated_build = true skip_missing_interpreters = true @@ -95,6 +95,7 @@ deps = dbt-spark15: dbt-spark[ODBC]~=1.5.0 dbt-spark16: dbt-spark[ODBC]~=1.6.0 dbt-spark17: dbt-spark[ODBC]~=1.7.0 + dbt-spark18: dbt-spark[ODBC]~=1.8.0 pip >= 19.3.1 extras = test commands = pytest {posargs:tests} @@ -115,6 +116,7 @@ deps = dbt-spark15: dbt-spark[ODBC]~=1.5.0 dbt-spark16: dbt-spark[ODBC]~=1.6.0 dbt-spark17: dbt-spark[ODBC]~=1.7.0 + dbt-spark18: dbt-spark[ODBC]~=1.8.0 pip >= 19.3.1 extras = test commands_pre = dbt deps --project-dir {toxinidir}/docs/source/_static/dbt_project --profiles-dir {toxinidir}/docs/source/_static/dbt_project diff --git a/src/pytest_dbt_core/fixtures.py b/src/pytest_dbt_core/fixtures.py index 6cfe19e..d416528 100644 --- a/src/pytest_dbt_core/fixtures.py +++ b/src/pytest_dbt_core/fixtures.py @@ -15,7 +15,6 @@ from dbt.context import providers from dbt.contracts.graph.manifest import Manifest from dbt.parser.manifest import ManifestLoader -from dbt.semver import VersionSpecifier from dbt.tracking import User from dbt.adapters.factory import ( # isort:skip @@ -28,6 +27,15 @@ dbt.tracking.active_user = User(os.getcwd()) +def _get_installed_dbt_version() -> tuple[int, int]: + """Cast a dbt version to a tuple with major and minor version.""" + installed_dbt_version = version.get_installed_version() + return int(installed_dbt_version.major), int(installed_dbt_version.minor) + + +DBT_INSTALLED_VERSION = _get_installed_dbt_version() + + @dataclasses.dataclass(frozen=True) class Args: """ @@ -47,6 +55,8 @@ class Args: target: str | None profile: str | None threads: int | None + # Required from dbt version 1.8 onwards + REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES = False @pytest.fixture @@ -74,7 +84,7 @@ def config(request: SubRequest) -> RuntimeConfig: threads=None, ) - if VersionSpecifier("1", "5", "12") < version.get_installed_version(): + if DBT_INSTALLED_VERSION > (1, 5): # See https://github.com/dbt-labs/dbt-core/issues/9183 project_flags = project.read_project_flags( args.project_dir, args.profiles_dir @@ -102,7 +112,12 @@ def adapter(config: RuntimeConfig) -> AdapterContainer: AdapterContainer The adapter. """ - register_adapter(config) + if DBT_INSTALLED_VERSION > (1, 7): + from dbt.mp_context import get_mp_context + + register_adapter(config, get_mp_context()) + else: + register_adapter(config) adapter = get_adapter(config) adapter.acquire_connection() return adapter @@ -125,6 +140,12 @@ def manifest( Manifest The manifest. """ + if DBT_INSTALLED_VERSION > (1, 7): + from dbt_common.clients.system import get_env + from dbt_common.context import set_invocation_context + + set_invocation_context(get_env()) + manifest = ManifestLoader.get_full_manifest(adapter.config) return manifest