Skip to content

Commit

Permalink
Fix 5122 add PostgreSql lateral join operator for subquery (sqldeligh…
Browse files Browse the repository at this point in the history
…t#5337)

* Fix 5122 Add LATERAL to grammar

LATERAL is used in two positions in the join_operator - for table subqueries and for table joins

Add Mixin to expose table/columns in a subquery - avoid recursion stackoverflow if child is same as subquery

* Lateral Tests

Lateral Tests

Fixture test

Integration test

InterfaceGeneration test
  • Loading branch information
griffio authored Sep 12, 2024
1 parent cda4915 commit 0adf9a2
Show file tree
Hide file tree
Showing 6 changed files with 305 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"static com.alecstrong.sql.psi.core.psi.SqlTypes.CONFLICT"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.CONSTRAINT"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.CREATE"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.CROSS"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.CURRENT_DATE"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.CURRENT_TIME"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.CURRENT_TIMESTAMP"
Expand Down Expand Up @@ -52,15 +53,19 @@
"static com.alecstrong.sql.psi.core.psi.SqlTypes.IS"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.IGNORE"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.INDEX"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.INDEXED"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.INNER"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.INSERT"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.INTO"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.JOIN"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.KEY"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.LIKE"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.LIMIT"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.LP"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.MATCH"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.MINUS"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.MULTIPLY"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.NATURAL"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.NO"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.NOT"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.NOTHING"
Expand All @@ -69,6 +74,7 @@
"static com.alecstrong.sql.psi.core.psi.SqlTypes.ON"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.OR"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.ORDER"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.OUTER"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.PARTITION"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.PLUS"
"static com.alecstrong.sql.psi.core.psi.SqlTypes.PRIMARY"
Expand Down Expand Up @@ -106,9 +112,12 @@ overrides ::= type_name
| insert_stmt
| update_stmt_limited
| generated_clause
| join_operator
| join_clause
| result_column
| alter_table_add_column
| alter_table_rules
| table_or_subquery
| compound_select_stmt
| extension_expr
| extension_stmt
Expand Down Expand Up @@ -410,6 +419,19 @@ select_stmt ::= SELECT ( distinct_on_expr | [ DISTINCT | ALL ] ) {result_column}
pin = 1
}

lateral ::= 'LATERAL'
join_operator ::= ( COMMA [ lateral ]
| [ NATURAL ] [ ( {left_join_operator} | {right_join_operator} | {full_join_operator} ) [ OUTER ] | INNER | CROSS ] JOIN [ lateral ] ) {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlJoinOperatorImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlJoinOperator"
override = true
}

join_clause ::= {table_or_subquery} ( {join_operator} {table_or_subquery} [ {join_constraint} ] ) * {
mixin = "app.cash.sqldelight.dialects.postgresql.grammar.mixins.SqlJoinClauseMixin"
override = true
}

