From 0adf9a2260a1c1b817cabdea718ec4433e06bb17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?griff=20=D1=96=E2=8A=99?= <346896+griffio@users.noreply.github.com> Date: Thu, 12 Sep 2024 19:37:16 +0100 Subject: [PATCH] Fix 5122 add PostgreSql lateral join operator for subquery (#5337) * Fix 5122 Add LATERAL to grammar LATERAL is used in two positions in the join_operator - for table subqueries and for table joins Add Mixin to expose table/columns in a subquery - avoid recursion stackoverflow if child is same as subquery * Lateral Tests Lateral Tests Fixture test Integration test InterfaceGeneration test --- .../postgresql/grammar/PostgreSql.bnf | 22 ++++ .../grammar/mixins/SqlJoinClauseMixin.kt | 22 ++++ .../fixtures_postgresql/lateral/Test.s | 98 +++++++++++++++++ .../core/queries/InterfaceGeneration.kt | 102 ++++++++++++++++++ .../postgresql/integration/Lateral.sq | 48 +++++++++ .../postgresql/integration/PostgreSqlTest.kt | 13 +++ 6 files changed, 305 insertions(+) create mode 100644 dialects/postgresql/src/main/kotlin/app/cash/sqldelight/dialects/postgresql/grammar/mixins/SqlJoinClauseMixin.kt create mode 100644 dialects/postgresql/src/testFixtures/resources/fixtures_postgresql/lateral/Test.s create mode 100644 sqldelight-gradle-plugin/src/test/integration-postgresql/src/main/sqldelight/app/cash/sqldelight/postgresql/integration/Lateral.sq diff --git a/dialects/postgresql/src/main/kotlin/app/cash/sqldelight/dialects/postgresql/grammar/PostgreSql.bnf b/dialects/postgresql/src/main/kotlin/app/cash/sqldelight/dialects/postgresql/grammar/PostgreSql.bnf index 7f0e9fe3d8d..73f9b22c6e8 100644 --- a/dialects/postgresql/src/main/kotlin/app/cash/sqldelight/dialects/postgresql/grammar/PostgreSql.bnf +++ b/dialects/postgresql/src/main/kotlin/app/cash/sqldelight/dialects/postgresql/grammar/PostgreSql.bnf @@ -25,6 +25,7 @@ "static com.alecstrong.sql.psi.core.psi.SqlTypes.CONFLICT" "static com.alecstrong.sql.psi.core.psi.SqlTypes.CONSTRAINT" "static com.alecstrong.sql.psi.core.psi.SqlTypes.CREATE" + "static com.alecstrong.sql.psi.core.psi.SqlTypes.CROSS" "static com.alecstrong.sql.psi.core.psi.SqlTypes.CURRENT_DATE" "static com.alecstrong.sql.psi.core.psi.SqlTypes.CURRENT_TIME" "static com.alecstrong.sql.psi.core.psi.SqlTypes.CURRENT_TIMESTAMP" @@ -52,8 +53,11 @@ "static com.alecstrong.sql.psi.core.psi.SqlTypes.IS" "static com.alecstrong.sql.psi.core.psi.SqlTypes.IGNORE" "static com.alecstrong.sql.psi.core.psi.SqlTypes.INDEX" + "static com.alecstrong.sql.psi.core.psi.SqlTypes.INDEXED" + "static com.alecstrong.sql.psi.core.psi.SqlTypes.INNER" "static com.alecstrong.sql.psi.core.psi.SqlTypes.INSERT" "static com.alecstrong.sql.psi.core.psi.SqlTypes.INTO" + "static com.alecstrong.sql.psi.core.psi.SqlTypes.JOIN" "static com.alecstrong.sql.psi.core.psi.SqlTypes.KEY" "static com.alecstrong.sql.psi.core.psi.SqlTypes.LIKE" "static com.alecstrong.sql.psi.core.psi.SqlTypes.LIMIT" @@ -61,6 +65,7 @@ "static com.alecstrong.sql.psi.core.psi.SqlTypes.MATCH" "static com.alecstrong.sql.psi.core.psi.SqlTypes.MINUS" "static com.alecstrong.sql.psi.core.psi.SqlTypes.MULTIPLY" + "static com.alecstrong.sql.psi.core.psi.SqlTypes.NATURAL" "static com.alecstrong.sql.psi.core.psi.SqlTypes.NO" "static com.alecstrong.sql.psi.core.psi.SqlTypes.NOT" "static com.alecstrong.sql.psi.core.psi.SqlTypes.NOTHING" @@ -69,6 +74,7 @@ "static com.alecstrong.sql.psi.core.psi.SqlTypes.ON" "static com.alecstrong.sql.psi.core.psi.SqlTypes.OR" "static com.alecstrong.sql.psi.core.psi.SqlTypes.ORDER" + "static com.alecstrong.sql.psi.core.psi.SqlTypes.OUTER" "static com.alecstrong.sql.psi.core.psi.SqlTypes.PARTITION" "static com.alecstrong.sql.psi.core.psi.SqlTypes.PLUS" "static com.alecstrong.sql.psi.core.psi.SqlTypes.PRIMARY" @@ -106,9 +112,12 @@ overrides ::= type_name | insert_stmt | update_stmt_limited | generated_clause + | join_operator + | join_clause | result_column | alter_table_add_column | alter_table_rules + | table_or_subquery | compound_select_stmt | extension_expr | extension_stmt @@ -410,6 +419,19 @@ select_stmt ::= SELECT ( distinct_on_expr | [ DISTINCT | ALL ] ) {result_column} pin = 1 } +lateral ::= 'LATERAL' +join_operator ::= ( COMMA [ lateral ] + | [ NATURAL ] [ ( {left_join_operator} | {right_join_operator} | {full_join_operator} ) [ OUTER ] | INNER | CROSS ] JOIN [ lateral ] ) { + extends = "com.alecstrong.sql.psi.core.psi.impl.SqlJoinOperatorImpl" + implements = "com.alecstrong.sql.psi.core.psi.SqlJoinOperator" + override = true +} + +join_clause ::= {table_or_subquery} ( {join_operator} {table_or_subquery} [ {join_constraint} ] ) * { + mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.SqlJoinClauseMixin" + override = true +} + compound_select_stmt ::= [ {with_clause} ] {select_stmt} ( {compound_operator} {select_stmt} ) * [ ORDER BY {ordering_term} ( COMMA {ordering_term} ) * ] [ LIMIT {limiting_term} ] [ ( OFFSET | COMMA ) {limiting_term} ] [ FOR UPDATE [ 'SKIP' 'LOCKED' ] ] { extends = "com.alecstrong.sql.psi.core.psi.impl.SqlCompoundSelectStmtImpl" implements = "com.alecstrong.sql.psi.core.psi.SqlCompoundSelectStmt" diff --git a/dialects/postgresql/src/main/kotlin/app/cash/sqldelight/dialects/postgresql/grammar/mixins/SqlJoinClauseMixin.kt b/dialects/postgresql/src/main/kotlin/app/cash/sqldelight/dialects/postgresql/grammar/mixins/SqlJoinClauseMixin.kt new file mode 100644 index 00000000000..ef942ec5c86 --- /dev/null +++ b/dialects/postgresql/src/main/kotlin/app/cash/sqldelight/dialects/postgresql/grammar/mixins/SqlJoinClauseMixin.kt @@ -0,0 +1,22 @@ +package app.cash.sqldelight.dialects.postgresql.grammar.mixins + +import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlTypes +import com.alecstrong.sql.psi.core.psi.QueryElement +import com.alecstrong.sql.psi.core.psi.impl.SqlJoinClauseImpl +import com.intellij.lang.ASTNode +import com.intellij.psi.PsiElement +import com.intellij.psi.util.elementType + +internal open class SqlJoinClauseMixin(node: ASTNode) : SqlJoinClauseImpl(node) { + + override fun queryAvailable(child: PsiElement): Collection { + return if (joinOperatorList + .flatMap { it.children.toList() } + .find { it.elementType == PostgreSqlTypes.LATERAL } != null + ) { + tableOrSubqueryList.takeWhile { it != child }.flatMap { it.queryExposed() } + } else { + super.queryAvailable(child) + } + } +} diff --git a/dialects/postgresql/src/testFixtures/resources/fixtures_postgresql/lateral/Test.s b/dialects/postgresql/src/testFixtures/resources/fixtures_postgresql/lateral/Test.s new file mode 100644 index 00000000000..6d1dfe53ff4 --- /dev/null +++ b/dialects/postgresql/src/testFixtures/resources/fixtures_postgresql/lateral/Test.s @@ -0,0 +1,98 @@ +CREATE TABLE A ( + b_id INTEGER +); + +CREATE TABLE B ( + id INTEGER +); + +SELECT * FROM A, LATERAL (SELECT * FROM B WHERE B.id = A.b_id) AB; + +CREATE TABLE Author ( + id INTEGER PRIMARY KEY, + name TEXT +); + +CREATE TABLE Genre ( + id INTEGER PRIMARY KEY, + name TEXT +); + +CREATE TABLE Book ( + id INTEGER PRIMARY KEY, + title TEXT, + author_id INTEGER REFERENCES Author(id), + genre_id INTEGER REFERENCES Genre(id) +); + +SELECT + Author.name AS author_name, + Genre.name AS genre_name, + book_count +FROM + Author, + Genre, + LATERAL ( + SELECT + COUNT(*) AS book_count + FROM + Book + WHERE + Book.author_id = Author.id + AND Book.genre_id = Genre.id + ) AS book_counts; + +CREATE TABLE Kickstarter_Data ( + pledged INTEGER, + fx_rate NUMERIC, + backers_count INTEGER, + launched_at NUMERIC, + deadline NUMERIC, + goal INTEGER +); + +SELECT + pledged_usd, + avg_pledge_usd, + duration, + (usd_from_goal / duration) AS usd_needed_daily +FROM Kickstarter_Data, + LATERAL (SELECT pledged / fx_rate AS pledged_usd) pu, + LATERAL (SELECT pledged_usd / backers_count AS avg_pledge_usd) apu, + LATERAL (SELECT goal / fx_rate AS goal_usd) gu, + LATERAL (SELECT goal_usd - pledged_usd AS usd_from_goal) ufg, + LATERAL (SELECT (deadline - launched_at) / 86400.00 AS duration) dr; + +CREATE TABLE Regions ( + id INTEGER, + name VARCHAR(255) +); + +CREATE TABLE SalesPeople ( + id INTEGER, + full_name VARCHAR(255), + home_region_id INTEGER +); + +CREATE TABLE Sales ( + id INTEGER, + amount NUMERIC, + product_id INTEGER, + salesperson_id INTEGER, + region_id INTEGER +); + +SELECT + sp.id salesperson_id, + sp.full_name, + sp.home_region_id, + rg.name AS home_region_name, + home_region_sales.total_sales +FROM SalesPeople sp + JOIN Regions rg ON sp.home_region_id = rg.id + JOIN LATERAL ( + SELECT SUM(amount) AS total_sales + FROM Sales s + WHERE s.salesperson_id = sp.id + AND s.region_id = sp.home_region_id + ) home_region_sales ON TRUE; diff --git a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt index ab7a309ed10..e2be5ec40bd 100644 --- a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt +++ b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt @@ -1394,6 +1394,108 @@ class InterfaceGeneration { ) } + @Test + fun `postgres lateral sub select has correct result columns`() { + val result = FixtureCompiler.compileSql( + """ + |CREATE TABLE Test ( + | p INTEGER, + | f NUMERIC, + | b INTEGER, + | l NUMERIC, + | d NUMERIC, + | g INTEGER + |); + | + |select: + |SELECT * + |FROM Test, + |LATERAL (SELECT p / f AS pf) _pf, + |LATERAL (SELECT pf / b AS pfb) _pfb, + |LATERAL (SELECT g / f AS gf) _gf, + |LATERAL (SELECT gf - pf AS gfpf) _gfpf, + |LATERAL (SELECT (d - l) / 60000.00 AS dl) _dl; + """.trimMargin(), + temporaryFolder, + fileName = "Lateral.sq", + overrideDialect = PostgreSqlDialect(), + ) + assertThat(result.errors).isEmpty() + val generatedInterface = result.compilerOutput.get(File(result.outputDirectory, "com/example/LateralQueries.kt")) + assertThat(generatedInterface).isNotNull() + assertThat(generatedInterface.toString()).isEqualTo( + """ + |package com.example + | + |import app.cash.sqldelight.Query + |import app.cash.sqldelight.TransacterImpl + |import app.cash.sqldelight.db.SqlDriver + |import app.cash.sqldelight.driver.jdbc.JdbcCursor + |import java.math.BigDecimal + |import kotlin.Any + |import kotlin.Int + | + |public class LateralQueries( + | driver: SqlDriver, + |) : TransacterImpl(driver) { + | public fun select(mapper: ( + | p: Int?, + | f: BigDecimal?, + | b: Int?, + | l: BigDecimal?, + | d: BigDecimal?, + | g: Int?, + | pf: BigDecimal?, + | pfb: BigDecimal?, + | gf: BigDecimal?, + | gfpf: BigDecimal?, + | dl: BigDecimal?, + | ) -> T): Query = Query(89_549_764, arrayOf("Test"), driver, "Lateral.sq", "select", ""${'"'} + | |SELECT Test.p, Test.f, Test.b, Test.l, Test.d, Test.g, pf, pfb, gf, gfpf, dl + | |FROM Test, + | |LATERAL (SELECT p / f AS pf) _pf, + | |LATERAL (SELECT pf / b AS pfb) _pfb, + | |LATERAL (SELECT g / f AS gf) _gf, + | |LATERAL (SELECT gf - pf AS gfpf) _gfpf, + | |LATERAL (SELECT (d - l) / 60000.00 AS dl) _dl + | ""${'"'}.trimMargin()) { cursor -> + | check(cursor is JdbcCursor) + | mapper( + | cursor.getInt(0), + | cursor.getBigDecimal(1), + | cursor.getInt(2), + | cursor.getBigDecimal(3), + | cursor.getBigDecimal(4), + | cursor.getInt(5), + | cursor.getBigDecimal(6), + | cursor.getBigDecimal(7), + | cursor.getBigDecimal(8), + | cursor.getBigDecimal(9), + | cursor.getBigDecimal(10) + | ) + | } + | + | public fun select(): Query