Skip to content

Commit

Permalink
chore(embedded/sql): add support for LEFT JOIN
Browse files Browse the repository at this point in the history
Signed-off-by: Stefano Scafiti <stefano.scafiti96@gmail.com>
  • Loading branch information
ostafen committed Dec 9, 2024
1 parent 4cf4af7 commit 344275d
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 14 deletions.
138 changes: 138 additions & 0 deletions embedded/sql/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5986,6 +5986,123 @@ func TestNestedJoins(t *testing.T) {
require.NoError(t, err)
}

func TestLeftJoins(t *testing.T) {
e := setupCommonTest(t)

_, _, err := e.Exec(
context.Background(),
nil,
`
CREATE TABLE customers (
customer_id INTEGER,
customer_name VARCHAR(50),
email VARCHAR(100),
PRIMARY KEY customer_id
);
CREATE TABLE products (
product_id INTEGER,
product_name VARCHAR(50),
price FLOAT,
PRIMARY KEY product_id
);
CREATE TABLE orders (
order_id INTEGER,
customer_id INTEGER,
order_date TIMESTAMP,
PRIMARY KEY order_id
);
CREATE TABLE order_items (
order_item_id INTEGER,
order_id INTEGER,
product_id INTEGER,
quantity INTEGER,
PRIMARY KEY order_item_id
);
INSERT INTO customers (customer_id, customer_name, email)
VALUES
(1, 'Alice Johnson', 'alice@example.com'),
(2, 'Bob Smith', 'bob@example.com'),
(3, 'Charlie Brown', 'charlie@example.com');
INSERT INTO products (product_id, product_name, price)
VALUES
(1, 'Laptop', 1200.00),
(2, 'Smartphone', 800.00),
(3, 'Tablet', 400.00);
INSERT INTO orders (order_id, customer_id, order_date)
VALUES
(101, 1, '2024-11-01'::TIMESTAMP),
(102, 2, '2024-11-02'::TIMESTAMP),
(103, 1, '2024-11-03'::TIMESTAMP);
INSERT INTO order_items (order_item_id, order_id, product_id, quantity)
VALUES
(1, 101, 1, 2),
(2, 101, 2, 1),
(3, 102, 3, 3),
(4, 103, 2, 2);
`,
nil,
)
require.NoError(t, err)

assertQueryShouldProduceResults(
t,
e,
`SELECT c.customer_id, c.customer_name, c.email, o.order_id, o.order_date
FROM customers c LEFT JOIN orders o ON c.customer_id = o.customer_id
ORDER BY c.customer_id, o.order_date;`,
`
SELECT *
FROM (
VALUES
(1, 'Alice Johnson', 'alice@example.com', 101, '2024-11-01'::TIMESTAMP),
(1, 'Alice Johnson', 'alice@example.com', 103, '2024-11-03'::TIMESTAMP),
(2, 'Bob Smith', 'bob@example.com', 102, '2024-11-02'::TIMESTAMP),
(3, 'Charlie Brown', 'charlie@example.com', NULL, NULL)
)`,
)

assertQueryShouldProduceResults(
t,
e,
`
SELECT
c.customer_name,
c.email,
o.order_id,
o.order_date,
p.product_name,
oi.quantity,
p.price,
(oi.quantity * p.price) AS total_price
FROM
products p
LEFT JOIN order_Items oi ON p.product_id = oi.product_id
LEFT JOIN orders o ON oi.order_id = o.order_id
LEFT JOIN customers c ON o.customer_id = c.customer_id
ORDER BY o.order_date, c.customer_name;`,
`
SELECT *
FROM (
VALUES
('Alice Johnson', 'alice@example.com', 101, '2024-11-01'::TIMESTAMP, 'Laptop', 2, 1200.00, 2400.00),
('Alice Johnson', 'alice@example.com', 101, '2024-11-01'::TIMESTAMP, 'Smartphone', 1, 800.00, 800.00),
('Bob Smith', 'bob@example.com', 102, '2024-11-02'::TIMESTAMP, 'Tablet', 3, 400.00, 1200.00),
('Alice Johnson', 'alice@example.com', 103, '2024-11-03'::TIMESTAMP, 'Smartphone', 2, 800.00, 1600.00)
)`,
)
}

