Skip to content

Commit

Permalink
[Coral-Trino] Migrate Trino Concat Operator transformer from RexShutt…
Browse files Browse the repository at this point in the history
…le to SqlShuttle (#378)

* initial commit for timestamp op migrations

* rename SqlShuttle class

* enable test and rename var

* Enable type derivation in Coral IR to Trino SQL translation path

* Extract enhanced type derivation logic to a separate class & refactor

* rebase fix

* refactor sqlSelect

* rebase with PR#426

* converge with DataTypeDerivedSqlCallConverter

* rebase

* fix for derive relDataType for dynamic udfs

* add comments
  • Loading branch information
aastha25 authored Jul 13, 2023
1 parent 633474f commit 69b26c5
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ public abstract class SqlCallTransformer {
private TypeDerivationUtil typeDerivationUtil;

public SqlCallTransformer() {

}

public SqlCallTransformer(TypeDerivationUtil typeDerivationUtil) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ public HiveToRelConverter(Map<String, Map<String, List<String>>> localMetaStore)
this.parseTreeBuilder = new ParseTreeBuilder(functionResolver);
}

public HiveFunctionResolver getFunctionResolver() {
return functionResolver;
}

@Override
protected SqlRexConvertletTable getConvertletTable() {
return new HiveConvertletTable();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ public Collection<Function> tryResolveAsDaliFunction(String functionName, @Nonnu
.collect(Collectors.toList());
}

public void addDynamicFunctionToTheRegistry(String funcClassName, Function function) {
if (!dynamicFunctionRegistry.contains(funcClassName)) {
dynamicFunctionRegistry.put(funcClassName, function);
}
}

private @Nonnull Collection<Function> resolveDaliFunctionDynamically(String functionName, String funcClassName,
HiveTable hiveTable, int numOfOperands) {
if (dynamicFunctionRegistry.contains(funcClassName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package com.linkedin.coral.trino.rel2trino;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -175,34 +174,9 @@ public RexNode visitCall(RexCall call) {
}
}

if (operatorName.equalsIgnoreCase("concat")) {
Optional<RexNode> modifiedCall = visitConcat(call);
if (modifiedCall.isPresent()) {
return modifiedCall.get();
}
}

return super.visitCall(call);
}

private Optional<RexNode> visitConcat(RexCall call) {
// Hive supports operations like CONCAT(date, varchar) while Trino only supports CONCAT(varchar, varchar)
// So we need to cast the unsupported types to varchar
final SqlOperator op = call.getOperator();
List<RexNode> convertedOperands = visitList(call.getOperands(), (boolean[]) null);
List<RexNode> castOperands = new ArrayList<>();

for (RexNode inputOperand : convertedOperands) {
if (inputOperand.getType().getSqlTypeName() != VARCHAR && inputOperand.getType().getSqlTypeName() != CHAR) {
final RexNode castOperand = rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), inputOperand);
castOperands.add(castOperand);
} else {
castOperands.add(inputOperand);
}
}
return Optional.of(rexBuilder.makeCall(op, castOperands));
}

// Hive allows passing in a byte array or String to substr/substring, so we can make an effort to emulate the
// behavior by casting non-String input to String
// https://cwiki.apache.org/confluence/display/hive/languagemanual+udf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
*/
package com.linkedin.coral.trino.rel2trino;

import org.apache.calcite.sql.SqlBasicCall;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.validate.SqlValidator;

import com.linkedin.coral.common.HiveMetastoreClient;
import com.linkedin.coral.common.functions.Function;
import com.linkedin.coral.common.transformers.SqlCallTransformers;
import com.linkedin.coral.common.utils.TypeDerivationUtil;
import com.linkedin.coral.hive.hive2rel.HiveToRelConverter;
import com.linkedin.coral.hive.hive2rel.functions.VersionedSqlUserDefinedFunction;
import com.linkedin.coral.trino.rel2trino.transformers.ConcatOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.FromUtcTimestampOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer;
Expand All @@ -29,16 +32,34 @@
*/
public class DataTypeDerivedSqlCallConverter extends SqlShuttle {
private final SqlCallTransformers operatorTransformerList;
private final HiveToRelConverter toRelConverter;

public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode topSqlNode) {
SqlValidator sqlValidator = new HiveToRelConverter(mscClient).getSqlValidator();
TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(sqlValidator, topSqlNode);
toRelConverter = new HiveToRelConverter(mscClient);
topSqlNode.accept(new RegisterDynamicFunctionsForTypeDerivation());

TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode);
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil),
new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil));
new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil),
new ConcatOperatorTransformer(typeDerivationUtil));
}

@Override
public SqlNode visit(final SqlCall call) {
return operatorTransformerList.apply((SqlCall) super.visit(call));
}

