Skip to content

Commit

Permalink
Introduce OrdinalReturnTypeInferenceV2 to infer RexCall's return type (
Browse files Browse the repository at this point in the history
…#481)

* /OrdinalReturnTypeInferenceV2 to infer RexCall return type

* add cast_nullability as an inbuilt function

* spotless fix

* coral-hive spotlesscheck

* modify UDF name

* coral-schema spotlessApply
  • Loading branch information
aastha25 authored Dec 15, 2023
1 parent c03b1e3 commit 08e43df
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* Copyright 2023 LinkedIn Corporation. All rights reserved.
* Licensed under the BSD-2 Clause license.
* See LICENSE in the project root for license information.
*/
package com.linkedin.coral.hive.hive2rel.functions;

import org.apache.calcite.sql.type.OrdinalReturnTypeInference;


/**
* Custom implementation of {@link OrdinalReturnTypeInference} which allows inferring the return type
* based on the ordinal of a given input argument and also exposes the ordinal.
*/
public class OrdinalReturnTypeInferenceV2 extends OrdinalReturnTypeInference {
private final int ordinal;

public OrdinalReturnTypeInferenceV2(int ordinal) {
super(ordinal);
this.ordinal = ordinal;
}

public int getOrdinal() {
return ordinal;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,8 @@ public boolean isOptional(int i) {
family(SqlTypeFamily.STRING, SqlTypeFamily.ANY, SqlTypeFamily.TIMESTAMP));
createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactFieldIf", ARG1,
family(SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.STRING, SqlTypeFamily.ANY));
createAddUserDefinedFunction("li_groot_cast_nullability", new OrdinalReturnTypeInferenceV2(1),
family(SqlTypeFamily.ANY, SqlTypeFamily.ANY));

createAddUserDefinedFunction("com.linkedin.policy.decoration.udfs.RedactSecondarySchemaFieldIf", ARG1, family(
SqlTypeFamily.BOOLEAN, SqlTypeFamily.ANY, SqlTypeFamily.ARRAY, SqlTypeFamily.CHARACTER, SqlTypeFamily.ANY));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ static Schema relDataTypeToAvroTypeNonNullable(@Nonnull RelDataType relDataType,

private static Schema relDataTypeToAvroType(RelDataType relDataType, String recordName) {
final Schema avroSchema = relDataTypeToAvroTypeNonNullable(relDataType, recordName);
// TODO: Current logic ALWAYS sets the inner fields of RelDataType record nullable.
// Modify this to be applied only when RelDataType record was generated from a HIVE_UDF RexCall
return SchemaUtilities.makeNullable(avroSchema, false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import com.linkedin.coral.com.google.common.base.Preconditions;
import com.linkedin.coral.common.HiveMetastoreClient;
import com.linkedin.coral.common.HiveUncollect;
import com.linkedin.coral.hive.hive2rel.functions.OrdinalReturnTypeInferenceV2;


/**
Expand Down Expand Up @@ -407,14 +408,7 @@ public SchemaRexShuttle(Schema inputSchema, RelNode inputNode, Queue<String> sug
@Override
public RexNode visitInputRef(RexInputRef rexInputRef) {
RexNode rexNode = super.visitInputRef(rexInputRef);

Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex());
String oldFieldName = field.name();
String suggestNewFieldName = suggestedFieldNames.poll();
String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName);

SchemaUtilities.appendField(newFieldName, field, fieldAssembler);

appendRexInputRefField(rexInputRef);
return rexNode;
}

Expand Down Expand Up @@ -442,6 +436,22 @@ public RexNode visitCall(RexCall rexCall) {
* For SqlUserDefinedFunction and SqlOperator RexCall, no need to handle it recursively
* and only return type of udf or sql operator is relevant
*/

/**
* If the return type of RexCall is based on the ordinal of its input argument
* and the corresponding input argument refers to a field from the input schema,
* use the field's schema as is.
*/
if (rexCall.getOperator().getReturnTypeInference() instanceof OrdinalReturnTypeInferenceV2) {
int index = ((OrdinalReturnTypeInferenceV2) rexCall.getOperator().getReturnTypeInference()).getOrdinal();
RexNode operand = rexCall.operands.get(index);

if (operand instanceof RexInputRef) {
appendRexInputRefField((RexInputRef) operand);
return rexCall;
}
}

RelDataType fieldType = rexCall.getType();
boolean isNullable = SchemaUtilities.isFieldNullable(rexCall, inputSchema);

Expand Down Expand Up @@ -545,6 +555,15 @@ public RexNode visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) {
return super.visitPatternFieldRef(rexPatternFieldRef);
}

private void appendRexInputRefField(RexInputRef rexInputRef) {
Schema.Field field = inputSchema.getFields().get(rexInputRef.getIndex());
String oldFieldName = field.name();
String suggestNewFieldName = suggestedFieldNames.poll();
String newFieldName = SchemaUtilities.getFieldName(oldFieldName, suggestNewFieldName);

SchemaUtilities.appendField(newFieldName, field, fieldAssembler);
}

private void appendField(RelDataType fieldType, boolean isNullable, String doc) {
String fieldName = SchemaUtilities.getFieldName("", suggestedFieldNames.poll());
SchemaUtilities.appendField(fieldName, fieldType, doc, fieldAssembler, isNullable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ private static void initializeTables() {
String baseComplexUnionTypeSchema = loadSchema("base-complex-union-type.avsc");
String baseNestedUnionSchema = loadSchema("base-nested-union.avsc");
String baseComplexLowercase = loadSchema("base-complex-lowercase.avsc");
String baseComplexNonNullable = loadSchema("base-complex-non-nullable.avsc");
String baseComplexNullableWithDefaults = loadSchema("base-complex-nullable-with-defaults.avsc");
String basePrimitive = loadSchema("base-primitive.avsc");
String baseComplexNestedStructSameName = loadSchema("base-complex-nested-struct-same-name.avsc");
Expand All @@ -121,6 +122,7 @@ private static void initializeTables() {
executeCreateTableWithPartitionFieldSchemaQuery("default", "basecomplexfieldschema", baseComplexFieldSchema);
executeCreateTableWithPartitionQuery("default", "basenestedcomplex", baseNestedComplexSchema);
executeCreateTableWithPartitionQuery("default", "basecomplexnullablewithdefaults", baseComplexNullableWithDefaults);
executeCreateTableWithPartitionQuery("default", "basecomplexnonnullable", baseComplexNonNullable);

String baseComplexSchemaWithDoc = loadSchema("docTestResources/base-complex-with-doc.avsc");
String baseEnumSchemaWithDoc = loadSchema("docTestResources/base-enum-with-doc.avsc");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1102,5 +1102,19 @@ public void testDivideReturnType() {
Assert.assertEquals(actualSchema.toString(true), TestUtils.loadSchema("testDivideReturnType-expected.avsc"));
}

@Test
public void testLiGrootCastNullability() {
ViewToAvroSchemaConverter viewToAvroSchemaConverter = ViewToAvroSchemaConverter.create(hiveMetastoreClient);

Schema schemaWithUDF = viewToAvroSchemaConverter
.toAvroSchema("SELECT li_groot_cast_nullability(Struct_Col, Struct_Col) AS modCol FROM basecomplexnonnullable");
Schema schemaWithField =
viewToAvroSchemaConverter.toAvroSchema("SELECT Struct_Col AS modCol FROM basecomplexnonnullable");

Assert.assertEquals(schemaWithUDF.toString(true), TestUtils.loadSchema("testLiGrootCastNullability-expected.avsc"));
Assert.assertEquals(schemaWithField.toString(true),
TestUtils.loadSchema("testLiGrootCastNullability-expected.avsc"));
}

// TODO: add more unit tests
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
{
"type" : "record",
"name" : "basecomplexnonnullable",
"namespace" : "coral.schema.avro.base.complex.nonnullable",
"fields" : [ {
"name" : "modCol",
"type" : {
"type" : "record",
"name" : "Struct_col",
"namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable",
"fields" : [ {
"name" : "Bool_Field",
"type" : "boolean"
}, {
"name" : "Int_Field",
"type" : "int"
}, {
"name" : "Bigint_Field",
"type" : "long"
}, {
"name" : "Float_Field",
"type" : "float"
}, {
"name" : "Double_Field",
"type" : "double"
}, {
"name" : "Date_String_Field",
"type" : "string"
}, {
"name" : "String_Field",
"type" : "string"
}, {
"name" : "Array_Col",
"type" : {
"type" : "array",
"items" : {
"type" : "record",
"name" : "Struct_col",
"namespace" : "coral.schema.avro.base.complex.nonnullable.basecomplexnonnullable.basecomplexnonnullable",
"fields" : [ {
"name" : "key",
"type" : "string"
}, {
"name" : "value",
"type" : "string"
} ]
}
}
}, {
"name" : "Map_Col",
"type" : {
"type" : "map",
"values" : "string"
}
} ]
}
} ]
}

0 comments on commit 08e43df

Please sign in to comment.