From 08e43df50bed6e89876a68fa3f9dfcba6efb7953 Mon Sep 17 00:00:00 2001 From: Aastha Agrrawal Date: Fri, 15 Dec 2023 10:59:31 -0800 Subject: [PATCH] Introduce OrdinalReturnTypeInferenceV2 to infer RexCall's return type (#481) * /OrdinalReturnTypeInferenceV2 to infer RexCall return type * add cast_nullability as an inbuilt function * spotless fix * coral-hive spotlesscheck * modify UDF name * coral-schema spotlessApply --- .../OrdinalReturnTypeInferenceV2.java | 26 +++++++++ .../functions/StaticHiveFunctionRegistry.java | 2 + .../schema/avro/RelDataTypeToAvroType.java | 2 + .../schema/avro/RelToAvroSchemaConverter.java | 35 ++++++++--- .../linkedin/coral/schema/avro/TestUtils.java | 2 + .../avro/ViewToAvroSchemaConverterTests.java | 14 +++++ .../testLiGrootCastNullability-expected.avsc | 58 +++++++++++++++++++ 7 files changed, 131 insertions(+), 8 deletions(-) create mode 100644 coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/OrdinalReturnTypeInferenceV2.java create mode 100644 coral-schema/src/test/resources/testLiGrootCastNullability-expected.avsc diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/OrdinalReturnTypeInferenceV2.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/OrdinalReturnTypeInferenceV2.java new file mode 100644 index 000000000..f41079ab1 --- /dev/null +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/OrdinalReturnTypeInferenceV2.java @@ -0,0 +1,26 @@ +/** + * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.hive.hive2rel.functions; + +import org.apache.calcite.sql.type.OrdinalReturnTypeInference; + + +/** + * Custom implementation of {@link OrdinalReturnTypeInference} which allows inferring the return type + * based on the ordinal of a given input argument and also exposes the ordinal. + */ +public class OrdinalReturnTypeInferenceV2 extends OrdinalReturnTypeInference { + private final int ordinal; + + public OrdinalReturnTypeInferenceV2(int ordinal) { + super(ordinal); + this.ordinal = ordinal; + } + + public int getOrdinal() { + return ordinal; + } +} diff --git a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java index c882e4fe0..c9e9fd67b 100644 --- a/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java +++ b/coral-hive/src/main/java/com/linkedin/coral/hive/hive2rel/functions/StaticHiveFunctionRegistry.java @@ -667,6 +667,8 @@ public boolean isOptional(int i) { family(SqlTypeFamily.STRING, SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP)); createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactFieldIf", ARG1, family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.ANY)); + createAddUserDefinedFunction("li_groot_cast_nullability", new OrdinalReturnTypeInferenceV2(1), + family(SqlTypeFamily.ANY, SqlTypeFamily.ANY)); createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactSecondarySchemaFieldIf", ARG1, family( SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY)); diff --git a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelDataTypeToAvroType.java b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelDataTypeToAvroType.java index 5b2f04406..53963c321 100644 --- a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelDataTypeToAvroType.java +++ b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelDataTypeToAvroType.java @@ -84,6 +84,8 @@ static Schema relDataTypeToAvroTypeNonNullable(@Nonnull RelDataType relDataType, private static Schema relDataTypeToAvroType(RelDataType relDataType, String recordName) { final Schema avroSchema = relDataTypeToAvroTypeNonNullable(relDataType, recordName); + // TODO: Current logic ALWAYS sets the inner fields of RelDataType record nullable. + // Modify this to be applied only when RelDataType record was generated from a HIVE_UDF RexCall return SchemaUtilities.makeNullable(avroSchema, false); } diff --git a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelToAvroSchemaConverter.java b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelToAvroSchemaConverter.java index 6c94ac56c..01109a372 100644 --- a/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelToAvroSchemaConverter.java +++ b/coral-schema/src/main/java/com/linkedin/coral/schema/avro/RelToAvroSchemaConverter.java @@ -59,6 +59,7 @@ import com.linkedin.coral.com.google.common.base.Preconditions; import com.linkedin.coral.common.HiveMetastoreClient; import com.linkedin.coral.common.HiveUncollect; +import com.linkedin.coral.hive.hive2rel.functions.OrdinalReturnTypeInferenceV2; /** @@ -407,14 +408,7 @@ public SchemaRexShuttle(Schema inputSchema, RelNode inputNode, Queue sug @Override public RexNode visitInputRef(RexInputRef rexInputRef) { RexNode rexNode = super.visitInputRef(rexInputRef); - - Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex()); - String oldFieldName = field.name(); - String suggestNewFieldName = suggestedFieldNames.poll(); - String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName); - - SchemaUtilities.appendField(newFieldName, field, fieldAssembler); - + appendRexInputRefField(rexInputRef); return rexNode; } @@ -442,6 +436,22 @@ public RexNode visitCall(RexCall rexCall) { * For SqlUserDefinedFunction and SqlOperator RexCall, no need to handle it recursively * and only return type of udf or sql operator is relevant */ + + /** + * If the return type of RexCall is based on the ordinal of its input argument + * and the corresponding input argument refers to a field from the input schema, + * use the field's schema as is. + */ + if (rexCall.getOperator().getReturnTypeInference() instanceof OrdinalReturnTypeInferenceV2) { + int index = ((OrdinalReturnTypeInferenceV2) rexCall.getOperator().getReturnTypeInference()).getOrdinal(); + RexNode operand = rexCall.operands.get(index); + + if (operand instanceof RexInputRef) { + appendRexInputRefField((RexInputRef) operand); + return rexCall; + } + } + RelDataType fieldType = rexCall.getType(); boolean isNullable = SchemaUtilities.isFieldNullable(rexCall, inputSchema); @@ -545,6 +555,15 @@ public RexNode visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) { return super.visitPatternFieldRef(rexPatternFieldRef); } + private void appendRexInputRefField(RexInputRef rexInputRef) { + Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex()); + String oldFieldName = field.name(); + String suggestNewFieldName = suggestedFieldNames.poll(); + String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName); + + SchemaUtilities.appendField(newFieldName, field, fieldAssembler); + } + private void appendField(RelDataType fieldType, boolean isNullable, String doc) { String fieldName = SchemaUtilities.getFieldName("", suggestedFieldNames.poll()); SchemaUtilities.appendField(fieldName, fieldType, doc, fieldAssembler, isNullable); diff --git a/coral-schema/src/test/java/com/linkedin/coral/schema/avro/TestUtils.java b/coral-schema/src/test/java/com/linkedin/coral/schema/avro/TestUtils.java index 73340ec98..cb79495f4 100644 --- a/coral-schema/src/test/java/com/linkedin/coral/schema/avro/TestUtils.java +++ b/coral-schema/src/test/java/com/linkedin/coral/schema/avro/TestUtils.java @@ -99,6 +99,7 @@ private static void initializeTables() { String baseComplexUnionTypeSchema = loadSchema("base-complex-union-type.avsc"); String baseNestedUnionSchema = loadSchema("base-nested-union.avsc"); String baseComplexLowercase = loadSchema("base-complex-lowercase.avsc"); + String baseComplexNonNullable = loadSchema("base-complex-non-nullable.avsc"); String baseComplexNullableWithDefaults = loadSchema("base-complex-nullable-with-defaults.avsc"); String basePrimitive = loadSchema("base-primitive.avsc"); String baseComplexNestedStructSameName = loadSchema("base-complex-nested-struct-same-name.avsc"); @@ -121,6 +122,7 @@ private static void initializeTables() { executeCreateTableWithPartitionFieldSchemaQuery("default", "basecomplexfieldschema", baseComplexFieldSchema); executeCreateTableWithPartitionQuery("default", "basenestedcomplex", baseNestedComplexSchema); executeCreateTableWithPartitionQuery("default", "basecomplexnullablewithdefaults", baseComplexNullableWithDefaults); + executeCreateTableWithPartitionQuery("default", "basecomplexnonnullable", baseComplexNonNullable); String baseComplexSchemaWithDoc = loadSchema("docTestResources/base-complex-with-doc.avsc"); String baseEnumSchemaWithDoc = loadSchema("docTestResources/base-enum-with-doc.avsc"); diff --git a/coral-schema/src/test/java/com/linkedin/coral/schema/avro/ViewToAvroSchemaConverterTests.java b/coral-schema/src/test/java/com/linkedin/coral/schema/avro/ViewToAvroSchemaConverterTests.java index 69d28eef3..4bf8130c5 100644 --- a/coral-schema/src/test/java/com/linkedin/coral/schema/avro/ViewToAvroSchemaConverterTests.java +++ b/coral-schema/src/test/java/com/linkedin/coral/schema/avro/ViewToAvroSchemaConverterTests.java @@ -1102,5 +1102,19 @@ public void testDivideReturnType() { Assert.assertEquals(actualSchema.toString(true), TestUtils.loadSchema("testDivideReturnType-expected.avsc")); } + @Test + public void testLiGrootCastNullability() { + ViewToAvroSchemaConverter viewToAvroSchemaConverter = ViewToAvroSchemaConverter.create(hiveMetastoreClient); + + Schema schemaWithUDF = viewToAvroSchemaConverter + .toAvroSchema("SELECT li_groot_cast_nullability(Struct_Col, Struct_Col) AS modCol FROM basecomplexnonnullable"); + Schema schemaWithField = + viewToAvroSchemaConverter.toAvroSchema("SELECT Struct_Col AS modCol FROM basecomplexnonnullable"); + + Assert.assertEquals(schemaWithUDF.toString(true), TestUtils.loadSchema("testLiGrootCastNullability-expected.avsc")); + Assert.assertEquals(schemaWithField.toString(true), + TestUtils.loadSchema("testLiGrootCastNullability-expected.avsc")); + } + // TODO: add more unit tests } diff --git a/coral-schema/src/test/resources/testLiGrootCastNullability-expected.avsc b/coral-schema/src/test/resources/testLiGrootCastNullability-expected.avsc new file mode 100644 index 000000000..c5b5f18ca --- /dev/null +++ b/coral-schema/src/test/resources/testLiGrootCastNullability-expected.avsc @@ -0,0 +1,58 @@ +{ + "type" : "record", + "name" : "basecomplexnonnullable", + "namespace" : "coral.schema.avro.base.complex.nonnullable", + "fields" : [ { + "name" : "modCol", + "type" : { + "type" : "record", + "name" : "Struct_col", + "namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable", + "fields" : [ { + "name" : "Bool_Field", + "type" : "boolean" + }, { + "name" : "Int_Field", + "type" : "int" + }, { + "name" : "Bigint_Field", + "type" : "long" + }, { + "name" : "Float_Field", + "type" : "float" + }, { + "name" : "Double_Field", + "type" : "double" + }, { + "name" : "Date_String_Field", + "type" : "string" + }, { + "name" : "String_Field", + "type" : "string" + }, { + "name" : "Array_Col", + "type" : { + "type" : "array", + "items" : { + "type" : "record", + "name" : "Struct_col", + "namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable.basecomplexnonnullable", + "fields" : [ { + "name" : "key", + "type" : "string" + }, { + "name" : "value", + "type" : "string" + } ] + } + } + }, { + "name" : "Map_Col", + "type" : { + "type" : "map", + "values" : "string" + } + } ] + } + } ] +} \ No newline at end of file