func TestReOpening(t *testing.T) {
st, err := store.Open(t.TempDir(), store.DefaultOptions().WithMultiIndexing(true))
require.NoError(t, err)
Expand Down Expand Up @@ -9434,3 +9551,24 @@ func TestFunctions(t *testing.T) {
require.Equal(t, "OBJECT", rows[0].ValuesByPosition[0].RawValue().(string))
})
}

func assertQueryShouldProduceResults(t *testing.T, e *Engine, query, resultQuery string) {
queryReader, err := e.Query(context.Background(), nil, query, nil)
require.NoError(t, err)
defer queryReader.Close()

resultReader, err := e.Query(context.Background(), nil, resultQuery, nil)
require.NoError(t, err)
defer resultReader.Close()

for {
actualRow, actualErr := queryReader.Read(context.Background())
expectedRow, expectedErr := resultReader.Read(context.Background())
require.Equal(t, expectedErr, actualErr)

if errors.Is(actualErr, ErrNoMoreRows) {
break
}
require.Equal(t, expectedRow.ValuesByPosition, actualRow.ValuesByPosition)
}
}
43 changes: 31 additions & 12 deletions embedded/sql/joint_row_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ func newJointRowReader(rowReader RowReader, joins []*JoinSpec) (*jointRowReader,
}

for _, jspec := range joins {
if jspec.joinType != InnerJoin {
switch jspec.joinType {
case InnerJoin, LeftJoin:
default:
return nil, ErrUnsupportedJoinType
}
}
Expand Down Expand Up @@ -113,7 +115,6 @@ func (jointr *jointRowReader) colsBySelector(ctx context.Context) (map[string]Co
colDescriptors[sel] = des
}
}

return colDescriptors, nil
}

Expand Down Expand Up @@ -240,17 +241,35 @@ func (jointr *jointRowReader) Read(ctx context.Context) (row *Row, err error) {

r, err := reader.Read(ctx)
if err == ErrNoMoreRows {
// previous reader will need to read next row
unsolvedFK = true

err = reader.Close()
if err != nil {
return nil, err
if jspec.joinType == InnerJoin {
// previous reader will need to read next row
unsolvedFK = true

err = reader.Close()
if err != nil {
return nil, err
}

break
} else { // LEFT JOIN: fill column values with NULLs
cols, err := reader.Columns(ctx)
if err != nil {
return nil, err
}

r = &Row{
ValuesByPosition: make([]TypedValue, len(cols)),
ValuesBySelector: make(map[string]TypedValue, len(cols)),
}

for i, col := range cols {
nullValue := NewNull(col.Type)

r.ValuesByPosition[i] = nullValue
r.ValuesBySelector[col.Selector()] = nullValue
}
}

break
}
if err != nil {
} else if err != nil {
reader.Close()
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion embedded/sql/joint_row_reader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ func TestJointRowReader(t *testing.T) {
r, err := newRawRowReader(tx, nil, table, period{}, "", &ScanSpecs{Index: table.primaryIndex})
require.NoError(t, err)

_, err = newJointRowReader(r, []*JoinSpec{{joinType: LeftJoin}})
_, err = newJointRowReader(r, []*JoinSpec{{joinType: RightJoin}})
require.ErrorIs(t, err, ErrUnsupportedJoinType)

_, err = newJointRowReader(r, []*JoinSpec{{joinType: LeftJoin}})
require.NoError(t, err)

_, err = newJointRowReader(r, []*JoinSpec{{joinType: InnerJoin, ds: &SelectStmt{}}})
require.NoError(t, err)

Expand Down
1 change: 0 additions & 1 deletion embedded/sql/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -3409,7 +3409,6 @@ func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[strin
rowReader = newLimitRowReader(rowReader, limit)
}
}

return rowReader, nil
}

Expand Down

0 comments on commit 344275d

Please sign in to comment.