From a035cd9bf8d099ffec1e1cb6d9d8ffe36f9452ec Mon Sep 17 00:00:00 2001 From: Adrian Date: Mon, 23 Dec 2024 05:08:59 +0200 Subject: [PATCH] Added equals macro that handles null value comparison (#394) Co-authored-by: Mila Page <67295367+VersusFacit@users.noreply.github.com> --- .../Under the Hood-20241217-110536.yaml | 6 ++++++ .../incremental/test_incremental_unique_id.py | 21 ++++++++++++------- .../dbt/tests/adapter/utils/base_utils.py | 19 ++++++++--------- .../dbt/tests/adapter/utils/test_equals.py | 8 +------ .../models/incremental/merge.sql | 19 ++++++++++++++--- .../materializations/snapshots/helpers.sql | 19 ++++++++++++++--- .../snapshots/snapshot_merge.sql | 10 +++++++-- .../global_project/macros/utils/equals.sql | 12 +++++++++++ 8 files changed, 81 insertions(+), 33 deletions(-) create mode 100644 .changes/unreleased/Under the Hood-20241217-110536.yaml create mode 100644 dbt/include/global_project/macros/utils/equals.sql diff --git a/.changes/unreleased/Under the Hood-20241217-110536.yaml b/.changes/unreleased/Under the Hood-20241217-110536.yaml new file mode 100644 index 00000000..5716da5e --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241217-110536.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Added new equals macro that handles null value checks in sql +time: 2024-12-17T11:05:36.363421+02:00 +custom: + Author: adrianburusdbt + Issue: "159" diff --git a/dbt-tests-adapter/dbt/tests/adapter/incremental/test_incremental_unique_id.py b/dbt-tests-adapter/dbt/tests/adapter/incremental/test_incremental_unique_id.py index bddf407e..34807062 100644 --- a/dbt-tests-adapter/dbt/tests/adapter/incremental/test_incremental_unique_id.py +++ b/dbt-tests-adapter/dbt/tests/adapter/incremental/test_incremental_unique_id.py @@ -240,6 +240,8 @@ select 'NY','New York','Manhattan','2021-04-01' union all select 'PA','Philadelphia','Philadelphia','2021-05-21' +union all +select 'CO','Denver',null,'2021-06-18' """ @@ -265,6 +267,8 @@ select 'NY','New York','Manhattan','2021-04-01' union all select 'PA','Philadelphia','Philadelphia','2021-05-21' +union all +select 'CO','Denver',null,'2021-06-18' """ @@ -288,6 +292,7 @@ NY,Kings,Brooklyn,2021-04-02 NY,New York,Manhattan,2021-04-01 PA,Philadelphia,Philadelphia,2021-05-21 +CO,Denver,,2021-06-18 """ seeds__add_new_rows_sql = """ @@ -439,7 +444,7 @@ def fail_to_build_inc_missing_unique_key_column(self, incremental_model_name): def test__no_unique_keys(self, project): """with no unique keys, seed and model should match""" - expected_fields = self.get_expected_fields(relation="seed", seed_rows=8) + expected_fields = self.get_expected_fields(relation="seed", seed_rows=9) test_case_fields = self.get_test_fields( project, seed="seed", incremental_model="no_unique_key", update_sql_file="add_new_rows" ) @@ -449,7 +454,7 @@ def test__no_unique_keys(self, project): def test__empty_str_unique_key(self, project): """with empty string for unique key, seed and model should match""" - expected_fields = self.get_expected_fields(relation="seed", seed_rows=8) + expected_fields = self.get_expected_fields(relation="seed", seed_rows=9) test_case_fields = self.get_test_fields( project, seed="seed", @@ -462,7 +467,7 @@ def test__one_unique_key(self, project): """with one unique key, model will overwrite existing row""" expected_fields = self.get_expected_fields( - relation="one_str__overwrite", seed_rows=7, opt_model_count=1 + relation="one_str__overwrite", seed_rows=8, opt_model_count=1 ) test_case_fields = self.get_test_fields( project, @@ -487,7 +492,7 @@ def test__bad_unique_key(self, project): def test__empty_unique_key_list(self, project): """with no unique keys, seed and model should match""" - expected_fields = self.get_expected_fields(relation="seed", seed_rows=8) + expected_fields = self.get_expected_fields(relation="seed", seed_rows=9) test_case_fields = self.get_test_fields( project, seed="seed", @@ -500,7 +505,7 @@ def test__unary_unique_key_list(self, project): """with one unique key, model will overwrite existing row""" expected_fields = self.get_expected_fields( - relation="unique_key_list__inplace_overwrite", seed_rows=7, opt_model_count=1 + relation="unique_key_list__inplace_overwrite", seed_rows=8, opt_model_count=1 ) test_case_fields = self.get_test_fields( project, @@ -515,7 +520,7 @@ def test__duplicated_unary_unique_key_list(self, project): """with two of the same unique key, model will overwrite existing row""" expected_fields = self.get_expected_fields( - relation="unique_key_list__inplace_overwrite", seed_rows=7, opt_model_count=1 + relation="unique_key_list__inplace_overwrite", seed_rows=8, opt_model_count=1 ) test_case_fields = self.get_test_fields( project, @@ -530,7 +535,7 @@ def test__trinary_unique_key_list(self, project): """with three unique keys, model will overwrite existing row""" expected_fields = self.get_expected_fields( - relation="unique_key_list__inplace_overwrite", seed_rows=7, opt_model_count=1 + relation="unique_key_list__inplace_overwrite", seed_rows=8, opt_model_count=1 ) test_case_fields = self.get_test_fields( project, @@ -545,7 +550,7 @@ def test__trinary_unique_key_list_no_update(self, project): """even with three unique keys, adding distinct rows to seed does not cause seed and model to diverge""" - expected_fields = self.get_expected_fields(relation="seed", seed_rows=8) + expected_fields = self.get_expected_fields(relation="seed", seed_rows=9) test_case_fields = self.get_test_fields( project, seed="seed", diff --git a/dbt-tests-adapter/dbt/tests/adapter/utils/base_utils.py b/dbt-tests-adapter/dbt/tests/adapter/utils/base_utils.py index 23e1ca7f..15c042ec 100644 --- a/dbt-tests-adapter/dbt/tests/adapter/utils/base_utils.py +++ b/dbt-tests-adapter/dbt/tests/adapter/utils/base_utils.py @@ -1,16 +1,6 @@ import pytest from dbt.tests.util import run_dbt - -macros__equals_sql = """ -{% macro equals(expr1, expr2) -%} -case when (({{ expr1 }} = {{ expr2 }}) or ({{ expr1 }} is null and {{ expr2 }} is null)) - then 0 - else 1 -end = 0 -{% endmacro %} -""" - macros__test_assert_equal_sql = """ {% test assert_equal(model, actual, expected) %} select * from {{ model }} @@ -27,6 +17,15 @@ {% endmacro %} """ +macros__equals_sql = """ +{% macro equals(expr1, expr2) -%} +case when (({{ expr1 }} = {{ expr2 }}) or ({{ expr1 }} is null and {{ expr2 }} is null)) + then 0 + else 1 +end = 0 +{% endmacro %} +""" + class BaseUtils: # setup diff --git a/dbt-tests-adapter/dbt/tests/adapter/utils/test_equals.py b/dbt-tests-adapter/dbt/tests/adapter/utils/test_equals.py index c61f6fdf..d8596dc0 100644 --- a/dbt-tests-adapter/dbt/tests/adapter/utils/test_equals.py +++ b/dbt-tests-adapter/dbt/tests/adapter/utils/test_equals.py @@ -1,16 +1,10 @@ import pytest -from dbt.tests.adapter.utils import base_utils, fixture_equals +from dbt.tests.adapter.utils import fixture_equals from dbt.tests.util import relation_from_name, run_dbt class BaseEquals: - @pytest.fixture(scope="class") - def macros(self): - return { - "equals.sql": base_utils.macros__equals_sql, - } - @pytest.fixture(scope="class") def seeds(self): return { diff --git a/dbt/include/global_project/macros/materializations/models/incremental/merge.sql b/dbt/include/global_project/macros/materializations/models/incremental/merge.sql index ca972c9f..d7e8af70 100644 --- a/dbt/include/global_project/macros/materializations/models/incremental/merge.sql +++ b/dbt/include/global_project/macros/materializations/models/incremental/merge.sql @@ -21,8 +21,14 @@ {% do predicates.append(this_key_match) %} {% endfor %} {% else %} + {% set source_unique_key %} + DBT_INTERNAL_SOURCE.{{ unique_key }} + {% endset %} + {% set target_unique_key %} + DBT_INTERNAL_DEST.{{ unique_key }} + {% endset %} {% set unique_key_match %} - DBT_INTERNAL_SOURCE.{{ unique_key }} = DBT_INTERNAL_DEST.{{ unique_key }} + {{ equals(source_unique_key, target_unique_key) }} {% endset %} {% do predicates.append(unique_key_match) %} {% endif %} @@ -62,11 +68,18 @@ {% if unique_key %} {% if unique_key is sequence and unique_key is not string %} - delete from {{target }} + delete from {{ target }} using {{ source }} where ( {% for key in unique_key %} - {{ source }}.{{ key }} = {{ target }}.{{ key }} + {% set source_unique_key %} + {{ source }}.{{ key }} + {% endset %} + {% set target_unique_key %} + {{ target }}.{{ key }} + {% endset %} + + {{ equals(source_unique_key, target_unique_key) }} {{ "and " if not loop.last}} {% endfor %} {% if incremental_predicates %} diff --git a/dbt/include/global_project/macros/materializations/snapshots/helpers.sql b/dbt/include/global_project/macros/materializations/snapshots/helpers.sql index 33492cc9..905ab136 100644 --- a/dbt/include/global_project/macros/materializations/snapshots/helpers.sql +++ b/dbt/include/global_project/macros/materializations/snapshots/helpers.sql @@ -53,8 +53,14 @@ from {{ target_relation }} where {% if config.get('dbt_valid_to_current') %} - {# Check for either dbt_valid_to_current OR null, in order to correctly update records with nulls #} - ( {{ columns.dbt_valid_to }} = {{ config.get('dbt_valid_to_current') }} or {{ columns.dbt_valid_to }} is null) + {% set source_unique_key %} + columns.dbt_valid_to + {% endset %} + {% set target_unique_key %} + config.get('dbt_valid_to_current') + {% endset %} + + {{ equals(source_unique_key, target_unique_key) }} {% else %} {{ columns.dbt_valid_to }} is null {% endif %} @@ -276,7 +282,14 @@ {% macro unique_key_join_on(unique_key, identifier, from_identifier) %} {% if unique_key | is_list %} {% for key in unique_key %} - {{ identifier }}.dbt_unique_key_{{ loop.index }} = {{ from_identifier }}.dbt_unique_key_{{ loop.index }} + {% set source_unique_key %} + {{ identifier }}.dbt_unique_key_{{ loop.index }} + {% endset %} + {% set target_unique_key %} + {{ from_identifier }}.dbt_unique_key_{{ loop.index }} + {% endset %} + + {{ equals(source_unique_key, target_unique_key) }} {%- if not loop.last %} and {%- endif %} {% endfor %} {% else %} diff --git a/dbt/include/global_project/macros/materializations/snapshots/snapshot_merge.sql b/dbt/include/global_project/macros/materializations/snapshots/snapshot_merge.sql index cf787e4f..19a67f6b 100644 --- a/dbt/include/global_project/macros/materializations/snapshots/snapshot_merge.sql +++ b/dbt/include/global_project/macros/materializations/snapshots/snapshot_merge.sql @@ -15,8 +15,14 @@ when matched {% if config.get("dbt_valid_to_current") %} - and (DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} = {{ config.get('dbt_valid_to_current') }} or - DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} is null) + {% set source_unique_key %} + DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} + {% endset %} + {% set target_unique_key %} + {{ config.get('dbt_valid_to_current') }} + {% endset %} + and {{ equals(source_unique_key, target_unique_key) }} + {% else %} and DBT_INTERNAL_DEST.{{ columns.dbt_valid_to }} is null {% endif %} diff --git a/dbt/include/global_project/macros/utils/equals.sql b/dbt/include/global_project/macros/utils/equals.sql new file mode 100644 index 00000000..d63b6cc1 --- /dev/null +++ b/dbt/include/global_project/macros/utils/equals.sql @@ -0,0 +1,12 @@ +{% macro equals(expr1, expr2) %} + {{ return(adapter.dispatch('equals', 'dbt') (expr1, expr2)) }} +{%- endmacro %} + +{% macro default__equals(expr1, expr2) -%} + + case when (({{ expr1 }} = {{ expr2 }}) or ({{ expr1 }} is null and {{ expr2 }} is null)) + then 0 + else 1 + end = 0 + +{% endmacro %}