Skip to content

Commit

Permalink
[Coral-Trino] Fix substring start index issue (#499)
Browse files Browse the repository at this point in the history
* use greatest between 1 and substring start index

* update test

* spotless apply

* spotless apply

* SUBSTRING operator UT

* dedicated substring index transfer + delete old transformer + update tests

* match tests

* clean up

* use ImmutableSet for substring operator names
  • Loading branch information
KevinGe00 authored Apr 11, 2024
1 parent e09015e commit 0d5dd3f
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Copyright 2023-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 LinkedIn Corporation. All rights reserved.
* Copyright 2017-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
Expand Down Expand Up @@ -34,6 +34,7 @@
import com.linkedin.coral.trino.rel2trino.transformers.NullOrderingTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.ReturnTypeAdjustmentTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.SqlSelectAliasAppenderTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.SubstrIndexTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.ToDateOperatorTransformer;
import com.linkedin.coral.trino.rel2trino.transformers.UnnestOperatorTransformer;

Expand Down Expand Up @@ -72,17 +73,6 @@ protected SqlCall transform(SqlCall sqlCall) {
"[{\"op\":\"*\",\"operands\":[{\"input\":1},{\"op\":\"^\",\"operands\":[{\"value\":10},{\"input\":2}]}]}]",
"{\"op\":\"/\",\"operands\":[{\"input\":0},{\"op\":\"^\",\"operands\":[{\"value\":10},{\"input\":2}]}]}",
null),
// string functions
new JsonTransformSqlCallTransformer(SqlStdOperatorTable.SUBSTRING, 2, "substr",
"[{\"input\": 1}, {\"op\": \"+\", \"operands\": [{\"input\": 2}, {\"value\": 1}]}]", null, null),
new JsonTransformSqlCallTransformer(SqlStdOperatorTable.SUBSTRING, 3, "substr",
"[{\"input\": 1}, {\"op\": \"+\", \"operands\": [{\"input\": 2}, {\"value\": 1}]}, {\"input\": 3}]", null,
null),
new JsonTransformSqlCallTransformer(hiveToCoralSqlOperator("substr"), 2, "substr",
"[{\"input\": 1}, {\"op\": \"+\", \"operands\": [{\"input\": 2}, {\"value\": 1}]}]", null, null),
new JsonTransformSqlCallTransformer(hiveToCoralSqlOperator("substr"), 3, "substr",
"[{\"input\": 1}, {\"op\": \"+\", \"operands\": [{\"input\": 2}, {\"value\": 1}]}, {\"input\": 3}]", null,
null),
// JSON functions
new CoralRegistryOperatorRenameSqlCallTransformer("get_json_object", 2, "json_extract"),
// map various hive functions
Expand Down Expand Up @@ -135,7 +125,7 @@ protected SqlCall transform(SqlCall sqlCall) {
new GenericCoralRegistryOperatorRenameSqlCallTransformer(),

new ReturnTypeAdjustmentTransformer(configs), new UnnestOperatorTransformer(), new AsOperatorTransformer(),
new JoinSqlCallTransformer(), new NullOrderingTransformer());
new JoinSqlCallTransformer(), new NullOrderingTransformer(), new SubstrIndexTransformer());
}

private SqlOperator hiveToCoralSqlOperator(String functionName) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* Copyright 2023-2024 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.trino.rel2trino.transformers;

import java.util.Arrays;
import java.util.List;
import java.util.Set;

import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlIdentifier;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNumericLiteral;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;

import com.linkedin.coral.com.google.common.collect.ImmutableSet;
import com.linkedin.coral.common.calcite.CalciteUtil;
import com.linkedin.coral.common.transformers.SqlCallTransformer;

import static org.apache.calcite.rel.rel2sql.SqlImplementor.*;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.*;


