Skip to content

Commit

Permalink
[Coral-Trino] Migrate SUBSTR() operator from RexShuttle to SqlShuttle (
Browse files Browse the repository at this point in the history
…#432)

* [Coral-Trino] Initial commit for migrating SUBSTR() operator

* rebase

* register Coral IR functions

* add UTs
  • Loading branch information
aastha25 authored Jul 17, 2023
1 parent 8c4b054 commit 482fe22
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ public boolean isOptional(int i) {
or(family(SqlTypeFamily.STRING), family(SqlTypeFamily.BINARY)));
createAddUserDefinedFunction("crc32", BIGINT, or(family(SqlTypeFamily.STRING), family(SqlTypeFamily.BINARY)));
createAddUserDefinedFunction("from_utf8", explicit(SqlTypeName.VARCHAR), or(CHARACTER, BINARY));
createAddUserDefinedFunction("at_timezone", explicit(SqlTypeName.TIMESTAMP),
family(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.STRING));
createAddUserDefinedFunction("with_timezone", explicit(SqlTypeName.TIMESTAMP),
family(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.STRING));
createAddUserDefinedFunction("to_unixtime", explicit(SqlTypeName.DOUBLE), family(SqlTypeFamily.TIMESTAMP));
createAddUserDefinedFunction("from_unixtime_nanos", explicit(SqlTypeName.TIMESTAMP), NUMERIC);
createAddUserDefinedFunction("$canonicalize_hive_timezone_id", explicit(SqlTypeName.VARCHAR), STRING);

// xpath functions
createAddUserDefinedFunction("xpath", FunctionReturnTypes.arrayOfType(SqlTypeName.VARCHAR), STRING_STRING);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.validate.SqlUserDefinedFunction;

import com.linkedin.coral.com.google.common.collect.ImmutableList;

import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*;
import static org.apache.calcite.sql.type.ReturnTypes.explicit;
import static org.apache.calcite.sql.type.SqlTypeName.*;
Expand Down Expand Up @@ -167,34 +165,9 @@ public RexNode visitCall(RexCall call) {
}
}

if (operatorName.equalsIgnoreCase("substr")) {
Optional<RexNode> modifiedCall = visitSubstring(call);
if (modifiedCall.isPresent()) {
return modifiedCall.get();
}
}

return super.visitCall(call);
}

// Hive allows passing in a byte array or String to substr/substring, so we can make an effort to emulate the
// behavior by casting non-String input to String
// https://cwiki.apache.org/confluence/display/hive/languagemanual+udf
private Optional<RexNode> visitSubstring(RexCall call) {
final SqlOperator op = call.getOperator();
List<RexNode> convertedOperands = visitList(call.getOperands(), (boolean[]) null);
RexNode inputOperand = convertedOperands.get(0);

if (inputOperand.getType().getSqlTypeName() != VARCHAR && inputOperand.getType().getSqlTypeName() != CHAR) {
List<RexNode> operands = new ImmutableList.Builder<RexNode>()
.add(rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), inputOperand))
.addAll(convertedOperands.subList(1, convertedOperands.size())).build();
return Optional.of(rexBuilder.makeCall(op, operands));
}

return Optional.empty();
}

