Skip to content

Commit

Permalink
Fix: Coercion of string to integers when converting csv to agate tabl…
Browse files Browse the repository at this point in the history
…es (#2918)
  • Loading branch information
izeigerman committed Jul 18, 2024
1 parent 4f1d735 commit ec498a5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
25 changes: 13 additions & 12 deletions sqlmesh/dbt/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,22 @@ def to_sqlmesh(self, context: DbtContext) -> Model:
)


class Integer(agate.data_types.DataType):
def cast(self, d: str) -> t.Optional[int]:
if d is None:
return d
try:
return int(d)
except ValueError:
raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d)

def jsonify(self, d: str) -> str:
class Integer(agate_helper.Integer):
def cast(self, d: t.Any) -> t.Optional[int]:
if isinstance(d, str):
# The dbt's implementation doesn't support coercion of strings to integers.
if d.strip().lower() in self.null_values:
return None
try:
return int(d)
except ValueError:
raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d)
return super().cast(d)

def jsonify(self, d: t.Any) -> str:
return d


# The dbt version has a bug in which they check whether the type of the input value
# is int, while the input value is actually always a string.
agate_helper.Integer = Integer # type: ignore


Expand Down
22 changes: 21 additions & 1 deletion tests/dbt/test_transformation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import agate
from datetime import datetime
import json
import logging
import typing as t
Expand Down Expand Up @@ -34,7 +36,7 @@
from sqlmesh.dbt.model import Materialization, ModelConfig
from sqlmesh.dbt.project import Project
from sqlmesh.dbt.relation import Policy
from sqlmesh.dbt.seed import SeedConfig
from sqlmesh.dbt.seed import SeedConfig, Integer
from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, SnowflakeConfig
from sqlmesh.dbt.test import TestConfig
from sqlmesh.utils.errors import ConfigError, MacroEvalError, SQLMeshError
Expand Down Expand Up @@ -402,6 +404,7 @@ def test_seed_column_inference(tmp_path):
fd.write("int_col,double_col,datetime_col,date_col,boolean_col,text_col\n")
fd.write("1,1.2,2021-01-01 00:00:00,2021-01-01,true,foo\n")
fd.write("2,2.3,2021-01-02 00:00:00,2021-01-02,false,bar\n")
fd.write("null,,null,,,null\n")

seed = SeedConfig(
name="test_model",
Expand All @@ -423,6 +426,23 @@ def test_seed_column_inference(tmp_path):
}


def test_agate_integer_cast():
agate_integer = Integer(null_values=("null", ""))
assert agate_integer.cast("1") == 1
assert agate_integer.cast(1) == 1
assert agate_integer.cast("null") is None
assert agate_integer.cast("") is None

with pytest.raises(agate.exceptions.CastError):
agate_integer.cast("1.2")

with pytest.raises(agate.exceptions.CastError):
agate_integer.cast(1.2)

with pytest.raises(agate.exceptions.CastError):
agate_integer.cast(datetime.now())


@pytest.mark.xdist_group("dbt_manifest")
def test_model_dialect(sushi_test_project: Project, assert_exp_eq):
model_config = ModelConfig(
Expand Down

0 comments on commit ec498a5

Please sign in to comment.