Skip to content

Commit

Permalink
[Coral-Hive] [Coral-Trino] Make named_struct a Coral IR operator and …
Browse files Browse the repository at this point in the history
…Migrate GenericProject Function (#431)

* Initial commit for genericProject Migration

* remaning genericProject and some from namedStruct

* initial commit for timestamp op migrations

* rename SqlShuttle class

* enable test and rename var

* initial commit for namedstruct from PR#412

* rename transformer and add UT

* build fix

* fix for nested named_struct()

* add documentation

---------

Co-authored-by: Walaa Eldin Moustafa <wmoustafa@linkedin.com>
  • Loading branch information
aastha25 and wmoustafa authored Jul 10, 2023
1 parent e808370 commit 633474f
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 133 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.hive.hive2rel;

import java.util.ArrayList;
import java.util.List;

import com.google.common.base.Preconditions;
Expand All @@ -17,7 +16,6 @@
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.fun.SqlCastFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql2rel.ReflectiveConvertletTable;
import org.apache.calcite.sql2rel.SqlRexContext;
import org.apache.calcite.sql2rel.SqlRexConvertlet;
Expand All @@ -26,7 +24,6 @@
import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.FunctionFieldReferenceOperator;
import com.linkedin.coral.hive.hive2rel.functions.HiveInOperator;
import com.linkedin.coral.hive.hive2rel.functions.HiveNamedStructFunction;


/**
Expand All @@ -35,17 +32,6 @@
*/
public class HiveConvertletTable extends ReflectiveConvertletTable {

@SuppressWarnings("unused")
public RexNode convertNamedStruct(SqlRexContext cx, HiveNamedStructFunction func, SqlCall call) {
List<RexNode> operandExpressions = new ArrayList<>(call.operandCount() / 2);
for (int i = 0; i < call.operandCount(); i += 2) {
operandExpressions.add(cx.convertExpression(call.operand(i + 1)));
}
RelDataType retType = cx.getValidator().getValidatedNodeType(call);
RexNode rowNode = cx.getRexBuilder().makeCall(retType, SqlStdOperatorTable.ROW, operandExpressions);
return cx.getRexBuilder().makeCast(retType, rowNode);
}

@SuppressWarnings("unused")
public RexNode convertHiveInOperator(SqlRexContext cx, HiveInOperator operator, SqlCall call) {
List<SqlNode> operandList = call.getOperandList();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2017-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -497,12 +497,11 @@ public void testStructPeekDisallowed() {
public void testStructReturnFieldAccess() {
final String sql = "select named_struct('field_a', 10, 'field_b', 'abc').field_b";
RelNode rel = toRel(sql);
final String expectedRel = "LogicalProject(EXPR$0=[CAST(ROW(10, 'abc')):"
+ "RecordType(INTEGER NOT NULL field_a, CHAR(3) NOT NULL field_b) NOT NULL.field_b])\n"
final String expectedRel = "LogicalProject(EXPR$0=[named_struct('field_a', 10, 'field_b', 'abc').field_b])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(relToStr(rel), expectedRel);
final String expectedSql = "SELECT CAST(ROW(10, 'abc') AS ROW(field_a INTEGER, field_b CHAR(3))).field_b\n"
+ "FROM (VALUES (0)) t (ZERO)";
final String expectedSql =
"SELECT named_struct('field_a', 10, 'field_b', 'abc').field_b\n" + "FROM (VALUES (0)) t (ZERO)";
assertEquals(relToHql(rel), expectedSql);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2018-2022 LinkedIn Corporation. All rights reserved.
* Copyright 2018-2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -44,8 +44,7 @@ public void testMixedTypes() {
final String sql = "SELECT named_struct('abc', 123, 'def', 'xyz')";
RelNode rel = toRel(sql);
final String generated = relToStr(rel);
final String expected = ""
+ "LogicalProject(EXPR$0=[CAST(ROW(123, 'xyz')):RecordType(INTEGER NOT NULL abc, CHAR(3) NOT NULL def) NOT NULL])\n"
final String expected = "" + "LogicalProject(EXPR$0=[named_struct('abc', 123, 'def', 'xyz')])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}
Expand All @@ -54,9 +53,8 @@ public void testMixedTypes() {
public void testNullFieldValue() {
final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', 150)";
final String generated = sqlToRelStr(sql);
final String expected =
"LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, 150)):RecordType(INTEGER abc, INTEGER NOT NULL def) NOT NULL])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
final String expected = "LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', 150)])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}

Expand All @@ -65,7 +63,7 @@ public void testAllNullValues() {
final String sql = "SELECT named_struct('abc', cast(NULL as int), 'def', cast(NULL as double))";
final String generated = sqlToRelStr(sql);
final String expected =
"LogicalProject(EXPR$0=[CAST(ROW(CAST(null:NULL):INTEGER, CAST(null:NULL):DOUBLE)):RecordType(INTEGER abc, DOUBLE def) NOT NULL])\n"
"LogicalProject(EXPR$0=[named_struct('abc', CAST(null:NULL):INTEGER, 'def', CAST(null:NULL):DOUBLE)])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
assertEquals(generated, expected);
}
Expand All @@ -74,10 +72,9 @@ public void testAllNullValues() {
public void testNestedComplexTypes() {
final String sql = "SELECT named_struct('arr', array(10, 15), 's', named_struct('f1', 123, 'f2', array(20.5)))";
final String generated = sqlToRelStr(sql);
final String expected = "LogicalProject(EXPR$0=[CAST(ROW(ARRAY(10, 15), CAST(ROW(123, ARRAY(20.5:DECIMAL(3, 1)))):"
+ "RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL)):"
+ "RecordType(INTEGER NOT NULL ARRAY NOT NULL arr, RecordType(INTEGER NOT NULL f1, DECIMAL(3, 1) NOT NULL ARRAY NOT NULL f2) NOT NULL s) NOT NULL])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
final String expected =
"LogicalProject(EXPR$0=[named_struct('arr', ARRAY(10, 15), 's', named_struct('f1', 123, 'f2', ARRAY(20.5:DECIMAL(3, 1))))])\n"
+ " LogicalValues(tuples=[[{ 0 }]])\n";
// verified by human that expected string is correct and retained here to protect from future changes
assertEquals(generated, expected);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

import com.linkedin.coral.com.google.common.collect.ImmutableList;
import com.linkedin.coral.common.functions.GenericProjectFunction;
import com.linkedin.coral.trino.rel2trino.functions.GenericProjectToTrinoConverter;

import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*;
import static org.apache.calcite.sql.type.ReturnTypes.explicit;
Expand Down Expand Up @@ -160,14 +158,6 @@ public TrinoRexConverter(RelNode node, Map<String, Boolean> configs) {

@Override
public RexNode visitCall(RexCall call) {
// GenericProject requires a nontrivial function rewrite because of the following:
// - makes use of Trino built-in UDFs transform_values for map objects and transform for array objects
// which has lambda functions as parameters
// - syntax is difficult for Calcite to parse
// - the return type varies based on a desired schema to be projected
if (call.getOperator() instanceof GenericProjectFunction) {
return GenericProjectToTrinoConverter.convertGenericProject(rexBuilder, call, node);
}

final String operatorName = call.getOperator().getName();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.trino.rel2trino.transformers.FromUtcTimestampOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer;


/**
Expand All @@ -31,7 +33,8 @@ public class DataTypeDerivedSqlCallConverter extends SqlShuttle {
public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode topSqlNode) {
SqlValidator sqlValidator = new HiveToRelConverter(mscClient).getSqlValidator();
TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(sqlValidator, topSqlNode);
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil));
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil),
new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* If a column, colA, has a RelDataType, relDataTypeA, with a Trino type string, trinoTypeStringA = buildStructDataTypeString(relDataTypeA),
* then the following operation is syntactically and semantically correct in Trino: CAST(colA as trinoTypeStringA)
*/
class RelDataTypeToTrinoTypeStringConverter {
public class RelDataTypeToTrinoTypeStringConverter {
private RelDataTypeToTrinoTypeStringConverter() {
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* Instead, we represent the input to this UDF as a string and we set its return type is passed as a parameter
* on creation.
*/
class TrinoMapTransformValuesFunction extends GenericTemplateFunction {
public class TrinoMapTransformValuesFunction extends GenericTemplateFunction {
public TrinoMapTransformValuesFunction(RelDataType transformValuesDataType) {
super(transformValuesDataType, "transform_values");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
* Instead, we represent the input to this UDF as a string and we set its return type is passed as a parameter
* on creation.
*/
class TrinoStructCastRowFunction extends GenericTemplateFunction {
public class TrinoStructCastRowFunction extends GenericTemplateFunction {
public TrinoStructCastRowFunction(RelDataType structDataType) {
super(structDataType, "cast");
}
Expand Down
Loading

0 comments on commit 633474f

Please sign in to comment.