private Optional<RexNode> visitCast(RexCall call) {
final SqlOperator op = call.getOperator();
if (op.getKind() != SqlKind.CAST) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.linkedin.coral.trino.rel2trino.transformers.FromUtcTimestampOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.SubstrOperatorTransformer;


/**
Expand All @@ -41,7 +42,7 @@ public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode to
TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode);
operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil),
new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil),
new ConcatOperatorTransformer(typeDerivationUtil));
new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlBasicTypeNameSpec;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlDataTypeSpec;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;

import com.linkedin.coral.common.HiveTypeSystem;
import com.linkedin.coral.common.transformers.SqlCallTransformer;
import com.linkedin.coral.common.utils.TypeDerivationUtil;

import static org.apache.calcite.sql.parser.SqlParserPos.*;
import static org.apache.calcite.sql.type.SqlTypeName.*;


/**
* This class implements the transformation of SqlCalls with Coral IR function `SUBSTR`
* to their corresponding Trino-compatible versions.
*
* For example:
* Given table:
* t1(int_col INTEGER, time_col timestamp)
* and a Coral IR SqlCall:
* `SUBSTR(time_col, 12, 8)`
*
* The transformed SqlCall would be:
* `SUBSTR(CAST(time_col AS VARCHAR(65535)), 12, 8)`
*/
public class SubstrOperatorTransformer extends SqlCallTransformer {

private static final int DEFAULT_VARCHAR_PRECISION = new HiveTypeSystem().getDefaultPrecision(SqlTypeName.VARCHAR);
private static final String SUBSTR_OPERATOR_NAME = "substr";
private static final Set<SqlTypeName> OPERAND_SQL_TYPE_NAMES =
new HashSet<>(Arrays.asList(SqlTypeName.VARCHAR, SqlTypeName.CHAR));
private static final SqlDataTypeSpec VARCHAR_SQL_DATA_TYPE_SPEC =
new SqlDataTypeSpec(new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, DEFAULT_VARCHAR_PRECISION, ZERO), ZERO);

public SubstrOperatorTransformer(TypeDerivationUtil typeDerivationUtil) {
super(typeDerivationUtil);
}

@Override
protected boolean condition(SqlCall sqlCall) {
return sqlCall.getOperator().getName().equalsIgnoreCase(SUBSTR_OPERATOR_NAME);
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
List<SqlNode> operands = sqlCall.getOperandList();
RelDataType relDataTypeOfOperand = deriveRelDatatype(operands.get(0));

// Coral IR accepts a byte array or String as an input for the `substr` operator.
// This behavior is emulated by casting non-String input to String in this transformer
// https://cwiki.apache.org/confluence/display/hive/languagemanual+udf
if (!OPERAND_SQL_TYPE_NAMES.contains(relDataTypeOfOperand.getSqlTypeName())) {
List<SqlNode> modifiedOperands = new ArrayList<>();

modifiedOperands.add(SqlStdOperatorTable.CAST.createCall(ZERO, operands.get(0), VARCHAR_SQL_DATA_TYPE_SPEC));
modifiedOperands.addAll(operands.subList(1, operands.size()));

return sqlCall.getOperator().createCall(SqlParserPos.ZERO, modifiedOperands);
}
return sqlCall;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,35 @@ public void testCastNestedTimestampToDecimal() {
assertEquals(expandedSql, targetSql);
}

@Test
public void testSubstrWithTimestampOperator() {
RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter();

RelNode relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT substring(from_utc_timestamp(a_bigint,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp");
String targetSql =
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(\"from_unixtime_nanos\"(CAST(\"table_from_utc_timestamp\".\"a_bigint\" AS BIGINT) * 1000000), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp\"";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);

relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT substring(from_utc_timestamp(a_decimal_three,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp");
targetSql =
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(CAST(\"format_datetime\"(\"from_unixtime\"(CAST(\"table_from_utc_timestamp0\".\"a_decimal_three\" AS DOUBLE)), 'yyyy-MM-dd HH:mm:ss') AS TIMESTAMP), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp0\"";
expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);

relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT substring(from_utc_timestamp(a_timestamp,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp");
targetSql =
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(CAST(\"format_datetime\"(\"from_unixtime\"(\"to_unixtime\"(\"with_timezone\"(\"table_from_utc_timestamp1\".\"a_timestamp\", 'UTC'))), 'yyyy-MM-dd HH:mm:ss') AS TIMESTAMP), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp1\"";
expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
}

@Test
public void testTranslateFunction() {
RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter();
Expand Down

0 comments on commit 482fe22

Please sign in to comment.