Skip to content

Commit

Permalink
Fix not equals comparison to be null-safe for adapters/utils tests (#…
Browse files Browse the repository at this point in the history
…7776)

* Fix names within functional test

* Changelog entry

* Test for implementation of null-safe equals comparison

* Remove duplicated where filter

* Fix null-safe equals comparison

* Fix tests for `concat` and `hash` by using empty strings () instead of `null`

* Remove macro namespace interpolation
  • Loading branch information
dbeatty10 authored Jun 6, 2023
1 parent dc35f56 commit 8e1c4ec
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 13 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230604-080052.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Fix null-safe equals comparison via `equals`
time: 2023-06-04T08:00:52.537967-06:00
custom:
Author: dbeatty10
Issue: "7778"
18 changes: 15 additions & 3 deletions tests/adapter/dbt/tests/adapter/utils/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from dbt.tests.util import run_dbt

macros__equals_sql = """
{% macro equals(actual, expected) %}
{# -- actual is not distinct from expected #}
(({{ actual }} = {{ expected }}) or ({{ actual }} is null and {{ expected }} is null))
{% macro equals(expr1, expr2) -%}
case when (({{ expr1 }} = {{ expr2 }}) or ({{ expr1 }} is null and {{ expr2 }} is null))
then 0
else 1
end = 0
{% endmacro %}
"""

Expand All @@ -15,6 +17,15 @@
{% endtest %}
"""

macros__replace_empty_sql = """
{% macro replace_empty(expr) -%}
case
when {{ expr }} = 'EMPTY' then ''
else {{ expr }}
end
{% endmacro %}
"""


class BaseUtils:
# setup
Expand All @@ -23,6 +34,7 @@ def macros(self):
return {
"equals.sql": macros__equals_sql,
"test_assert_equal.sql": macros__test_assert_equal_sql,
"replace_empty.sql": macros__replace_empty_sql,
}

# make it possible to dynamically update the macro call with a namespace
Expand Down
19 changes: 15 additions & 4 deletions tests/adapter/dbt/tests/adapter/utils/fixture_concat.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
# concat

# https://github.com/dbt-labs/dbt-core/issues/4725
seeds__data_concat_csv = """input_1,input_2,output
a,b,ab
a,,a
,b,b
,,
a,EMPTY,a
EMPTY,b,b
EMPTY,EMPTY,EMPTY
"""


models__test_concat_sql = """
with data as (
with seed_data as (
select * from {{ ref('data_concat') }}
),
data as (
select
{{ replace_empty('input_1') }} as input_1,
{{ replace_empty('input_2') }} as input_2,
{{ replace_empty('output') }} as output
from seed_data
)
select
Expand Down
41 changes: 41 additions & 0 deletions tests/adapter/dbt/tests/adapter/utils/fixture_equals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# equals

SEEDS__DATA_EQUALS_CSV = """key_name,x,y,expected
1,1,1,same
2,1,2,different
3,1,null,different
4,2,1,different
5,2,2,same
6,2,null,different
7,null,1,different
8,null,2,different
9,null,null,same
"""


MODELS__EQUAL_VALUES_SQL = """
with data as (
select * from {{ ref('data_equals') }}
)
select *
from data
where
{{ equals('x', 'y') }}
"""


MODELS__NOT_EQUAL_VALUES_SQL = """
with data as (
select * from {{ ref('data_equals') }}
)
select *
from data
where
not {{ equals('x', 'y') }}
"""
14 changes: 12 additions & 2 deletions tests/adapter/dbt/tests/adapter/utils/fixture_hash.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
# hash

# https://github.com/dbt-labs/dbt-core/issues/4725
seeds__data_hash_csv = """input_1,output
ab,187ef4436122d1cc2f40dc2b92f0eba0
a,0cc175b9c0f1b6a831c399e269772661
1,c4ca4238a0b923820dcc509a6f75849b
,d41d8cd98f00b204e9800998ecf8427e
EMPTY,d41d8cd98f00b204e9800998ecf8427e
"""


models__test_hash_sql = """
with data as (
with seed_data as (
select * from {{ ref('data_hash') }}
),
data as (
select
{{ replace_empty('input_1') }} as input_1,
{{ replace_empty('output') }} as output
from seed_data
)
select
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
MODELS__TEST_MIXED_NULL_COMPARE_YML = """
version: 2
models:
- name: test_null_compare
- name: test_mixed_null_compare
tests:
- assert_equal:
actual: actual
Expand Down
54 changes: 54 additions & 0 deletions tests/adapter/dbt/tests/adapter/utils/test_equals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from dbt.tests.adapter.utils.base_utils import macros__equals_sql
from dbt.tests.adapter.utils.fixture_equals import (
SEEDS__DATA_EQUALS_CSV,
MODELS__EQUAL_VALUES_SQL,
MODELS__NOT_EQUAL_VALUES_SQL,
)
from dbt.tests.util import run_dbt, relation_from_name


class BaseEquals:
@pytest.fixture(scope="class")
def macros(self):
return {
"equals.sql": macros__equals_sql,
}

@pytest.fixture(scope="class")
def seeds(self):
return {
"data_equals.csv": SEEDS__DATA_EQUALS_CSV,
}

@pytest.fixture(scope="class")
def models(self):
return {
"equal_values.sql": MODELS__EQUAL_VALUES_SQL,
"not_equal_values.sql": MODELS__NOT_EQUAL_VALUES_SQL,
}

def test_equal_values(self, project):
run_dbt(["seed"])
run_dbt(["run"])

# There are 9 cases total; 3 are equal and 6 are not equal

# 3 are equal
relation = relation_from_name(project.adapter, "equal_values")
result = project.run_sql(
f"select count(*) as num_rows from {relation} where expected = 'same'", fetch="one"
)
assert result[0] == 3

# 6 are not equal
relation = relation_from_name(project.adapter, "not_equal_values")
result = project.run_sql(
f"select count(*) as num_rows from {relation} where expected = 'different'",
fetch="one",
)
assert result[0] == 6


class TestEquals(BaseEquals):
pass
6 changes: 3 additions & 3 deletions tests/adapter/dbt/tests/adapter/utils/test_null_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class BaseMixedNullCompare(BaseUtils):
@pytest.fixture(scope="class")
def models(self):
return {
"test_mixed_null_compare.yml": MODELS__TEST_MIXED_NULL_COMPARE_SQL,
"test_mixed_null_compare.sql": MODELS__TEST_MIXED_NULL_COMPARE_YML,
"test_mixed_null_compare.yml": MODELS__TEST_MIXED_NULL_COMPARE_YML,
"test_mixed_null_compare.sql": MODELS__TEST_MIXED_NULL_COMPARE_SQL,
}

def test_build_assert_equal(self, project):
Expand All @@ -32,7 +32,7 @@ def models(self):
}


class TestMixedNullCompare(BaseNullCompare):
class TestMixedNullCompare(BaseMixedNullCompare):
pass


Expand Down

0 comments on commit 8e1c4ec

Please sign in to comment.