diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java index a898a5df6..62b905ca9 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java @@ -298,6 +298,13 @@ public boolean isOptional(int i) { or(family(SqlTypeFamily.STRING), family(SqlTypeFamily.BINARY))); createAddUserDefinedFunction("crc32", BIGINT, or(family(SqlTypeFamily.STRING), family(SqlTypeFamily.BINARY))); createAddUserDefinedFunction("from_utf8", explicit(SqlTypeName.VARCHAR), or(CHARACTER, BINARY)); + createAddUserDefinedFunction("at_timezone", explicit(SqlTypeName.TIMESTAMP), + family(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.STRING)); + createAddUserDefinedFunction("with_timezone", explicit(SqlTypeName.TIMESTAMP), + family(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.STRING)); + createAddUserDefinedFunction("to_unixtime", explicit(SqlTypeName.DOUBLE), family(SqlTypeFamily.TIMESTAMP)); + createAddUserDefinedFunction("from_unixtime_nanos", explicit(SqlTypeName.TIMESTAMP), NUMERIC); + createAddUserDefinedFunction("$canonicalize_hive_timezone_id", explicit(SqlTypeName.VARCHAR), STRING); // xpath functions createAddUserDefinedFunction("xpath", FunctionReturnTypes.arrayOfType(SqlTypeName.VARCHAR), STRING_STRING); diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java index 6cb29c496..b8fb269d5 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/Calcite2TrinoUDFConverter.java @@ -38,8 +38,6 @@ import org.apache.calcite.sql.type.SqlReturnTypeInference; import org.apache.calcite.sql.validate.SqlUserDefinedFunction; -import com.linkedin.coral.com.google.common.collect.ImmutableList; - import static com.linkedin.coral.trino.rel2trino.CoralTrinoConfigKeys.*; import static org.apache.calcite.sql.type.ReturnTypes.explicit; import static org.apache.calcite.sql.type.SqlTypeName.*; @@ -167,34 +165,9 @@ public RexNode visitCall(RexCall call) { } } - if (operatorName.equalsIgnoreCase("substr")) { - Optional modifiedCall = visitSubstring(call); - if (modifiedCall.isPresent()) { - return modifiedCall.get(); - } - } - return super.visitCall(call); } - // Hive allows passing in a byte array or String to substr/substring, so we can make an effort to emulate the - // behavior by casting non-String input to String - // https://cwiki.apache.org/confluence/display/hive/languagemanual+udf - private Optional visitSubstring(RexCall call) { - final SqlOperator op = call.getOperator(); - List convertedOperands = visitList(call.getOperands(), (boolean[]) null); - RexNode inputOperand = convertedOperands.get(0); - - if (inputOperand.getType().getSqlTypeName() != VARCHAR && inputOperand.getType().getSqlTypeName() != CHAR) { - List operands = new ImmutableList.Builder() - .add(rexBuilder.makeCast(typeFactory.createSqlType(VARCHAR), inputOperand)) - .addAll(convertedOperands.subList(1, convertedOperands.size())).build(); - return Optional.of(rexBuilder.makeCall(op, operands)); - } - - return Optional.empty(); - } - private Optional visitCast(RexCall call) { final SqlOperator op = call.getOperator(); if (op.getKind() != SqlKind.CAST) { diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java index c6334c953..546edec86 100644 --- a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/DataTypeDerivedSqlCallConverter.java @@ -20,6 +20,7 @@ import com.linkedin.coral.trino.rel2trino.transformers.FromUtcTimestampOperatorTransformer; import com.linkedin.coral.trino.rel2trino.transformers.GenericProjectTransformer; import com.linkedin.coral.trino.rel2trino.transformers.NamedStructToCastTransformer; +import com.linkedin.coral.trino.rel2trino.transformers.SubstrOperatorTransformer; /** @@ -41,7 +42,7 @@ public DataTypeDerivedSqlCallConverter(HiveMetastoreClient mscClient, SqlNode to TypeDerivationUtil typeDerivationUtil = new TypeDerivationUtil(toRelConverter.getSqlValidator(), topSqlNode); operatorTransformerList = SqlCallTransformers.of(new FromUtcTimestampOperatorTransformer(typeDerivationUtil), new GenericProjectTransformer(typeDerivationUtil), new NamedStructToCastTransformer(typeDerivationUtil), - new ConcatOperatorTransformer(typeDerivationUtil)); + new ConcatOperatorTransformer(typeDerivationUtil), new SubstrOperatorTransformer(typeDerivationUtil)); } @Override diff --git a/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/SubstrOperatorTransformer.java b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/SubstrOperatorTransformer.java new file mode 100644 index 000000000..6e4c3d834 --- /dev/null +++ b/coral-trino/src/main/java/com/linkedin/coral/trino/rel2trino/transformers/SubstrOperatorTransformer.java @@ -0,0 +1,80 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.trino.rel2trino.transformers; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlBasicTypeNameSpec; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeName; + +import com.linkedin.coral.common.HiveTypeSystem; +import com.linkedin.coral.common.transformers.SqlCallTransformer; +import com.linkedin.coral.common.utils.TypeDerivationUtil; + +import static org.apache.calcite.sql.parser.SqlParserPos.*; +import static org.apache.calcite.sql.type.SqlTypeName.*; + + +/** + * This class implements the transformation of SqlCalls with Coral IR function `SUBSTR` + * to their corresponding Trino-compatible versions. + * + * For example: + * Given table: + * t1(int_col INTEGER, time_col timestamp) + * and a Coral IR SqlCall: + * `SUBSTR(time_col, 12, 8)` + * + * The transformed SqlCall would be: + * `SUBSTR(CAST(time_col AS VARCHAR(65535)), 12, 8)` + */ +public class SubstrOperatorTransformer extends SqlCallTransformer { + + private static final int DEFAULT_VARCHAR_PRECISION = new HiveTypeSystem().getDefaultPrecision(SqlTypeName.VARCHAR); + private static final String SUBSTR_OPERATOR_NAME = "substr"; + private static final Set OPERAND_SQL_TYPE_NAMES = + new HashSet<>(Arrays.asList(SqlTypeName.VARCHAR, SqlTypeName.CHAR)); + private static final SqlDataTypeSpec VARCHAR_SQL_DATA_TYPE_SPEC = + new SqlDataTypeSpec(new SqlBasicTypeNameSpec(SqlTypeName.VARCHAR, DEFAULT_VARCHAR_PRECISION, ZERO), ZERO); + + public SubstrOperatorTransformer(TypeDerivationUtil typeDerivationUtil) { + super(typeDerivationUtil); + } + + @Override + protected boolean condition(SqlCall sqlCall) { + return sqlCall.getOperator().getName().equalsIgnoreCase(SUBSTR_OPERATOR_NAME); + } + + @Override + protected SqlCall transform(SqlCall sqlCall) { + List operands = sqlCall.getOperandList(); + RelDataType relDataTypeOfOperand = deriveRelDatatype(operands.get(0)); + + // Coral IR accepts a byte array or String as an input for the `substr` operator. + // This behavior is emulated by casting non-String input to String in this transformer + // https://cwiki.apache.org/confluence/display/hive/languagemanual+udf + if (!OPERAND_SQL_TYPE_NAMES.contains(relDataTypeOfOperand.getSqlTypeName())) { + List modifiedOperands = new ArrayList<>(); + + modifiedOperands.add(SqlStdOperatorTable.CAST.createCall(ZERO, operands.get(0), VARCHAR_SQL_DATA_TYPE_SPEC)); + modifiedOperands.addAll(operands.subList(1, operands.size())); + + return sqlCall.getOperator().createCall(SqlParserPos.ZERO, modifiedOperands); + } + return sqlCall; + } +} diff --git a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java index 2bad16a47..754ccecfc 100644 --- a/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java +++ b/coral-trino/src/test/java/com/linkedin/coral/trino/rel2trino/HiveToTrinoConverterTest.java @@ -486,6 +486,35 @@ public void testCastNestedTimestampToDecimal() { assertEquals(expandedSql, targetSql); } + @Test + public void testSubstrWithTimestampOperator() { + RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter(); + + RelNode relNode = TestUtils.getHiveToRelConverter().convertSql( + "SELECT substring(from_utc_timestamp(a_bigint,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp"); + String targetSql = + "SELECT \"substr\"(CAST(CAST(\"at_timezone\"(\"from_unixtime_nanos\"(CAST(\"table_from_utc_timestamp\".\"a_bigint\" AS BIGINT) * 1000000), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n" + + "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp\""; + String expandedSql = relToTrinoConverter.convert(relNode); + assertEquals(expandedSql, targetSql); + + relNode = TestUtils.getHiveToRelConverter().convertSql( + "SELECT substring(from_utc_timestamp(a_decimal_three,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp"); + targetSql = + "SELECT \"substr\"(CAST(CAST(\"at_timezone\"(CAST(\"format_datetime\"(\"from_unixtime\"(CAST(\"table_from_utc_timestamp0\".\"a_decimal_three\" AS DOUBLE)), 'yyyy-MM-dd HH:mm:ss') AS TIMESTAMP), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n" + + "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp0\""; + expandedSql = relToTrinoConverter.convert(relNode); + assertEquals(expandedSql, targetSql); + + relNode = TestUtils.getHiveToRelConverter().convertSql( + "SELECT substring(from_utc_timestamp(a_timestamp,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp"); + targetSql = + "SELECT \"substr\"(CAST(CAST(\"at_timezone\"(CAST(\"format_datetime\"(\"from_unixtime\"(\"to_unixtime\"(\"with_timezone\"(\"table_from_utc_timestamp1\".\"a_timestamp\", 'UTC'))), 'yyyy-MM-dd HH:mm:ss') AS TIMESTAMP), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n" + + "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp1\""; + expandedSql = relToTrinoConverter.convert(relNode); + assertEquals(expandedSql, targetSql); + } + @Test public void testTranslateFunction() { RelToTrinoConverter relToTrinoConverter = TestUtils.getRelToTrinoConverter();