compound_select_stmt ::= [ {with_clause} ] {select_stmt} ( {compound_operator} {select_stmt} ) * [ ORDER BY {ordering_term} ( COMMA {ordering_term} ) * ] [ LIMIT {limiting_term} ] [ ( OFFSET | COMMA ) {limiting_term} ] [ FOR UPDATE [ 'SKIP' 'LOCKED' ] ] {
extends = "com.alecstrong.sql.psi.core.psi.impl.SqlCompoundSelectStmtImpl"
implements = "com.alecstrong.sql.psi.core.psi.SqlCompoundSelectStmt"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package app.cash.sqldelight.dialects.postgresql.grammar.mixins

import app.cash.sqldelight.dialects.postgresql.grammar.psi.PostgreSqlTypes
import com.alecstrong.sql.psi.core.psi.QueryElement
import com.alecstrong.sql.psi.core.psi.impl.SqlJoinClauseImpl
import com.intellij.lang.ASTNode
import com.intellij.psi.PsiElement
import com.intellij.psi.util.elementType

internal open class SqlJoinClauseMixin(node: ASTNode) : SqlJoinClauseImpl(node) {

override fun queryAvailable(child: PsiElement): Collection<QueryElement.QueryResult> {
return if (joinOperatorList
.flatMap { it.children.toList() }
.find { it.elementType == PostgreSqlTypes.LATERAL } != null
) {
tableOrSubqueryList.takeWhile { it != child }.flatMap { it.queryExposed() }
} else {
super.queryAvailable(child)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
CREATE TABLE A (
b_id INTEGER
);

CREATE TABLE B (
id INTEGER
);

SELECT * FROM A, LATERAL (SELECT * FROM B WHERE B.id = A.b_id) AB;

CREATE TABLE Author (
id INTEGER PRIMARY KEY,
name TEXT
);

CREATE TABLE Genre (
id INTEGER PRIMARY KEY,
name TEXT
);

CREATE TABLE Book (
id INTEGER PRIMARY KEY,
title TEXT,
author_id INTEGER REFERENCES Author(id),
genre_id INTEGER REFERENCES Genre(id)
);

SELECT
Author.name AS author_name,
Genre.name AS genre_name,
book_count
FROM
Author,
Genre,
LATERAL (
SELECT
COUNT(*) AS book_count
FROM
Book
WHERE
Book.author_id = Author.id
AND Book.genre_id = Genre.id
) AS book_counts;

CREATE TABLE Kickstarter_Data (
pledged INTEGER,
fx_rate NUMERIC,
backers_count INTEGER,
launched_at NUMERIC,
deadline NUMERIC,
goal INTEGER
);

SELECT
pledged_usd,
avg_pledge_usd,
duration,
(usd_from_goal / duration) AS usd_needed_daily
FROM Kickstarter_Data,
LATERAL (SELECT pledged / fx_rate AS pledged_usd) pu,
LATERAL (SELECT pledged_usd / backers_count AS avg_pledge_usd) apu,
LATERAL (SELECT goal / fx_rate AS goal_usd) gu,
LATERAL (SELECT goal_usd - pledged_usd AS usd_from_goal) ufg,
LATERAL (SELECT (deadline - launched_at) / 86400.00 AS duration) dr;

CREATE TABLE Regions (
id INTEGER,
name VARCHAR(255)
);

CREATE TABLE SalesPeople (
id INTEGER,
full_name VARCHAR(255),
home_region_id INTEGER
);

CREATE TABLE Sales (
id INTEGER,
amount NUMERIC,
product_id INTEGER,
salesperson_id INTEGER,
region_id INTEGER
);

SELECT
sp.id salesperson_id,
sp.full_name,
sp.home_region_id,
rg.name AS home_region_name,
home_region_sales.total_sales
FROM SalesPeople sp
JOIN Regions rg ON sp.home_region_id = rg.id
JOIN LATERAL (
SELECT SUM(amount) AS total_sales
FROM Sales s
WHERE s.salesperson_id = sp.id
AND s.region_id = sp.home_region_id
) home_region_sales ON TRUE;
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,108 @@ class InterfaceGeneration {
)
}

@Test
fun `postgres lateral sub select has correct result columns`() {
val result = FixtureCompiler.compileSql(
"""
|CREATE TABLE Test (
| p INTEGER,
| f NUMERIC,
| b INTEGER,
| l NUMERIC,
| d NUMERIC,
| g INTEGER
|);
|
|select:
|SELECT *
|FROM Test,
|LATERAL (SELECT p / f AS pf) _pf,
|LATERAL (SELECT pf / b AS pfb) _pfb,
|LATERAL (SELECT g / f AS gf) _gf,
|LATERAL (SELECT gf - pf AS gfpf) _gfpf,
|LATERAL (SELECT (d - l) / 60000.00 AS dl) _dl;
""".trimMargin(),
temporaryFolder,
fileName = "Lateral.sq",
overrideDialect = PostgreSqlDialect(),
)
assertThat(result.errors).isEmpty()
val generatedInterface = result.compilerOutput.get(File(result.outputDirectory, "com/example/LateralQueries.kt"))
assertThat(generatedInterface).isNotNull()
assertThat(generatedInterface.toString()).isEqualTo(
"""
|package com.example
|
|import app.cash.sqldelight.Query
|import app.cash.sqldelight.TransacterImpl
|import app.cash.sqldelight.db.SqlDriver
|import app.cash.sqldelight.driver.jdbc.JdbcCursor
|import java.math.BigDecimal
|import kotlin.Any
|import kotlin.Int
|
|public class LateralQueries(
| driver: SqlDriver,
|) : TransacterImpl(driver) {
| public fun <T : Any> select(mapper: (
| p: Int?,
| f: BigDecimal?,
| b: Int?,
| l: BigDecimal?,
| d: BigDecimal?,
| g: Int?,
| pf: BigDecimal?,
| pfb: BigDecimal?,
| gf: BigDecimal?,
| gfpf: BigDecimal?,
| dl: BigDecimal?,
| ) -> T): Query<T> = Query(89_549_764, arrayOf("Test"), driver, "Lateral.sq", "select", ""${'"'}
| |SELECT Test.p, Test.f, Test.b, Test.l, Test.d, Test.g, pf, pfb, gf, gfpf, dl
| |FROM Test,
| |LATERAL (SELECT p / f AS pf) _pf,
| |LATERAL (SELECT pf / b AS pfb) _pfb,
| |LATERAL (SELECT g / f AS gf) _gf,
| |LATERAL (SELECT gf - pf AS gfpf) _gfpf,
| |LATERAL (SELECT (d - l) / 60000.00 AS dl) _dl
| ""${'"'}.trimMargin()) { cursor ->
| check(cursor is JdbcCursor)
| mapper(
| cursor.getInt(0),
| cursor.getBigDecimal(1),
| cursor.getInt(2),
| cursor.getBigDecimal(3),
| cursor.getBigDecimal(4),
| cursor.getInt(5),
| cursor.getBigDecimal(6),
| cursor.getBigDecimal(7),
| cursor.getBigDecimal(8),
| cursor.getBigDecimal(9),
| cursor.getBigDecimal(10)
| )
| }
|
| public fun select(): Query<Select> = select { p, f, b, l, d, g, pf, pfb, gf, gfpf, dl ->
| Select(
| p,
| f,
| b,
| l,
| d,
| g,
| pf,
| pfb,
| gf,
| gfpf,
| dl
| )
| }
|}
|
""".trimMargin(),
)
}

private fun checkFixtureCompiles(fixtureRoot: String) {
val result = FixtureCompiler.compileFixture(
fixtureRoot = "src/test/query-interface-fixtures/$fixtureRoot",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
CREATE TABLE Regions (
id INTEGER,
name VARCHAR(255)
);

CREATE TABLE Sales_People (
id INTEGER,
full_name VARCHAR(255),
home_region_id INTEGER
);

CREATE TABLE Sales (
id INTEGER,
amount NUMERIC,
product_id INTEGER,
sales_person_id INTEGER,
region_id INTEGER
);

insertSales {
INSERT INTO Regions (id, name) VALUES (1, 'North America');
INSERT INTO Regions (id, name) VALUES (2, 'Europe');
INSERT INTO Regions (id, name) VALUES (3, 'Asia');

INSERT INTO Sales_People (id, full_name, home_region_id) VALUES (1, 'A D', 1);
INSERT INTO Sales_People (id, full_name, home_region_id) VALUES (2, 'L S', 2);
INSERT INTO Sales_People (id, full_name, home_region_id) VALUES (3, 'M J', 3);

INSERT INTO Sales (id, amount, product_id, sales_person_id, region_id) VALUES (1, 1000.50, 101, 1, 1);
INSERT INTO Sales (id, amount, product_id, sales_person_id, region_id) VALUES (2, 2500.75, 102, 2, 2);
INSERT INTO Sales (id, amount, product_id, sales_person_id, region_id) VALUES (3, 1250.25, 103, 3, 3);
}

selectSales:
SELECT
sp.id salesperson_id,
sp.full_name,
sp.home_region_id,
rg.name AS home_region_name,
home_region_sales.total_sales
FROM Sales_People AS sp
JOIN Regions rg ON sp.home_region_id = rg.id
JOIN LATERAL (
SELECT SUM(amount) AS total_sales
FROM Sales AS s
WHERE s.sales_person_id = sp.id
AND s.region_id = sp.home_region_id
) home_region_sales ON TRUE;
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import app.cash.sqldelight.Query
import app.cash.sqldelight.db.OptimisticLockException
import app.cash.sqldelight.driver.jdbc.JdbcDriver
import com.google.common.truth.Truth.assertThat
import java.math.BigDecimal
import java.sql.Connection
import java.sql.DriverManager
import java.time.Instant
Expand Down Expand Up @@ -1092,4 +1093,16 @@ class PostgreSqlTest {
assertThat(grade_date).isEqualTo(gradeExpected.grade_date)
}
}

@Test
fun testLateralJoin() {
database.lateralQueries.insertSales()
with(database.lateralQueries.selectSales().executeAsList()) {
assertThat(first().salesperson_id).isEqualTo(1)
assertThat(first().full_name).isEqualTo("A D")
assertThat(first().home_region_id).isEqualTo(1)
assertThat(first().home_region_name).isEqualTo("North America")
assertThat(first().total_sales).isEqualTo(BigDecimal("1000.50"))
}
}
}

0 comments on commit 0adf9a2

Please sign in to comment.