Skip to content

Commit

Permalink
Fix: Quote columns when getting column level lineage (#2907)
Browse files Browse the repository at this point in the history
  • Loading branch information
vchan authored Jul 16, 2024
1 parent 1fa1a4c commit d92bf84
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
38 changes: 38 additions & 0 deletions tests/web/test_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,41 @@ def test_get_lineage_constants(project_context: Context) -> None:
response_json = response.json()
assert response_json['"foo"']["col"]["models"] == {'"bar"': ["col"]}
assert response_json['"bar"']["col"]["models"] == {'"external_table"': ["col"]}


def test_get_lineage_quoted_columns(project_context: Context) -> None:
project_tmp_path = project_context.path
models_dir = project_tmp_path / "models"
models_dir.mkdir()
foo_sql_file = models_dir / "foo.sql"
foo_sql_file.write_text(
"""MODEL (name foo);
WITH my_cte AS (
SELECT col as "@col" FROM bar
UNION
SELECT NULL::TIMESTAMP as "@col" FROM bar
UNION
SELECT 1 as "@col" FROM external_table
)
SELECT "@col" FROM my_cte;"""
)
bar_sql_file = models_dir / "bar.sql"
bar_sql_file.write_text(
"""MODEL (name bar);
SELECT col FROM external_table;"""
)
project_context.load()

response = client.get("/api/lineage/foo/@col")
assert response.status_code == 200, response.json()
response_json = response.json()
assert response_json['"foo"']["@col"]["models"] == {'"foo": my_cte': ["@col"]}
assert response_json['"foo": my_cte']["@col"]["models"] == {'"bar"': ["col"]}
assert response_json['"bar"']["col"]["models"] == {'"external_table"': ["col"]}

# Models only
response = client.get("/api/lineage/foo/@col?models_only=1")
assert response.status_code == 200, response.json()
response_json = response.json()
assert response_json['"foo"']["@col"]["models"] == {'"bar"': ["col"]}
assert response_json['"bar"']["col"]["models"] == {'"external_table"': ["col"]}
11 changes: 9 additions & 2 deletions web/server/api/endpoints/lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
router = APIRouter()


def quote_column(column: str, dialect: str) -> str:
return exp.to_identifier(column, quoted=True).sql(dialect=dialect)


def get_source_name(
node: Node, default_catalog: t.Optional[str], dialect: str, model_name: str
) -> str:
Expand Down Expand Up @@ -57,7 +61,8 @@ def create_lineage_adjacency_list(
models={},
)
continue
root = lineage(column, model)

root = lineage(quote_column(column, model.dialect), model)

for node in root.walk():
if root.name == "UNION" and node is root:
Expand Down Expand Up @@ -109,7 +114,9 @@ def create_models_only_lineage_adjacency_list(
model = context.get_model(model_name)
dependencies = defaultdict(set)
if model:
for table, column_names in column_dependencies(context, model_name, column).items():
for table, column_names in column_dependencies(
context, model_name, quote_column(column, model.dialect)
).items():
for column_name in column_names:
dependencies[table].add(column_name)
nodes.append((table, column_name))
Expand Down

0 comments on commit d92bf84

Please sign in to comment.