Skip to content

Commit

Permalink
update Number class to handle integer values (#8306)
Browse files Browse the repository at this point in the history
* add show test for json data

* oh changie my changie

* revert unecessary cahnge to fixture

* keep decimal class for precision methods, but return __int__ value

* jerco updates

* update integer type

* update other tests

* Update .changes/unreleased/Fixes-20230803-093502.yaml

---------

Co-authored-by: Emily Rockman <emily.rockman@dbtlabs.com>
  • Loading branch information
dave-connors-3 and emmyoop committed Sep 6, 2023
1 parent 7e2a08f commit 82f086b
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 30 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230803-093502.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: Add explicit support for integers for the show command
time: 2023-08-03T09:35:02.163968-05:00
custom:
Author: dave-connors-3
Issue: "8153"
12 changes: 12 additions & 0 deletions core/dbt/clients/agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@
BOM = BOM_UTF8.decode("utf-8") # '\ufeff'


class Integer(agate.data_types.DataType):
def cast(self, d):
if type(d) == int:
return d
else:
raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d)

def jsonify(self, d):
return d


class Number(agate.data_types.Number):
# undo the change in https://github.com/wireservice/agate/pull/733
# i.e. do not cast True and False to numeric 1 and 0
Expand Down Expand Up @@ -47,6 +58,7 @@ def build_type_tester(
) -> agate.TypeTester:

types = [
Integer(null_values=("null", "")),
Number(null_values=("null", "")),
agate.data_types.Date(null_values=("null", ""), date_format="%Y-%m-%d"),
agate.data_types.DateTime(null_values=("null", ""), datetime_format="%Y-%m-%d %H:%M:%S"),
Expand Down
8 changes: 8 additions & 0 deletions tests/functional/show/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
select * from {{ ref('sample_seed') }}
"""

models__sample_number_model = """
select
cast(1.0 as int) as float_to_int_field,
3.0 as float_field,
4.3 as float_with_dec_field,
5 as int_field
"""

models__second_model = """
select
sample_num as col_one,
Expand Down
35 changes: 15 additions & 20 deletions tests/functional/show/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
models__second_ephemeral_model,
seeds__sample_seed,
models__sample_model,
models__sample_number_model,
models__second_model,
models__ephemeral_model,
schema_yml,
Expand All @@ -14,11 +15,12 @@
)


