Skip to content

Commit

Permalink
[CALCITE-6690] Refactor the Arrow adapter type system
Browse files Browse the repository at this point in the history
  • Loading branch information
caicancai committed Dec 14, 2024
1 parent 041619f commit 6991ef9
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,84 +17,69 @@
package org.apache.calcite.adapter.arrow;

import org.apache.calcite.adapter.java.JavaTypeFactory;
import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.type.SqlTypeName;

import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;

import java.math.BigDecimal;
import java.sql.Date;
import java.util.List;

import static java.util.Objects.requireNonNull;

/**
* Arrow field type.
*/
enum ArrowFieldType {
INT(Primitive.INT),
BOOLEAN(Primitive.BOOLEAN),
STRING(String.class),
FLOAT(Primitive.FLOAT),
DOUBLE(Primitive.DOUBLE),
DATE(Date.class),
LIST(List.class),
DECIMAL(BigDecimal.class),
LONG(Primitive.LONG),
BYTE(Primitive.BYTE),
SHORT(Primitive.SHORT);

private final Class<?> clazz;

ArrowFieldType(Primitive primitive) {
this(requireNonNull(primitive.boxClass, "boxClass"));
}
public class ArrowFieldType {

ArrowFieldType(Class<?> clazz) {
this.clazz = clazz;
private ArrowFieldType() {
throw new UnsupportedOperationException("Utility class");
}

public RelDataType toType(JavaTypeFactory typeFactory) {
RelDataType javaType = typeFactory.createJavaType(clazz);
RelDataType sqlType = typeFactory.createSqlType(javaType.getSqlTypeName());
public static RelDataType toType(ArrowType arrowType, JavaTypeFactory typeFactory) {
RelDataType sqlType = of(arrowType, typeFactory);
return typeFactory.createTypeWithNullability(sqlType, true);
}

public static ArrowFieldType of(ArrowType arrowType) {
/**
* Converts an Arrow type to a Calcite RelDataType.
*
* @param arrowType the Arrow type to convert
* @param typeFactory the factory to create the Calcite type
* @return the corresponding Calcite RelDataType
*/
private static RelDataType of(ArrowType arrowType, JavaTypeFactory typeFactory) {
switch (arrowType.getTypeID()) {
case Int:
int bitWidth = ((ArrowType.Int) arrowType).getBitWidth();
switch (bitWidth) {
case 64:
return LONG;
return typeFactory.createSqlType(SqlTypeName.BIGINT);
case 32:
return INT;
return typeFactory.createSqlType(SqlTypeName.INTEGER);
case 16:
return SHORT;
return typeFactory.createSqlType(SqlTypeName.SMALLINT);
case 8:
return BYTE;
return typeFactory.createSqlType(SqlTypeName.TINYINT);
default:
throw new IllegalArgumentException("Unsupported Int bit width: " + bitWidth);
}
case Bool:
return BOOLEAN;
return typeFactory.createSqlType(SqlTypeName.BOOLEAN);
case Utf8:
return STRING;
return typeFactory.createSqlType(SqlTypeName.VARCHAR);
case FloatingPoint:
FloatingPointPrecision precision = ((ArrowType.FloatingPoint) arrowType).getPrecision();
switch (precision) {
case SINGLE:
return FLOAT;
return typeFactory.createSqlType(SqlTypeName.REAL);
case DOUBLE:
return DOUBLE;
return typeFactory.createSqlType(SqlTypeName.DOUBLE);
default:
throw new IllegalArgumentException("Unsupported Floating point precision: " + precision);
}
case Date:
return DATE;
return typeFactory.createSqlType(SqlTypeName.DATE);
case Decimal:
return DECIMAL;
return typeFactory.createSqlType(SqlTypeName.DECIMAL,
((ArrowType.Decimal) arrowType).getPrecision(),
((ArrowType.Decimal) arrowType).getScale());
default:
throw new IllegalArgumentException("Unsupported type: " + arrowType);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ private static RelDataType deduceRowType(Schema schema,
final RelDataTypeFactory.Builder builder = typeFactory.builder();
for (Field field : schema.getFields()) {
builder.add(field.getName(),
ArrowFieldType.of(field.getType()).toType(typeFactory));
ArrowFieldType.toType(field.getType(), typeFactory));
}
return builder.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,8 +732,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
@Test void testFilteredAgg() {
String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP";
String plan = "PLAN=EnumerableAggregate(group=[{}], SALESSUM=[SUM($0) FILTER $1])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], "
+ "expr#10=[IS TRUE($t9)], SAL=[$t5], $f1=[$t10])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(10, 2)], expr#9=[400.00:DECIMAL(10, 2)], "
+ "expr#10=[>($t8, $t9)], expr#11=[IS TRUE($t10)], SAL=[$t5], $f1=[$t11])\n"
+ " ArrowToEnumerableConverter\n"
+ " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n";
String result = "SALESSUM=2500.00\n";
Expand All @@ -750,8 +750,8 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
String sql = "select SUM(SAL) FILTER (WHERE COMM > 400) as SALESSUM from EMP group by EMPNO";
String plan = "PLAN=EnumerableCalc(expr#0..1=[{inputs}], SALESSUM=[$t1])\n"
+ " EnumerableAggregate(group=[{0}], SALESSUM=[SUM($1) FILTER $2])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[400:DECIMAL(19, 0)], expr#9=[>($t6, $t8)], "
+ "expr#10=[IS TRUE($t9)], EMPNO=[$t0], SAL=[$t5], $f2=[$t10])\n"
+ " EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t6):DECIMAL(10, 2)], expr#9=[400.00:DECIMAL(10, 2)], "
+ "expr#10=[>($t8, $t9)], expr#11=[IS TRUE($t10)], EMPNO=[$t0], SAL=[$t5], $f2=[$t11])\n"
+ " ArrowToEnumerableConverter\n"
+ " ArrowTableScan(table=[[ARROW, EMP]], fields=[[0, 1, 2, 3, 4, 5, 6, 7]])\n\n";
String result = "SALESSUM=1250.00\nSALESSUM=null\n";
Expand Down Expand Up @@ -971,7 +971,7 @@ static void initializeArrowState(@TempDir Path sharedTempDir)
+ "where \"decimalField\" = 1.00";
String plan = "PLAN=ArrowToEnumerableConverter\n"
+ " ArrowProject(decimalField=[$8])\n"
+ " ArrowFilter(condition=[=($8, 1)])\n"
+ " ArrowFilter(condition=[=($8, 1.00)])\n"
+ " ArrowTableScan(table=[[ARROW, ARROWDATATYPE]], fields=[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])\n\n";
String result = "decimalField=1.00\n";

Expand Down

0 comments on commit 6991ef9

Please sign in to comment.