/**
* This class transforms the substr indexing in the input SqlCall to be compatible with Trino engine.
* Trino uses 1-based indexing for substr, so the lowest possible index is 1. While other engines like Hive
* allow for 0 as a valid index.
*
* This transformer guarantees that starting index will always 1 or greater.
*/
public class SubstrIndexTransformer extends SqlCallTransformer {
private final static Set<String> SUBSTRING_OPERATORS = ImmutableSet.of("substr", "substring");
@Override
protected boolean condition(SqlCall sqlCall) {
return SUBSTRING_OPERATORS.contains(sqlCall.getOperator().getName().toLowerCase());
}

@Override
protected SqlCall transform(SqlCall sqlCall) {
final List<SqlNode> operandList = sqlCall.getOperandList();
SqlNode start = operandList.get(1);
if (start instanceof SqlNumericLiteral) {
int startInt = ((SqlNumericLiteral) operandList.get(1)).getValueAs(Integer.class);

if (startInt == 0) {
SqlNumericLiteral newStart = SqlNumericLiteral.createExactNumeric(String.valueOf(1), POS);
sqlCall.setOperand(1, newStart);
}

} else if (start instanceof SqlIdentifier) {
// If we don't have a literal start index value, we need to use a case statement with the column identifier to ensure the value is always 1 or greater
// So instead of just "col_name" as the start index, we have "CASE WHEN col_name = 0 THEN 1 ELSE col_name END"
List<SqlNode> whenClauses = Arrays
.asList(SqlStdOperatorTable.EQUALS.createCall(POS, start, SqlNumericLiteral.createExactNumeric("0", POS)));
List<SqlNode> thenClauses = Arrays.asList(SqlNumericLiteral.createExactNumeric("1", POS));

sqlCall.setOperand(1, CASE.createCall(null, POS, null, CalciteUtil.createSqlNodeList(whenClauses, POS),
CalciteUtil.createSqlNodeList(thenClauses, POS), start));
}

return sqlCall;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public Object[][] viewTestCasesProvider() {
{ "test", "view_from_utc_timestamp", "SELECT CAST(\"at_timezone\"(\"from_unixtime_nanos\"(CAST(\"table_from_utc_timestamp\".\"a_tinyint\" AS BIGINT) * 1000000), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime_nanos\"(CAST(\"table_from_utc_timestamp\".\"a_smallint\" AS BIGINT) * 1000000), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime_nanos\"(CAST(\"table_from_utc_timestamp\".\"a_integer\" AS BIGINT) * 1000000), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime_nanos\"(CAST(\"table_from_utc_timestamp\".\"a_bigint\" AS BIGINT) * 1000000), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime\"(CAST(\"table_from_utc_timestamp\".\"a_float\" AS DOUBLE)), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime\"(CAST(\"table_from_utc_timestamp\".\"a_double\" AS DOUBLE)), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime\"(CAST(\"table_from_utc_timestamp\".\"a_decimal_three\" AS DOUBLE)), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime\"(CAST(\"table_from_utc_timestamp\".\"a_decimal_zero\" AS DOUBLE)), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime\"(\"to_unixtime\"(\"with_timezone\"(\"table_from_utc_timestamp\".\"a_timestamp\", 'UTC'))), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3)), CAST(\"at_timezone\"(\"from_unixtime\"(\"to_unixtime\"(\"with_timezone\"(CAST(\"table_from_utc_timestamp\".\"a_date\" AS TIMESTAMP), 'UTC'))), \"$canonicalize_hive_timezone_id\"('America/Los_Angeles')) AS TIMESTAMP(3))\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp\"" },

{ "test", "date_calculation_view", "SELECT \"date\"(CAST(\"substr\"('2021-08-20', 1 + 1, 10) AS TIMESTAMP)), \"date\"(CAST('2021-08-20' AS TIMESTAMP)), \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP)), \"date_add\"('day', 1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))), \"date_add\"('day', 1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))), \"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))), \"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))), CAST(\"date_diff\"('day', \"date\"(CAST('2021-08-21' AS TIMESTAMP)), \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS INTEGER), CAST(\"date_diff\"('day', \"date\"(CAST('2021-08-19' AS TIMESTAMP)), \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS INTEGER), CAST(\"date_diff\"('day', \"date\"(CAST('2021-08-19 23:59:59' AS TIMESTAMP)), \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS INTEGER)\n"
{ "test", "date_calculation_view", "SELECT \"date\"(CAST(\"substr\"('2021-08-20', 1, 10) AS TIMESTAMP)), \"date\"(CAST('2021-08-20' AS TIMESTAMP)), \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP)), \"date_add\"('day', 1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))), \"date_add\"('day', 1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))), \"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20' AS TIMESTAMP))), \"date_add\"('day', 1 * -1, \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))), CAST(\"date_diff\"('day', \"date\"(CAST('2021-08-21' AS TIMESTAMP)), \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS INTEGER), CAST(\"date_diff\"('day', \"date\"(CAST('2021-08-19' AS TIMESTAMP)), \"date\"(CAST('2021-08-20' AS TIMESTAMP))) AS INTEGER), CAST(\"date_diff\"('day', \"date\"(CAST('2021-08-19 23:59:59' AS TIMESTAMP)), \"date\"(CAST('2021-08-20 00:00:00' AS TIMESTAMP))) AS INTEGER)\n"
+ "FROM \"test\".\"tablea\" AS \"tablea\"" },

{ "test", "pmod_view", "SELECT MOD(MOD(- 9, 4) + 4, 4)\n" + "FROM \"test\".\"tablea\" AS \"tablea\"" },
Expand Down Expand Up @@ -361,7 +361,7 @@ public void testLateralViewOuterPosExplodeWithAlias() {
public void testAvoidTransformToDate() {
RelNode relNode = TestUtils.getHiveToRelConverter()
.convertSql("SELECT to_date(substr('2021-08-20', 1, 10)), to_date('2021-08-20')" + "FROM test.tableA");
String targetSql = "SELECT \"to_date\"(\"substr\"('2021-08-20', 1 + 1, 10)), \"to_date\"('2021-08-20')\n"
String targetSql = "SELECT \"to_date\"(\"substr\"('2021-08-20', 1, 10)), \"to_date\"('2021-08-20')\n"
+ "FROM \"test\".\"tablea\" AS \"tablea\"";

RelToTrinoConverter relToTrinoConverter =
Expand Down Expand Up @@ -598,23 +598,23 @@ public void testSubstrWithTimestampOperator() {
RelNode relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT substring(from_utc_timestamp(a_bigint,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp");
String targetSql =
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(\"from_unixtime_nanos\"(CAST(\"table_from_utc_timestamp\".\"a_bigint\" AS BIGINT) * 1000000), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1 + 1, 10) AS \"d\"\n"
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(\"from_unixtime_nanos\"(CAST(\"table_from_utc_timestamp\".\"a_bigint\" AS BIGINT) * 1000000), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp\"";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);

relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT substring(from_utc_timestamp(a_decimal_three,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp");
targetSql =
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(\"from_unixtime\"(CAST(\"table_from_utc_timestamp0\".\"a_decimal_three\" AS DOUBLE)), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1 + 1, 10) AS \"d\"\n"
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(\"from_unixtime\"(CAST(\"table_from_utc_timestamp0\".\"a_decimal_three\" AS DOUBLE)), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp0\"";
expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);

relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT substring(from_utc_timestamp(a_timestamp,'PST'),1,10) AS d\nFROM test.table_from_utc_timestamp");
targetSql =
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(\"from_unixtime\"(\"to_unixtime\"(\"with_timezone\"(\"table_from_utc_timestamp1\".\"a_timestamp\", 'UTC'))), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1 + 1, 10) AS \"d\"\n"
"SELECT \"substr\"(CAST(CAST(\"at_timezone\"(\"from_unixtime\"(\"to_unixtime\"(\"with_timezone\"(\"table_from_utc_timestamp1\".\"a_timestamp\", 'UTC'))), \"$canonicalize_hive_timezone_id\"('PST')) AS TIMESTAMP(3)) AS VARCHAR(65535)), 1, 10) AS \"d\"\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp1\"";
expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
Expand Down Expand Up @@ -750,15 +750,15 @@ public void testSubstrWithTimestamp() {
RelNode relNode = TestUtils.getHiveToRelConverter()
.convertSql("SELECT SUBSTR(a_timestamp, 12, 8) AS d\nFROM test.table_from_utc_timestamp");
String targetSql =
"SELECT \"substr\"(CAST(\"table_from_utc_timestamp\".\"a_timestamp\" AS VARCHAR(65535)), 12 + 1, 8) AS \"d\"\n"
"SELECT \"substr\"(CAST(\"table_from_utc_timestamp\".\"a_timestamp\" AS VARCHAR(65535)), 12, 8) AS \"d\"\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp\"";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);

relNode = TestUtils.getHiveToRelConverter()
.convertSql("SELECT SUBSTRING(a_timestamp, 12, 8) AS d\nFROM test.table_from_utc_timestamp");
targetSql =
"SELECT \"substr\"(CAST(\"table_from_utc_timestamp0\".\"a_timestamp\" AS VARCHAR(65535)), 12 + 1, 8) AS \"d\"\n"
"SELECT \"substr\"(CAST(\"table_from_utc_timestamp0\".\"a_timestamp\" AS VARCHAR(65535)), 12, 8) AS \"d\"\n"
+ "FROM \"test\".\"table_from_utc_timestamp\" AS \"table_from_utc_timestamp0\"";
expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
Expand All @@ -771,8 +771,8 @@ public void testAliasOrderByDESC() {
RelNode relNode = TestUtils.getHiveToRelConverter()
.convertSql("SELECT a, SUBSTR(b, 1, 1) AS aliased_column, c FROM test.tabler ORDER BY aliased_column DESC");
String targetSql =
"SELECT \"tabler\".\"a\" AS \"a\", \"substr\"(\"tabler\".\"b\", 1 + 1, 1) AS \"aliased_column\", \"tabler\".\"c\" AS \"c\"\n"
+ "FROM \"test\".\"tabler\" AS \"tabler\"\n" + "ORDER BY \"substr\"(\"tabler\".\"b\", 1 + 1, 1) DESC";
"SELECT \"tabler\".\"a\" AS \"a\", \"substr\"(\"tabler\".\"b\", 1, 1) AS \"aliased_column\", \"tabler\".\"c\" AS \"c\"\n"
+ "FROM \"test\".\"tabler\" AS \"tabler\"\n" + "ORDER BY \"substr\"(\"tabler\".\"b\", 1, 1) DESC";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
}
Expand All @@ -784,9 +784,9 @@ public void testAliasOrderByDESCMultipleOrderings() {
RelNode relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT a, SUBSTR(b, 1, 1) AS aliased_column, c FROM test.tabler ORDER BY aliased_column DESC, a DESC, c DESC");
String targetSql =
"SELECT \"tabler\".\"a\" AS \"a\", \"substr\"(\"tabler\".\"b\", 1 + 1, 1) AS \"aliased_column\", \"tabler\".\"c\" AS \"c\"\n"
"SELECT \"tabler\".\"a\" AS \"a\", \"substr\"(\"tabler\".\"b\", 1, 1) AS \"aliased_column\", \"tabler\".\"c\" AS \"c\"\n"
+ "FROM \"test\".\"tabler\" AS \"tabler\"\n"
+ "ORDER BY \"substr\"(\"tabler\".\"b\", 1 + 1, 1) DESC, \"tabler\".\"a\" DESC, \"tabler\".\"c\" DESC";
+ "ORDER BY \"substr\"(\"tabler\".\"b\", 1, 1) DESC, \"tabler\".\"a\" DESC, \"tabler\".\"c\" DESC";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
}
Expand All @@ -812,9 +812,8 @@ public void testAliasOrderByASC() {
.convertSql("SELECT a, SUBSTR(b, 1, 1) AS aliased_column, c FROM test.tabler ORDER BY aliased_column ASC");
// We want NULLS FIRST since we're translating from Hive and that is the default null ordering for ASC in Hive
String targetSql =
"SELECT \"tabler\".\"a\" AS \"a\", \"substr\"(\"tabler\".\"b\", 1 + 1, 1) AS \"aliased_column\", \"tabler\".\"c\" AS \"c\"\n"
+ "FROM \"test\".\"tabler\" AS \"tabler\"\n"
+ "ORDER BY \"substr\"(\"tabler\".\"b\", 1 + 1, 1) NULLS FIRST";
"SELECT \"tabler\".\"a\" AS \"a\", \"substr\"(\"tabler\".\"b\", 1, 1) AS \"aliased_column\", \"tabler\".\"c\" AS \"c\"\n"
+ "FROM \"test\".\"tabler\" AS \"tabler\"\n" + "ORDER BY \"substr\"(\"tabler\".\"b\", 1, 1) NULLS FIRST";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
}
Expand All @@ -825,10 +824,9 @@ public void testAliasHaving() {

RelNode relNode = TestUtils.getHiveToRelConverter().convertSql(
"SELECT a, SUBSTR(b, 1, 1) AS aliased_column FROM test.tabler GROUP BY a, b HAVING aliased_column in ('dummy_value')");
String targetSql =
"SELECT \"tabler\".\"a\" AS \"a\", \"substr\"(\"tabler\".\"b\", 1 + 1, 1) AS \"aliased_column\"\n"
+ "FROM \"test\".\"tabler\" AS \"tabler\"\n" + "GROUP BY \"tabler\".\"a\", \"tabler\".\"b\"\n"
+ "HAVING \"substr\"(\"tabler\".\"b\", 1 + 1, 1)\n" + "IN ('dummy_value')";
String targetSql = "SELECT \"tabler\".\"a\" AS \"a\", \"substr\"(\"tabler\".\"b\", 1, 1) AS \"aliased_column\"\n"
+ "FROM \"test\".\"tabler\" AS \"tabler\"\n" + "GROUP BY \"tabler\".\"a\", \"tabler\".\"b\"\n"
+ "HAVING \"substr\"(\"tabler\".\"b\", 1, 1)\n" + "IN ('dummy_value')";
String expandedSql = relToTrinoConverter.convert(relNode);
assertEquals(expandedSql, targetSql);
}
Expand Down
Loading

0 comments on commit 0d5dd3f

Please sign in to comment.