private class RegisterDynamicFunctionsForTypeDerivation extends SqlShuttle {
@Override
public SqlNode visit(SqlCall sqlCall) {
if (sqlCall instanceof SqlBasicCall && sqlCall.getOperator() instanceof VersionedSqlUserDefinedFunction
&& sqlCall.getOperator().getName().contains(".")) {
// Register versioned SqlUserDefinedFunctions in RelConverter's dynamicFunctionRegistry.
// This enables the SqlValidator to derive RelDataType for SqlCalls that involve these operators.
Function function = new Function(sqlCall.getOperator().getName(), sqlCall.getOperator());
toRelConverter.getFunctionResolver().addDynamicFunctionToTheRegistry(sqlCall.getOperator().getName(), function);
}
return super.visit(sqlCall);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,11 @@ public Result visit(Project e) {
final List<SqlNode> selectList = new ArrayList<>();
for (RexNode ref : e.getChildExps()) {
SqlNode sqlExpr = builder.context.toSql(null, ref);

// Append the CAST operator when the derived data type is NON-NULL.
RelDataTypeField targetField = e.getRowType().getFieldList().get(selectList.size());
if (SqlUtil.isNullLiteral(sqlExpr, false) && !targetField.getValue().getSqlTypeName().equals(SqlTypeName.NULL)) {
sqlExpr = SqlStdOperatorTable.CAST.createCall(POS, sqlExpr, dialect.getCastSpec(targetField.getType()));
}

addSelect(selectList, sqlExpr, e.getRowType());
}
builder.setSelect(new SqlNodeList(selectList, POS));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeName;

import com.linkedin.coral.common.HiveTypeSystem;
import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.common.utils.TypeDerivationUtil;

import static org.apache.calcite.rel.rel2sql.SqlImplementor.*;
import static org.apache.calcite.sql.parser.SqlParserPos.*;


/**
* This transformer is designed for SqlCalls that use the CONCAT operator.
* Its purpose is to convert the data types of the operands to be compatible with Trino.
* Trino only allows VARCHAR type operands for the CONCAT operator. Therefore, if there are any other data type operands present,
* an extra CAST operator is added around the operand to cast it to VARCHAR.
*/
public class ConcatOperatorTransformer extends SqlCallTransformer {
private static final int DEFAULT_VARCHAR_PRECISION = new HiveTypeSystem().getDefaultPrecision(SqlTypeName.VARCHAR);
private static final String OPERATOR_NAME = "concat";
private static final Set<SqlTypeName> OPERAND_SQL_TYPE_NAMES =
new HashSet<>(Arrays.asList(SqlTypeName.VARCHAR, SqlTypeName.CHAR));
private static final SqlDataTypeSpec VARCHAR_SQL_DATA_TYPE_SPEC =
new SqlDataTypeSpec(new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, DEFAULT_VARCHAR_PRECISION, ZERO), ZERO);

public ConcatOperatorTransformer(TypeDerivationUtil typeDerivationUtil) {
super(typeDerivationUtil);
}

@Override
protected boolean condition(SqlCall sqlCall) {
return sqlCall.getOperator().getName().equalsIgnoreCase(OPERATOR_NAME);
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> updatedOperands = new ArrayList<>();

for (SqlNode operand : sqlCall.getOperandList()) {
RelDataType type = deriveRelDatatype(operand);
if (!OPERAND_SQL_TYPE_NAMES.contains(type.getSqlTypeName())) {
SqlNode castOperand = SqlStdOperatorTable.CAST.createCall(POS,
new ArrayList<>(Arrays.asList(operand, VARCHAR_SQL_DATA_TYPE_SPEC)));
updatedOperands.add(castOperand);
} else {
updatedOperands.add(operand);
}
}
return sqlCall.getOperator().createCall(POS, updatedOperands);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,19 @@ public void testDateFormatFunction() {
assertEquals(expandedSql, targetSql);
}

@Test
public void testConcatWithUnionAndStar() {
RelNode relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT * from test.tableA union all SELECT * from test.tableB where concat(current_date(), '|', tableB.a) = 'invalid'");
RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter();
String expandedSql = relToTrinoConverter.convert(relNode);

String expected = "SELECT *\n" + "FROM \"test\".\"tablea\" AS \"tablea\"\n" + "UNION ALL\n" + "SELECT *\n"
+ "FROM \"test\".\"tableb\" AS \"tableb\"\n"
+ "WHERE \"concat\"(CAST(CURRENT_DATE AS VARCHAR(65535)), '|', CAST(\"tableb\".\"a\" AS VARCHAR(65535))) = 'invalid'";
assertEquals(expandedSql, expected);
}

@Test
public void testConcatFunction() {
RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter();
Expand Down Expand Up @@ -771,8 +784,9 @@ public void testRegexpTransformation() {
assertEquals(expandedSql, targetSql);
}

@Test
public void testSqlSelectAliasAppenderTransformer() {
//test.tableA(a int, b struct<b1:string>
// test.tableA(a int, b struct<b1:string>
RelNode relNode = TestUtils.getHiveToRelConverter().convertSql("SELECT tableA.b.b1 FROM test.tableA where a > 5");
RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter();
String expandedSql = relToTrinoConverter.convert(relNode);
Expand Down

0 comments on commit 69b26c5

Please sign in to comment.