class BaseTestShow:
class TestShow:
@pytest.fixture(scope="class")
def models(self):
return {
"sample_model.sql": models__sample_model,
"sample_number_model.sql": models__sample_number_model,
"second_model.sql": models__second_model,
"ephemeral_model.sql": models__ephemeral_model,
"sql_header.sql": models__sql_header,
Expand All @@ -28,17 +30,13 @@ def models(self):
def seeds(self):
return {"sample_seed.csv": seeds__sample_seed}


class TestNone(BaseTestShow):
def test_none(self, project):
with pytest.raises(
DbtRuntimeError, match="Either --select or --inline must be passed to show"
):
run_dbt(["seed"])
run_dbt(["show"])


class TestSelectModelText(BaseTestShow):
def test_select_model_text(self, project):
run_dbt(["build"])
(results, log_output) = run_dbt_and_capture(["show", "--select", "second_model"])
Expand All @@ -48,8 +46,6 @@ def test_select_model_text(self, project):
assert "col_two" in log_output
assert "answer" in log_output


class TestSelectMultModelText(BaseTestShow):
def test_select_multiple_model_text(self, project):
run_dbt(["build"])
(results, log_output) = run_dbt_and_capture(
Expand All @@ -59,8 +55,6 @@ def test_select_multiple_model_text(self, project):
assert "sample_num" in log_output
assert "sample_bool" in log_output


class TestSelectSingleMultModelJson(BaseTestShow):
def test_select_single_model_json(self, project):
run_dbt(["build"])
(results, log_output) = run_dbt_and_capture(
Expand All @@ -70,8 +64,19 @@ def test_select_single_model_json(self, project):
assert "sample_num" in log_output
assert "sample_bool" in log_output

def test_numeric_values(self, project):
run_dbt(["build"])
(results, log_output) = run_dbt_and_capture(
["show", "--select", "sample_number_model", "--output", "json"]
)
assert "Previewing node 'sample_number_model'" not in log_output
assert "1.0" not in log_output
assert "1" in log_output
assert "3.0" in log_output
assert "4.3" in log_output
assert "5" in log_output
assert "5.0" not in log_output

class TestInlinePass(BaseTestShow):
def test_inline_pass(self, project):
run_dbt(["build"])
(results, log_output) = run_dbt_and_capture(
Expand All @@ -81,8 +86,6 @@ def test_inline_pass(self, project):
assert "sample_num" in log_output
assert "sample_bool" in log_output


class TestShowExceptions(BaseTestShow):
def test_inline_fail(self, project):
with pytest.raises(DbtException, match="Error parsing inline query"):
run_dbt(["show", "--inline", "select * from {{ ref('third_model') }}"])
Expand All @@ -91,8 +94,6 @@ def test_inline_fail_database_error(self, project):
with pytest.raises(DbtRuntimeError, match="Database Error"):
run_dbt(["show", "--inline", "slect asdlkjfsld;j"])


class TestEphemeralModels(BaseTestShow):
def test_ephemeral_model(self, project):
run_dbt(["build"])
(results, log_output) = run_dbt_and_capture(["show", "--select", "ephemeral_model"])
Expand All @@ -105,8 +106,6 @@ def test_second_ephemeral_model(self, project):
)
assert "col_hundo" in log_output


class TestLimit(BaseTestShow):
@pytest.mark.parametrize(
"args,expected",
[
Expand All @@ -121,14 +120,10 @@ def test_limit(self, project, args, expected):
results, log_output = run_dbt_and_capture(dbt_args)
assert len(results.results[0].agate_table) == expected


class TestSeed(BaseTestShow):
def test_seed(self, project):
(results, log_output) = run_dbt_and_capture(["show", "--select", "sample_seed"])
assert "Previewing node 'sample_seed'" in log_output


class TestSqlHeader(BaseTestShow):
def test_sql_header(self, project):
run_dbt(["build"])
(results, log_output) = run_dbt_and_capture(["show", "--select", "sql_header"])
Expand Down
20 changes: 10 additions & 10 deletions tests/unit/test_agate_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,37 +121,37 @@ def test_datetime_formats(self):
self.assertEqual(tbl[0][0], expected)

def test_merge_allnull(self):
t1 = agate.Table([(1, "a", None), (2, "b", None)], ("a", "b", "c"))
t2 = agate.Table([(3, "c", None), (4, "d", None)], ("a", "b", "c"))
t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c"))
t2 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c"))
result = agate_helper.merge_tables([t1, t2])
self.assertEqual(result.column_names, ("a", "b", "c"))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[0], agate_helper.Integer)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Number)
self.assertEqual(len(result), 4)

def test_merge_mixed(self):
t1 = agate.Table([(1, "a", None), (2, "b", None)], ("a", "b", "c"))
t2 = agate.Table([(3, "c", "dog"), (4, "d", "cat")], ("a", "b", "c"))
t3 = agate.Table([(3, "c", None), (4, "d", None)], ("a", "b", "c"))
t1 = agate_helper.table_from_rows([(1, "a", None), (2, "b", None)], ("a", "b", "c"))
t2 = agate_helper.table_from_rows([(3, "c", "dog"), (4, "d", "cat")], ("a", "b", "c"))
t3 = agate_helper.table_from_rows([(3, "c", None), (4, "d", None)], ("a", "b", "c"))

result = agate_helper.merge_tables([t1, t2])
self.assertEqual(result.column_names, ("a", "b", "c"))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[0], agate_helper.Integer)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 4)

result = agate_helper.merge_tables([t2, t3])
self.assertEqual(result.column_names, ("a", "b", "c"))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[0], agate_helper.Integer)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 4)

result = agate_helper.merge_tables([t1, t2, t3])
self.assertEqual(result.column_names, ("a", "b", "c"))
assert isinstance(result.column_types[0], agate.data_types.Number)
assert isinstance(result.column_types[0], agate_helper.Integer)
assert isinstance(result.column_types[1], agate.data_types.Text)
assert isinstance(result.column_types[2], agate.data_types.Text)
self.assertEqual(len(result), 6)
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_nocast_bool_01(self):
self.assertEqual(len(tbl), len(result_set))

assert isinstance(tbl.column_types[0], agate.data_types.Boolean)
assert isinstance(tbl.column_types[1], agate.data_types.Number)
assert isinstance(tbl.column_types[1], agate_helper.Integer)

expected = [
[True, Decimal(1)],
Expand Down

0 comments on commit 82f086b

Please sign in to comment.