diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndex.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndex.scala index 990ba73b5..04d551dc1 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndex.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndex.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.{Column, DataFrame, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.functions.input_file_name import org.apache.spark.sql.hyperspace.utils.StructTypeUtils @@ -151,7 +152,7 @@ case class DataSkippingIndex( // by combining individual index predicates with And. // True is returned if there are no index predicates for the source predicate node. def toIndexPred(sourcePred: Expression): Expression = { - predMap.get(sourcePred).map(_.reduceLeft(And)).getOrElse(Literal.TrueLiteral) + predMap.get(sourcePred).map(_.reduceLeft(And)).getOrElse(TrueLiteral) } // Compose an index predicate visiting the source predicate tree recursively. @@ -168,15 +169,15 @@ case class DataSkippingIndex( // This is a trimmed down version of the BooleanSimplification rule. // It's just enough to determine whether the index is applicable or not. val optimizePredicate: PartialFunction[Expression, Expression] = { - case And(Literal.TrueLiteral, right) => right - case And(left, Literal.TrueLiteral) => left - case Or(Literal.TrueLiteral, _) => Literal.TrueLiteral - case Or(_, Literal.TrueLiteral) => Literal.TrueLiteral + case And(TrueLiteral, right) => right + case And(left, TrueLiteral) => left + case Or(TrueLiteral, _) => TrueLiteral + case Or(_, TrueLiteral) => TrueLiteral } val optimizedIndexPredicate = indexPredicate.transformUp(optimizePredicate) // Return None if the index predicate is True - meaning no conversion can be done. - if (optimizedIndexPredicate == Literal.TrueLiteral) { + if (optimizedIndexPredicate == TrueLiteral) { None } else { Some(optimizedIndexPredicate) diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterAgg.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterAgg.scala new file mode 100644 index 000000000..95afffd49 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterAgg.scala @@ -0,0 +1,83 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.sketch.BloomFilter + +/** + * Aggregation function that collects elements in a bloom filter. + */ +private[dataskipping] case class BloomFilterAgg( + child: Expression, + expectedNumItems: Long, // expected number of distinct elements + fpp: Double, // target false positive probability + override val mutableAggBufferOffset: Int = 0, + override val inputAggBufferOffset: Int = 0) + extends TypedImperativeAggregate[BloomFilter] { + + override def prettyName: String = "bloom_filter" + + override def dataType: DataType = bloomFilterEncoder.dataType + + override def nullable: Boolean = false + + override def children: Seq[Expression] = Seq(child) + + override def createAggregationBuffer(): BloomFilter = { + BloomFilter.create(expectedNumItems, fpp) + } + + override def update(buffer: BloomFilter, input: InternalRow): BloomFilter = { + val value = child.eval(input) + if (value != null) { + BloomFilterUtils.put(buffer, value, child.dataType) + } + buffer + } + + override def merge(buffer: BloomFilter, input: BloomFilter): BloomFilter = { + buffer.mergeInPlace(input) + buffer + } + + override def eval(buffer: BloomFilter): Any = bloomFilterEncoder.encode(buffer) + + override def serialize(buffer: BloomFilter): Array[Byte] = { + val out = new ByteArrayOutputStream() + buffer.writeTo(out) + out.toByteArray + } + + override def deserialize(bytes: Array[Byte]): BloomFilter = { + val in = new ByteArrayInputStream(bytes) + BloomFilter.readFrom(in) + } + + override def withNewMutableAggBufferOffset(newOffset: Int): BloomFilterAgg = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): BloomFilterAgg = + copy(inputAggBufferOffset = newOffset) + + private def bloomFilterEncoder = BloomFilterEncoderProvider.defaultEncoder +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterEncoder.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterEncoder.scala new file mode 100644 index 000000000..3c956a8d3 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterEncoder.scala @@ -0,0 +1,42 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.sketch.BloomFilter + +/** + * Defines how [[BloomFilter]] should be represented in the Spark DataFrame. + */ +trait BloomFilterEncoder { + + /** + * Returns the data type of the value in the DataFrame representing [[BloomFilter]]. + */ + def dataType: DataType + + /** + * Returns a value representing the given [[BloomFilter]] + * that can be put in the [[InternalRow]]. + */ + def encode(bf: BloomFilter): Any + + /** + * Returns a [[BloomFilter]] from the value in the DataFrame. + */ + def decode(value: Any): BloomFilter +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterEncoderProvider.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterEncoderProvider.scala new file mode 100644 index 000000000..2f65624db --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterEncoderProvider.scala @@ -0,0 +1,30 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +/** + * Provides the default implementation of [[BloomFilterEncoder]]. + */ +object BloomFilterEncoderProvider { + + /** + * Returns the default encoder. + * + * It should return a singleton object declared as "object". + */ + def defaultEncoder: BloomFilterEncoder = FastBloomFilterEncoder +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContain.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContain.scala new file mode 100644 index 000000000..6bc7f23d8 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContain.scala @@ -0,0 +1,71 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, Predicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ + +/** + * Returns true if the bloom filter (left) might contain the value (right). + * + * If the value (right) is null, null is returned. + * + * Preconditions (unchecked): + * - The bloom filter must not be null. + */ +private[dataskipping] case class BloomFilterMightContain(left: Expression, right: Expression) + extends BinaryExpression + with Predicate { + + override def prettyName: String = "bloom_filter_might_contain" + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = right.eval(input) + if (value != null) { + val bfData = left.eval(input) + val bf = BloomFilterEncoderProvider.defaultEncoder.decode(bfData) + return BloomFilterUtils.mightContain(bf, value, right.dataType) + } + null + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) + val bloomFilterEncoder = + BloomFilterEncoderProvider.defaultEncoder.getClass.getCanonicalName.stripSuffix("$") + val bf = s"$bloomFilterEncoder.decode(${leftGen.value})" + val result = BloomFilterUtils.mightContainCodegen(bf, rightGen.value, right.dataType) + val resultCode = + s""" + |if (!(${rightGen.isNull})) { + | ${leftGen.code} + | ${ev.isNull} = false; + | ${ev.value} = $result; + |} + """.stripMargin + ev.copy(code = code""" + ${rightGen.code} + boolean ${ev.isNull} = true; + boolean ${ev.value} = false; + $resultCode""") + } +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainAny.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainAny.scala new file mode 100644 index 000000000..4b6c4189d --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainAny.scala @@ -0,0 +1,85 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, Predicate, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.types.DataType +import org.apache.spark.util.sketch.BloomFilter + +/** + * Returns true if the bloom filter (child) might contain one of the values. + * + * Preconditions (unchecked): + * - The bloom filter must not be null. + * - The values must be an array without nulls. + * - If the element type can be represented as a primitive type in Scala, + * then the array must be an array of the primitive type. + */ +private[dataskipping] case class BloomFilterMightContainAny( + child: Expression, + values: Any, + elementType: DataType) + extends UnaryExpression + with Predicate { + + override def prettyName: String = "bloom_filter_might_contain_any" + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Boolean = { + val bfData = child.eval(input) + val bf = BloomFilterEncoderProvider.defaultEncoder.decode(bfData) + values + .asInstanceOf[Array[_]] + .exists(BloomFilterUtils.mightContain(bf, _, elementType)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx); + val bloomFilterEncoder = + BloomFilterEncoderProvider.defaultEncoder.getClass.getCanonicalName.stripSuffix("$") + val bf = ctx.freshName("bf") + val bfType = classOf[BloomFilter].getCanonicalName + val javaType = CodeGenerator.javaType(elementType) + val arrayType = if (values.isInstanceOf[Array[Any]]) "java.lang.Object[]" else s"$javaType[]" + val valuesRef = ctx.addReferenceObj("values", values, arrayType) + val valuesArray = ctx.freshName("values") + val i = ctx.freshName("i") + val mightContain = + BloomFilterUtils.mightContainCodegen(bf, s"($javaType) $valuesArray[$i]", elementType) + val resultCode = + s""" + |$bfType $bf = $bloomFilterEncoder.decode(${childGen.value}); + |$arrayType $valuesArray = $valuesRef; + |for (int $i = 0; $i < $valuesArray.length; $i++) { + | if ($mightContain) { + | ${ev.value} = true; + | break; + | } + |} + """.stripMargin + ev.copy( + code = code""" + ${childGen.code} + boolean ${ev.value} = false; + $resultCode""", + isNull = FalseLiteral) + } +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterUtils.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterUtils.scala new file mode 100644 index 000000000..2490b638a --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterUtils.scala @@ -0,0 +1,62 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.HyperspaceException + +// TODO: Support more types. +// Currently we are relying on org.apache.spark.util.sketch.BloomFilter and +// supported types are restricted by the implementation. To support more types +// without changing the underlying implementation, we can convert Spark values +// to and from byte arrays. +private[dataskipping] object BloomFilterUtils { + def put(bf: BloomFilter, value: Any, dataType: DataType): Boolean = + dataType match { + case LongType => bf.putLong(value.asInstanceOf[Long]) + case IntegerType => bf.putLong(value.asInstanceOf[Int]) + case ByteType => bf.putLong(value.asInstanceOf[Byte]) + case ShortType => bf.putLong(value.asInstanceOf[Short]) + case StringType => bf.putBinary(value.asInstanceOf[UTF8String].getBytes) + case BinaryType => bf.putBinary(value.asInstanceOf[Array[Byte]]) + case _ => throw HyperspaceException(s"BloomFilter does not support ${dataType}") + } + + def mightContain(bf: BloomFilter, value: Any, dataType: DataType): Boolean = { + dataType match { + case LongType => bf.mightContainLong(value.asInstanceOf[Long]) + case IntegerType => bf.mightContainLong(value.asInstanceOf[Int]) + case ByteType => bf.mightContainLong(value.asInstanceOf[Byte]) + case ShortType => bf.mightContainLong(value.asInstanceOf[Short]) + case StringType => bf.mightContainBinary(value.asInstanceOf[UTF8String].getBytes) + case BinaryType => bf.mightContainBinary(value.asInstanceOf[Array[Byte]]) + case _ => throw HyperspaceException(s"BloomFilter does not support ${dataType}") + } + } + + def mightContainCodegen(bf: String, value: String, dataType: DataType): String = { + dataType match { + case LongType | IntegerType | ByteType | ShortType => s"$bf.mightContainLong($value)" + case StringType => s"$bf.mightContainBinary(($value).getBytes())" + case BinaryType => s"$bf.mightContainBinary($value)" + case _ => throw HyperspaceException(s"BloomFilter does not support ${dataType}") + } + } +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/FastBloomFilterEncoder.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/FastBloomFilterEncoder.scala new file mode 100644 index 000000000..36b222896 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/FastBloomFilterEncoder.scala @@ -0,0 +1,60 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{ArrayType, IntegerType, LongType, StructField, StructType} +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.index.dataskipping.util.ReflectionHelper + +/** + * A [[BloomFilterEncoder]] implementation that avoids copying arrays. + */ +object FastBloomFilterEncoder extends BloomFilterEncoder with ReflectionHelper { + override val dataType: StructType = StructType( + StructField("numHashFunctions", IntegerType, nullable = false) :: + StructField("bitCount", LongType, nullable = false) :: + StructField("data", ArrayType(LongType, containsNull = false), nullable = false) :: Nil) + + override def encode(bf: BloomFilter): InternalRow = { + val bloomFilterImplClass = bf.getClass + val bits = get(bloomFilterImplClass, "bits", bf) + val bitArrayClass = bits.getClass + InternalRow( + getInt(bloomFilterImplClass, "numHashFunctions", bf), + getLong(bitArrayClass, "bitCount", bits), + ArrayData.toArrayData(get(bitArrayClass, "data", bits).asInstanceOf[Array[Long]])) + } + + override def decode(value: Any): BloomFilter = { + val struct = value.asInstanceOf[InternalRow] + val numHashFunctions = struct.getInt(0) + val bitCount = struct.getLong(1) + val data = struct.getArray(2).toLongArray() + + val bf = BloomFilter.create(1) + val bloomFilterImplClass = bf.getClass + val bits = get(bloomFilterImplClass, "bits", bf) + val bitArrayClass = bits.getClass + setInt(bloomFilterImplClass, "numHashFunctions", bf, numHashFunctions) + setLong(bitArrayClass, "bitCount", bits, bitCount) + set(bitArrayClass, "data", bits, data) + bf + } +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContains.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContains.scala new file mode 100644 index 000000000..02877e7e6 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContains.scala @@ -0,0 +1,91 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, Predicate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils} + +/** + * Returns true if the sorted array (left) contains the value (right). + * + * If the value (right) is null, null is returned. + * + * Preconditions (unchecked): + * - The array must not be null. + * - Elements in the array must be in ascending order. + * - The array must not contain null elements. + * - The array must not contain duplicate elements. + */ +private[dataskipping] case class SortedArrayContains(left: Expression, right: Expression) + extends BinaryExpression + with Predicate { + + override def prettyName: String = "sorted_array_contains" + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + val value = right.eval(input) + if (value != null) { + val arr = left.eval(input).asInstanceOf[ArrayData] + val dt = right.dataType + val n = arr.numElements() + if (n > 0 && + ordering.lteq(arr.get(0, dt), value) && + ordering.lteq(value, arr.get(n - 1, dt))) { + val (found, _) = SortedArrayUtils.binarySearch(arr, dt, ordering, 0, n, value) + if (found) return true + } + return false + } + null + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val arr = leftGen.value + val rightGen = right.genCode(ctx) + val value = rightGen.value + val dt = right.dataType + val n = ctx.freshName("n") + val binarySearch = SortedArrayUtils.binarySearchCodeGen(ctx, dt) + val resultCode = + s""" + |if (!(${rightGen.isNull})) { + | ${leftGen.code} + | ${ev.isNull} = false; + | int $n = $arr.numElements(); + | if ($n > 0 && + | !(${ctx.genGreater(dt, CodeGenerator.getValue(arr, dt, "0"), value)}) && + | !(${ctx.genGreater(dt, value, CodeGenerator.getValue(arr, dt, s"$n - 1"))})) { + | ${ev.value} = $binarySearch($arr, 0, $n, $value).found(); + | } + |} + """.stripMargin + ev.copy(code = code""" + ${rightGen.code} + boolean ${ev.isNull} = true; + boolean ${ev.value} = false; + $resultCode""") + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(right.dataType) +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsAny.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsAny.scala new file mode 100644 index 000000000..ab3a14cb4 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsAny.scala @@ -0,0 +1,132 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, Predicate, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.util.{ArrayData, TypeUtils} +import org.apache.spark.sql.types.DataType + +/** + * Returns true if the sorted array (child) contains any of the values. + * + * If either array is empty, false is returned. + * + * Preconditions (unchecked): + * - Both arrays must not be null. + * - Elements in the arrays must be in ascending order. + * - The left array should not contain duplicate elements. + * - The arrays must not contain null elements. + * + * If the element type can be represented as a primitive type in Scala, + * then the right array must be an array of the primitive type. + */ +private[dataskipping] case class SortedArrayContainsAny( + child: Expression, + values: Any, + elementType: DataType) + extends UnaryExpression + with Predicate { + + override def prettyName: String = "sorted_array_contains_any" + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Boolean = { + val arr1 = child.eval(input).asInstanceOf[ArrayData] + val arr2 = values.asInstanceOf[Array[_]] + val dt = elementType + val n = arr1.numElements() + val m = arr2.length + if (n > 0 && m > 0 && + ordering.lteq(arr1.get(0, dt), arr2(m - 1)) && + ordering.lteq(arr2(0), arr1.get(n - 1, dt))) { + var i = 0 + var j = 0 + do { + val v = arr1.get(i, dt) + while (j < m && ordering.lt(arr2(j), v)) j += 1 + if (j == m) return false + val u = arr2(j) + j += 1 + val (found, k) = SortedArrayUtils.binarySearch(arr1, dt, ordering, i, n, u) + if (found) return true + if (k == n) return false + i = k + } while (j < m) + } + false + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val arr1 = childGen.value + val arr2 = ctx.freshName("values") + val dt = elementType + val javaType = CodeGenerator.javaType(dt) + val arrayType = if (values.isInstanceOf[Array[Any]]) "java.lang.Object[]" else s"$javaType[]" + val valuesRef = ctx.addReferenceObj("values", values, arrayType) + val n = ctx.freshName("n") + val m = ctx.freshName("m") + val i = ctx.freshName("i") + val j = ctx.freshName("j") + val v = ctx.freshName("v") + val u = ctx.freshName("u") + val result = ctx.freshName("result") + val binarySearchResultType = + SortedArrayUtils.BinarySearchResult.getClass.getCanonicalName.stripSuffix("$") + val binarySearch = SortedArrayUtils.binarySearchCodeGen(ctx, dt) + import CodeGenerator.getValue + val resultCode = + s""" + |int $n = $arr1.numElements(); + |int $m = $arr2.length; + |if ($n > 0 && $m > 0 && + | !(${ctx.genGreater(dt, getValue(arr1, dt, "0"), s"(($javaType) $arr2[$m - 1])")}) && + | !(${ctx.genGreater(dt, s"(($javaType)$arr2[0])", getValue(arr1, dt, s"$n - 1"))})) { + | int $i = 0; + | int $j = 0; + | do { + | $javaType $v = ${getValue(arr1, dt, i)}; + | while ($j < $m && ${ctx.genGreater(dt, v, s"(($javaType) $arr2[$j])")}) $j += 1; + | if ($j == $m) break; + | $javaType $u = ($javaType) $arr2[$j]; + | $j += 1; + | $binarySearchResultType $result = $binarySearch($arr1, $i, $n, $u); + | if ($result.found()) { + | ${ev.value} = true; + | break; + | } + | if ($result.index() == $n) break; + | $i = $result.index(); + | } while ($j < $m); + |} + """.stripMargin + ev.copy( + code = code""" + ${childGen.code} + $arrayType $arr2 = $valuesRef; + boolean ${ev.value} = false; + $resultCode""", + isNull = FalseLiteral) + } + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayLowerBound.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayLowerBound.scala index 48ed15ff9..9199f7882 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayLowerBound.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayLowerBound.scala @@ -46,9 +46,9 @@ private[dataskipping] case class SortedArrayLowerBound(left: Expression, right: override def nullable: Boolean = true override def eval(input: InternalRow): Any = { - val arr = left.eval(input).asInstanceOf[ArrayData] val value = right.eval(input) if (value != null) { + val arr = left.eval(input).asInstanceOf[ArrayData] val dt = right.dataType val n = arr.numElements() if (n > 0) { @@ -77,6 +77,7 @@ private[dataskipping] case class SortedArrayLowerBound(left: Expression, right: val resultCode = s""" |if (!(${rightGen.isNull})) { + | ${leftGen.code} | int $n = $arr.numElements(); | if ($n > 0) { | if (!(${ctx.genGreater(dt, value, firstValueInArr)})) { @@ -90,7 +91,6 @@ private[dataskipping] case class SortedArrayLowerBound(left: Expression, right: |} """.stripMargin ev.copy(code = code""" - ${leftGen.code} ${rightGen.code} boolean ${ev.isNull} = true; int ${ev.value} = 0; diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/StreamBloomFilterEncoder.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/StreamBloomFilterEncoder.scala new file mode 100644 index 000000000..78a768b7a --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/StreamBloomFilterEncoder.scala @@ -0,0 +1,40 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import org.apache.spark.sql.types.BinaryType +import org.apache.spark.util.sketch.BloomFilter + +/** + * A [[BloomFilterEncoder]] implementation based on byte array streams. + */ +object StreamBloomFilterEncoder extends BloomFilterEncoder { + val dataType: BinaryType = BinaryType + + def encode(bf: BloomFilter): Array[Byte] = { + val out = new ByteArrayOutputStream() + bf.writeTo(out) + out.toByteArray + } + + def decode(value: Any): BloomFilter = { + val in = new ByteArrayInputStream(value.asInstanceOf[Array[Byte]]) + BloomFilter.readFrom(in) + } +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/extractors.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/extractors.scala index f097fc97d..3bf458d17 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/extractors.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/expressions/extractors.scala @@ -19,8 +19,13 @@ package com.microsoft.hyperspace.index.dataskipping.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.BooleanType -case class EqualToExtractor(left: ExpressionExtractor, right: ExpressionExtractor) { - def unapply(p: Expression): Option[(Expression, Expression)] = +trait BinaryExpressionExtractor { + def unapply(p: Expression): Option[(Expression, Expression)] +} + +case class EqualToExtractor(left: ExpressionExtractor, right: ExpressionExtractor) + extends BinaryExpressionExtractor { + override def unapply(p: Expression): Option[(Expression, Expression)] = p match { case EqualTo(left(l), right(r)) => Some((l, r)) case EqualTo(right(r), left(l)) => Some((l, r)) @@ -28,8 +33,9 @@ case class EqualToExtractor(left: ExpressionExtractor, right: ExpressionExtracto } } -case class EqualNullSafeExtractor(left: ExpressionExtractor, right: ExpressionExtractor) { - def unapply(p: Expression): Option[(Expression, Expression)] = +case class EqualNullSafeExtractor(left: ExpressionExtractor, right: ExpressionExtractor) + extends BinaryExpressionExtractor { + override def unapply(p: Expression): Option[(Expression, Expression)] = p match { case EqualNullSafe(left(l), right(r)) => Some((l, r)) case EqualNullSafe(right(r), left(l)) => Some((l, r)) @@ -37,8 +43,9 @@ case class EqualNullSafeExtractor(left: ExpressionExtractor, right: ExpressionEx } } -case class LessThanExtractor(left: ExpressionExtractor, right: ExpressionExtractor) { - def unapply(p: Expression): Option[(Expression, Expression)] = +case class LessThanExtractor(left: ExpressionExtractor, right: ExpressionExtractor) + extends BinaryExpressionExtractor { + override def unapply(p: Expression): Option[(Expression, Expression)] = p match { case LessThan(left(l), right(r)) => Some((l, r)) case GreaterThan(right(r), left(l)) => Some((l, r)) @@ -46,8 +53,9 @@ case class LessThanExtractor(left: ExpressionExtractor, right: ExpressionExtract } } -case class LessThanOrEqualExtractor(left: ExpressionExtractor, right: ExpressionExtractor) { - def unapply(p: Expression): Option[(Expression, Expression)] = +case class LessThanOrEqualExtractor(left: ExpressionExtractor, right: ExpressionExtractor) + extends BinaryExpressionExtractor { + override def unapply(p: Expression): Option[(Expression, Expression)] = p match { case LessThanOrEqual(left(l), right(r)) => Some((l, r)) case GreaterThanOrEqual(right(r), left(l)) => Some((l, r)) @@ -55,6 +63,28 @@ case class LessThanOrEqualExtractor(left: ExpressionExtractor, right: Expression } } +case class SwitchLeftRight(extractor: BinaryExpressionExtractor) + extends BinaryExpressionExtractor { + override def unapply(p: Expression): Option[(Expression, Expression)] = { + extractor.unapply(p) match { + case Some((a, b)) => Some((b, a)) + case None => None + } + } +} + +object GreaterThanExtractor { + def apply(left: ExpressionExtractor, right: ExpressionExtractor): BinaryExpressionExtractor = { + SwitchLeftRight(LessThanExtractor(right, left)) + } +} + +object GreaterThanOrEqualExtractor { + def apply(left: ExpressionExtractor, right: ExpressionExtractor): BinaryExpressionExtractor = { + SwitchLeftRight(LessThanOrEqualExtractor(right, left)) + } +} + case class IsNullExtractor(expr: ExpressionExtractor) { def unapply(p: Expression): Option[Expression] = p match { diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/BloomFilterSketch.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/BloomFilterSketch.scala new file mode 100644 index 000000000..d9bcbaad6 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/BloomFilterSketch.scala @@ -0,0 +1,86 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.sketches + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types.DataType + +import com.microsoft.hyperspace.index.dataskipping.expressions._ +import com.microsoft.hyperspace.index.dataskipping.util.ArrayUtils.toArray + +/** + * Sketch based on a bloom filter for a given expression. + * + * Being a probabilistic structure, it is more efficient in terms of the index + * data size than [[ValueListSketch]] if the number of distinct values for the + * expression is large, but can be less efficient in terms of query optimization + * than [[ValueListSketch]] due to false positives. + * + * Users can specify the target false positive rate and the expected number of + * distinct values per file. These variables determine the size of the bloom + * filters and thus the size of the index data. + * + * @param expr Expression this sketch is based on + * @param fpp Target false positive rate + * @param expectedDistinctCountPerFile Expected number of distinct values per file + * @param dataType Optional data type to specify the expected data type of the + * expression. If not specified, it is deduced automatically. + * If the actual data type of the expression is different from this, + * an error is thrown. Users are recommended to leave this parameter to + * None. + */ +case class BloomFilterSketch( + override val expr: String, + fpp: Double, + expectedDistinctCountPerFile: Long, + override val dataType: Option[DataType] = None) + extends SingleExprSketch[BloomFilterSketch](expr, dataType) { + override def name: String = "BloomFilter" + + override def toString: String = s"$name($expr, $fpp, $expectedDistinctCountPerFile)" + + override def withNewExpression(newExpr: (String, Option[DataType])): BloomFilterSketch = { + copy(expr = newExpr._1, dataType = newExpr._2) + } + + override def aggregateFunctions: Seq[Expression] = { + BloomFilterAgg(parsedExpr, expectedDistinctCountPerFile, fpp).toAggregateExpression :: Nil + } + + override def convertPredicate( + predicate: Expression, + resolvedExprs: Seq[Expression], + sketchValues: Seq[Expression], + nameMap: Map[ExprId, String], + valueExtractor: ExpressionExtractor): Option[Expression] = { + val bf = sketchValues.head + val resolvedExpr = resolvedExprs.head + val dataType = resolvedExpr.dataType + val exprExtractor = NormalizedExprExtractor(resolvedExpr, nameMap) + val ExprEqualTo = EqualToExtractor(exprExtractor, valueExtractor) + val ExprEqualNullSafe = EqualNullSafeExtractor(exprExtractor, valueExtractor) + val ExprIn = InExtractor(exprExtractor, valueExtractor) + val ExprInSet = InSetExtractor(exprExtractor) + Option(predicate).collect { + case ExprEqualTo(_, v) => BloomFilterMightContain(bf, v) + case ExprEqualNullSafe(_, v) => Or(IsNull(v), BloomFilterMightContain(bf, v)) + case ExprIn(_, vs) => vs.map(BloomFilterMightContain(bf, _)).reduceLeft(Or) + case ExprInSet(_, vs) => + BloomFilterMightContainAny(bf, toArray(vs.filter(_ != null).toSeq, dataType), dataType) + } + } +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/MinMaxSketch.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/MinMaxSketch.scala index 6d8a143dc..dad22d3e7 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/MinMaxSketch.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/MinMaxSketch.scala @@ -42,8 +42,9 @@ case class MinMaxSketch(override val expr: String, override val dataType: Option copy(expr = newExpr._1, dataType = newExpr._2) } - override def aggregateFunctions: Seq[Expression] = - Min(parsedExpr).toAggregateExpression() :: Max(parsedExpr).toAggregateExpression() :: Nil + override def aggregateFunctions: Seq[Expression] = { + Seq(Min(parsedExpr), Max(parsedExpr)).map(_.toAggregateExpression) + } override def convertPredicate( predicate: Expression, @@ -68,33 +69,32 @@ case class MinMaxSketch(override val expr: String, override val dataType: Option val ExprEqualNullSafe = EqualNullSafeExtractor(exprExtractor, valueExtractor) val ExprLessThan = LessThanExtractor(exprExtractor, valueExtractor) val ExprLessThanOrEqualTo = LessThanOrEqualExtractor(exprExtractor, valueExtractor) - val ExprGreaterThan = LessThanExtractor(valueExtractor, exprExtractor) - val ExprGreaterThanOrEqualTo = LessThanOrEqualExtractor(valueExtractor, exprExtractor) + val ExprGreaterThan = GreaterThanExtractor(exprExtractor, valueExtractor) + val ExprGreaterThanOrEqualTo = GreaterThanOrEqualExtractor(exprExtractor, valueExtractor) val ExprIn = InExtractor(exprExtractor, valueExtractor) val ExprInSet = InSetExtractor(exprExtractor) - Option(predicate) - .collect { - case ExprIsTrue(_) => max - case ExprIsFalse(_) => Not(min) - case ExprIsNotNull(_) => IsNotNull(min) - case ExprEqualTo(_, v) => And(LessThanOrEqual(min, v), GreaterThanOrEqual(max, v)) - case ExprEqualNullSafe(_, v) => - Or(IsNull(v), And(LessThanOrEqual(min, v), GreaterThanOrEqual(max, v))) - case ExprLessThan(_, v) => LessThan(min, v) - case ExprLessThanOrEqualTo(_, v) => LessThanOrEqual(min, v) - case ExprGreaterThan(v, _) => GreaterThan(max, v) - case ExprGreaterThanOrEqualTo(v, _) => GreaterThanOrEqual(max, v) - case ExprIn(_, vs) => - vs.map(v => And(LessThanOrEqual(min, v), GreaterThanOrEqual(max, v))).reduceLeft(Or) - case ExprInSet(_, vs) => - val sortedValues = Literal( - ArrayData.toArrayData( - ArrayUtils.toArray( - vs.filter(_ != null).toArray.sorted(TypeUtils.getInterpretedOrdering(dataType)), - dataType)), - ArrayType(dataType, containsNull = false)) - LessThanOrEqual(ElementAt(sortedValues, SortedArrayLowerBound(sortedValues, min)), max) - // TODO: StartsWith, Like with constant prefix - } + Option(predicate).collect { + case ExprIsTrue(_) => max + case ExprIsFalse(_) => Not(min) + case ExprIsNotNull(_) => IsNotNull(min) + case ExprEqualTo(_, v) => And(LessThanOrEqual(min, v), GreaterThanOrEqual(max, v)) + case ExprEqualNullSafe(_, v) => + Or(IsNull(v), And(LessThanOrEqual(min, v), GreaterThanOrEqual(max, v))) + case ExprLessThan(_, v) => LessThan(min, v) + case ExprLessThanOrEqualTo(_, v) => LessThanOrEqual(min, v) + case ExprGreaterThan(_, v) => GreaterThan(max, v) + case ExprGreaterThanOrEqualTo(_, v) => GreaterThanOrEqual(max, v) + case ExprIn(_, vs) => + vs.map(v => And(LessThanOrEqual(min, v), GreaterThanOrEqual(max, v))).reduceLeft(Or) + case ExprInSet(_, vs) => + val sortedValues = Literal( + ArrayData.toArrayData( + ArrayUtils.toArray( + vs.filter(_ != null).toArray.sorted(TypeUtils.getInterpretedOrdering(dataType)), + dataType)), + ArrayType(dataType, containsNull = false)) + LessThanOrEqual(ElementAt(sortedValues, SortedArrayLowerBound(sortedValues, min)), max) + // TODO: StartsWith, Like with constant prefix + } } } diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/PartitionSketch.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/PartitionSketch.scala index 71db379ac..abe43de9d 100644 --- a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/PartitionSketch.scala +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/PartitionSketch.scala @@ -50,7 +50,7 @@ private[dataskipping] case class PartitionSketch( override def aggregateFunctions: Seq[Expression] = { val parser = SparkSession.getActiveSession.get.sessionState.sqlParser exprStrings.map { e => - FirstNullSafe(parser.parseExpression(e)).toAggregateExpression() + FirstNullSafe(parser.parseExpression(e)).toAggregateExpression } } diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/ValueListSketch.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/ValueListSketch.scala new file mode 100644 index 000000000..fa06f8663 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/sketches/ValueListSketch.scala @@ -0,0 +1,103 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.sketches + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{ArrayType, DataType} + +import com.microsoft.hyperspace.index.dataskipping.expressions._ +import com.microsoft.hyperspace.index.dataskipping.util.ArrayUtils + +/** + * Sketch based on distinct values for a given expression. + * + * This is not really a sketch, as it stores all distinct values for a given + * expression. It can be useful when the number of distinct values is expected to + * be small and each file tends to store only a subset of the values. + */ +case class ValueListSketch( + override val expr: String, + override val dataType: Option[DataType] = None) + extends SingleExprSketch[ValueListSketch](expr, dataType) { + override def name: String = "ValueList" + + override def withNewExpression(newExpr: (String, Option[DataType])): ValueListSketch = { + copy(expr = newExpr._1, dataType = newExpr._2) + } + + override def aggregateFunctions: Seq[Expression] = + new ArraySort(CollectSet(parsedExpr).toAggregateExpression()) :: Nil + + override def convertPredicate( + predicate: Expression, + resolvedExprs: Seq[Expression], + sketchValues: Seq[Expression], + nameMap: Map[ExprId, String], + valueExtractor: ExpressionExtractor): Option[Expression] = { + val valueList = sketchValues.head + val min = ElementAt(valueList, Literal(1)) + val max = ElementAt(valueList, Literal(-1)) + // TODO: Consider shared sketches + // HasNullSketch as described in MinMaxSketch.convertPredicate + // can be useful for ValueListSketch too, as it can be used to + // to optimize Not(EqualTo) as well as IsNull. + val resolvedExpr = resolvedExprs.head + val dataType = resolvedExpr.dataType + val exprExtractor = NormalizedExprExtractor(resolvedExpr, nameMap) + val ExprIsTrue = IsTrueExtractor(exprExtractor) + val ExprIsFalse = IsFalseExtractor(exprExtractor) + val ExprIsNotNull = IsNotNullExtractor(exprExtractor) + val ExprEqualTo = EqualToExtractor(exprExtractor, valueExtractor) + val ExprEqualNullSafe = EqualNullSafeExtractor(exprExtractor, valueExtractor) + val ExprLessThan = LessThanExtractor(exprExtractor, valueExtractor) + val ExprLessThanOrEqualTo = LessThanOrEqualExtractor(exprExtractor, valueExtractor) + val ExprGreaterThan = GreaterThanExtractor(exprExtractor, valueExtractor) + val ExprGreaterThanOrEqualTo = GreaterThanOrEqualExtractor(exprExtractor, valueExtractor) + val ExprIn = InExtractor(exprExtractor, valueExtractor) + val ExprInSet = InSetExtractor(exprExtractor) + def Empty(arr: Expression) = EqualTo(Size(arr), Literal(0)) + Option(predicate).collect { + case ExprIsTrue(_) => ArrayContains(valueList, Literal(true)) + case ExprIsFalse(_) => ArrayContains(valueList, Literal(false)) + case ExprIsNotNull(_) => Not(Empty(valueList)) + case ExprEqualTo(_, v) => SortedArrayContains(valueList, v) + case ExprEqualNullSafe(_, v) => Or(IsNull(v), SortedArrayContains(valueList, v)) + case Not(ExprEqualTo(_, v)) => + And( + IsNotNull(v), + Or( + GreaterThan(Size(valueList), Literal(1)), + Not(EqualTo(ElementAt(valueList, Literal(1)), v)))) + case ExprLessThan(_, v) => LessThan(min, v) + case ExprLessThanOrEqualTo(_, v) => LessThanOrEqual(min, v) + case ExprGreaterThan(_, v) => GreaterThan(max, v) + case ExprGreaterThanOrEqualTo(_, v) => GreaterThanOrEqual(max, v) + case ExprIn(_, vs) => + vs.map(v => SortedArrayContains(valueList, v)).reduceLeft(Or) + case ExprInSet(_, vs) => + SortedArrayContainsAny( + valueList, + ArrayUtils.toArray( + vs.filter(_ != null).toArray.sorted(TypeUtils.getInterpretedOrdering(dataType)), + dataType), + dataType) + // TODO: StartsWith, Like with constant prefix + } + } +} diff --git a/src/main/scala/com/microsoft/hyperspace/index/dataskipping/util/ReflectionHelper.scala b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/util/ReflectionHelper.scala new file mode 100644 index 000000000..aecc5b9c4 --- /dev/null +++ b/src/main/scala/com/microsoft/hyperspace/index/dataskipping/util/ReflectionHelper.scala @@ -0,0 +1,51 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.util + +import java.lang.reflect.Field + +trait ReflectionHelper { + def getAccesibleDeclaredField(clazz: Class[_], name: String): Field = { + val field = clazz.getDeclaredField(name) + field.setAccessible(true) + field + } + + def get(clazz: Class[_], fieldName: String, obj: Any): Any = { + getAccesibleDeclaredField(clazz, fieldName).get(obj) + } + + def getInt(clazz: Class[_], fieldName: String, obj: Any): Int = { + getAccesibleDeclaredField(clazz, fieldName).getInt(obj) + } + + def getLong(clazz: Class[_], fieldName: String, obj: Any): Long = { + getAccesibleDeclaredField(clazz, fieldName).getLong(obj) + } + + def set(clazz: Class[_], fieldName: String, obj: Any, value: Any): Unit = { + getAccesibleDeclaredField(clazz, fieldName).set(obj, value) + } + + def setInt(clazz: Class[_], fieldName: String, obj: Any, value: Int): Unit = { + getAccesibleDeclaredField(clazz, fieldName).setInt(obj, value) + } + + def setLong(clazz: Class[_], fieldName: String, obj: Any, value: Long): Unit = { + getAccesibleDeclaredField(clazz, fieldName).setLong(obj, value) + } +} diff --git a/src/test/scala-spark2/com/microsoft/hyperspace/util/SparkTestShims.scala b/src/test/scala-spark2/com/microsoft/hyperspace/util/SparkTestShims.scala index 532634a3b..fe04d7f58 100644 --- a/src/test/scala-spark2/com/microsoft/hyperspace/util/SparkTestShims.scala +++ b/src/test/scala-spark2/com/microsoft/hyperspace/util/SparkTestShims.scala @@ -16,6 +16,8 @@ package com.microsoft.hyperspace.util +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.execution.command.ExplainCommand @@ -32,4 +34,8 @@ object SparkTestShims { ExplainCommand(logicalPlan, extended = false) } } + + def fromRow[T](encoder: ExpressionEncoder[T], row: InternalRow): T = { + encoder.fromRow(row) + } } diff --git a/src/test/scala-spark3/com/microsoft/hyperspace/util/SparkTestShims.scala b/src/test/scala-spark3/com/microsoft/hyperspace/util/SparkTestShims.scala index 556b167a2..9e475f54e 100644 --- a/src/test/scala-spark3/com/microsoft/hyperspace/util/SparkTestShims.scala +++ b/src/test/scala-spark3/com/microsoft/hyperspace/util/SparkTestShims.scala @@ -16,6 +16,8 @@ package com.microsoft.hyperspace.util +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.execution.SimpleMode @@ -33,4 +35,8 @@ object SparkTestShims { ExplainCommand(logicalPlan, SimpleMode) } } + + def fromRow[T](encoder: ExpressionEncoder[T], row: InternalRow): T = { + encoder.createDeserializer().apply(row) + } } diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/BloomFilterTestUtils.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/BloomFilterTestUtils.scala new file mode 100644 index 000000000..ca1507c01 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/BloomFilterTestUtils.scala @@ -0,0 +1,40 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.index.dataskipping.expressions.BloomFilterEncoderProvider +import com.microsoft.hyperspace.util.SparkTestShims + +trait BloomFilterTestUtils { + def encodeExternal(bf: BloomFilter): Any = { + val bloomFilterEncoder = BloomFilterEncoderProvider.defaultEncoder + val data = bloomFilterEncoder.encode(bf) + val dataType = bloomFilterEncoder.dataType + dataType match { + case st: StructType => + SparkTestShims.fromRow(RowEncoder(st).resolveAndBind(), data.asInstanceOf[InternalRow]) + case _ => + val encoder = RowEncoder(StructType(StructField("x", dataType) :: Nil)).resolveAndBind() + SparkTestShims.fromRow(encoder, InternalRow(data)).get(0) + } + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexConfigTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexConfigTest.scala index 9596310fc..7bf8bbabc 100644 --- a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexConfigTest.scala +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexConfigTest.scala @@ -17,14 +17,15 @@ package com.microsoft.hyperspace.index.dataskipping import org.apache.hadoop.fs.Path -import org.apache.spark.sql.functions.{input_file_name, max, min} +import org.apache.spark.sql.functions.{array_sort, collect_set, input_file_name, max, min} import org.apache.spark.sql.types.{IntegerType, LongType, StringType} +import org.apache.spark.util.sketch.BloomFilter import com.microsoft.hyperspace.HyperspaceException import com.microsoft.hyperspace.index.IndexConstants import com.microsoft.hyperspace.index.dataskipping.sketches._ -class DataSkippingIndexConfigTest extends DataSkippingSuite { +class DataSkippingIndexConfigTest extends DataSkippingSuite with BloomFilterTestUtils { test("indexName returns the index name.") { val indexConfig = DataSkippingIndexConfig("myIndex", MinMaxSketch("A")) assert(indexConfig.indexName === "myIndex") @@ -85,6 +86,43 @@ class DataSkippingIndexConfigTest extends DataSkippingSuite { checkAnswer(indexData, withFileId(expectedSketchValues)) } + test("createIndex works correctly with a ValueListSketch.") { + val sourceData = + createSourceData(spark.range(100).selectExpr("cast(id / 10 as int) as A").toDF) + val indexConfig = DataSkippingIndexConfig("MyIndex", ValueListSketch("A")) + val (index, indexData) = indexConfig.createIndex(ctx, sourceData, Map()) + assert(index.sketches === Seq(ValueListSketch("A", Some(IntegerType)))) + val expectedSketchValues = sourceData + .groupBy(input_file_name().as(fileNameCol)) + .agg(array_sort(collect_set("A"))) + checkAnswer(indexData, withFileId(expectedSketchValues)) + assert(indexData.columns === Seq(IndexConstants.DATA_FILE_NAME_ID, "ValueList_A__0")) + } + + test("createIndex works correctly with a BloomFilterSketch.") { + val sourceData = createSourceData(spark.range(100).toDF("A")) + val indexConfig = DataSkippingIndexConfig("MyIndex", BloomFilterSketch("A", 0.001, 20)) + val (index, indexData) = indexConfig.createIndex(ctx, sourceData, Map()) + assert(index.sketches === Seq(BloomFilterSketch("A", 0.001, 20, Some(LongType)))) + val valuesAndBloomFilters = indexData + .collect() + .map { row => + val fileId = row.getAs[Long](IndexConstants.DATA_FILE_NAME_ID) + val filePath = fileIdTracker.getIdToFileMapping().toMap.apply(fileId) + val values = spark.read.parquet(filePath).collect().toSeq.map(_.getLong(0)) + val bfData = row.getAs[Any]("BloomFilter_A__0.001__20__0") + (values, bfData) + } + valuesAndBloomFilters.foreach { + case (values, bfData) => + val bf = BloomFilter.create(20, 0.001) + values.foreach(bf.put) + assert(bfData === encodeExternal(bf)) + } + assert( + indexData.columns === Seq(IndexConstants.DATA_FILE_NAME_ID, "BloomFilter_A__0.001__20__0")) + } + test("createIndex resolves column names and data types.") { val sourceData = createSourceData(spark.range(10).toDF("Foo")) val indexConfig = DataSkippingIndexConfig("MyIndex", MinMaxSketch("foO")) diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexIntegrationTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexIntegrationTest.scala index 11133a7cf..95246d5c1 100644 --- a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexIntegrationTest.scala +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexIntegrationTest.scala @@ -27,7 +27,6 @@ import com.microsoft.hyperspace.index.IndexConstants import com.microsoft.hyperspace.index.covering.CoveringIndexConfig import com.microsoft.hyperspace.index.dataskipping.sketches._ import com.microsoft.hyperspace.index.plans.logical.IndexHadoopFsRelation -import com.microsoft.hyperspace.shim.ExtractFileSourceScanExecRelation class DataSkippingIndexIntegrationTest extends DataSkippingSuite with IcebergTestUtils { import spark.implicits._ @@ -182,6 +181,178 @@ class DataSkippingIndexIntegrationTest extends DataSkippingSuite with IcebergTes checkIndexApplied(query, numParallelism + 1) } + test("BloomFilter index is applied for a filter query (EqualTo).") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(spark.range(100).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", BloomFilterSketch("A", 0.01, 10))) + def query: DataFrame = df.filter("A = 1") + checkIndexApplied(query, 1) + } + } + } + + test( + "BloomFilter index is applied for a filter query (EqualTo) " + + "where some source data files has only null values.") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(Seq[Integer](1, 2, 3, null, 5, null, 7, 8, 9, null).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", BloomFilterSketch("A", 0.01, 10))) + def query: DataFrame = df.filter("A = 1") + checkIndexApplied(query, 1) + } + } + } + + test("BloomFilter index is applied for a filter query (In).") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(spark.range(100).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", BloomFilterSketch("A", 0.01, 10))) + def query: DataFrame = df.filter("A in (1, 11, 19)") + checkIndexApplied(query, 2) + } + } + } + + test("BloomFilter index support string type.") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(('a' to 'z').map(_.toString).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", BloomFilterSketch("A", 0.01, 10))) + def query: DataFrame = df.filter("A = 'a'") + checkIndexApplied(query, 1) + } + } + } + + test("BloomFilter index does not support double type.") { + val df = createSourceData((0 until 10).map(_.toDouble).toDF("A")) + val ex = intercept[SparkException]( + hs.createIndex(df, DataSkippingIndexConfig("myind", BloomFilterSketch("A", 0.01, 10)))) + assert(ex.getCause().getMessage().contains("BloomFilter does not support DoubleType")) + } + + test("ValueList index is applied for a filter query (EqualTo).") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(spark.range(100).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("A"))) + def query: DataFrame = df.filter("A = 1") + checkIndexApplied(query, 1) + } + } + } + + test("ValueList index is applied for a filter query (Not(EqualTo)).") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(spark.range(10).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("A"))) + def query: DataFrame = df.filter("A != 1") + checkIndexApplied(query, 9) + } + } + } + + test( + "ValueList index is applied for a filter query (EqualTo) " + + "where some source data files has only null values.") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(Seq[Integer](1, 2, 3, null, 5, null, 7, 8, 9, null).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("A"))) + def query: DataFrame = df.filter("A = 1") + checkIndexApplied(query, 1) + } + } + } + + test("ValueList index is applied for a filter query (multiple EqualTo's).") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(spark.range(100).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("A"))) + def query: DataFrame = df.filter("A = 1 or A = 12 or A = 20") + checkIndexApplied(query, 3) + } + } + } + + test("ValueList index is applied for a filter query (In).") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(spark.range(100).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("A"))) + def query: DataFrame = df.filter("A in (20, 30, 10, 20)") + checkIndexApplied(query, 3) + } + } + } + + test("ValueList index is applied for a filter query (In) - string type.") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(Seq.range(0, 100).map(n => s"foo$n").toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("A"))) + def query: DataFrame = df.filter("A in ('foo31', 'foo12', 'foo1')") + checkIndexApplied(query, 3) + } + } + } + + test("ValueList index is applied for a filter query with UDF returning boolean.") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(spark.range(100).toDF("A")) + spark.udf.register("F", (a: Int) => a < 15) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("F(A)"))) + def query: DataFrame = df.filter("F(A)") + checkIndexApplied(query, 2) + } + } + } + + test( + "ValueList index is applied for a filter query with UDF " + + "taking two arguments and returning boolean.") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData(spark.range(100).selectExpr("id as A", "id * 2 as B")) + spark.udf.register("F", (a: Int, b: Int) => a < 15 || b > 190) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("F(A, B)"))) + def query: DataFrame = df.filter("F(A, B)") + checkIndexApplied(query, 3) + } + } + } + + test( + "ValueList index is applied for a filter query with UDF " + + "taking binary and returning boolean.") { + withAndWithoutCodegen { + withIndex("myind") { + val df = createSourceData( + Seq( + Array[Byte](0, 0, 0, 0), + Array[Byte](0, 1, 0, 1), + Array[Byte](1, 2, 3, 4), + Array[Byte](5, 6, 7, 8), + Array[Byte](32, 32, 32, 32), + Array[Byte](64, 64, 64, 64), + Array[Byte](1, 1, 1, 1), + Array[Byte](-128, -128, -128, -128), + Array[Byte](127, 127, 127, 127), + Array[Byte](-1, 1, 0, 0)).toDF("A")) + spark.udf.register("F", (a: Array[Byte]) => a.sum == 0) + hs.createIndex(df, DataSkippingIndexConfig("myind", ValueListSketch("F(A)"))) + def query: DataFrame = df.filter("F(A)") + checkIndexApplied(query, 4) + } + } + } + test( "DataSkippingIndex works correctly for CSV where the same source data files can be " + "interpreted differently.") { @@ -275,6 +446,20 @@ class DataSkippingIndexIntegrationTest extends DataSkippingSuite with IcebergTes } } + test( + "BloomFilter index can be applied without refresh when source files are deleted " + + "if hybrid scan is enabled.") { + withSQLConf( + IndexConstants.INDEX_HYBRID_SCAN_ENABLED -> "true", + IndexConstants.INDEX_HYBRID_SCAN_DELETED_RATIO_THRESHOLD -> "1") { + val df = createSourceData(spark.range(100).toDF("A")) + hs.createIndex(df, DataSkippingIndexConfig("myind", BloomFilterSketch("A", 0.001, 10))) + deleteFile(listFiles(dataPath()).filter(isParquet).head.getPath) + def query: DataFrame = spark.read.parquet(dataPath().toString).filter("A in (25, 50, 75)") + checkIndexApplied(query, 3) + } + } + test("Empty source data does not cause an error.") { val df = createSourceData(spark.range(0).toDF("A")) hs.createIndex(df, DataSkippingIndexConfig("myind", MinMaxSketch("A"))) diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexTest.scala index 414e3c786..db96141f0 100644 --- a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexTest.scala +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/DataSkippingIndexTest.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types.{IntegerType, StructType} import com.microsoft.hyperspace.HyperspaceException import com.microsoft.hyperspace.index.{Content, FileInfo, Index, IndexConstants} -import com.microsoft.hyperspace.index.dataskipping.sketches.MinMaxSketch +import com.microsoft.hyperspace.index.dataskipping.sketches.{MinMaxSketch, ValueListSketch} import com.microsoft.hyperspace.util.JsonUtils class DataSkippingIndexTest extends DataSkippingSuite { @@ -49,11 +49,27 @@ class DataSkippingIndexTest extends DataSkippingSuite { assert(index.indexedColumns === Seq("A", "B")) } + test("indexedColumns returns indexed columns of sketches (mixed sketch types).") { + val index = DataSkippingIndex(Seq(MinMaxSketch("A"), ValueListSketch("B")), emptyStructType) + assert(index.indexedColumns === Seq("A", "B")) + } + + test("indexedColumns returns indexed columns without duplicates.") { + val index = + DataSkippingIndex(Seq(MinMaxSketch("A"), ValueListSketch("A")), emptyStructType) + assert(index.indexedColumns === Seq("A")) + } + test("referencedColumns returns indexed columns of sketches.") { val index = DataSkippingIndex(Seq(MinMaxSketch("A"), MinMaxSketch("B")), emptyStructType) assert(index.referencedColumns === Seq("A", "B")) } + test("referencedColumns returns indexed columns of sketches (mixed sketch types).") { + val index = DataSkippingIndex(Seq(MinMaxSketch("A"), ValueListSketch("B")), emptyStructType) + assert(index.referencedColumns === Seq("A", "B")) + } + test( "withNewProperties returns a new index which copies the original index except the " + "properties.") { @@ -68,11 +84,22 @@ class DataSkippingIndexTest extends DataSkippingSuite { assert(index.statistics() === Map("sketches" -> "MinMax(A), MinMax(B)")) } + test("statistics returns a string-formatted list of sketches (mixed sketch types).") { + val index = DataSkippingIndex(Seq(MinMaxSketch("A"), ValueListSketch("B")), emptyStructType) + assert(index.statistics() === Map("sketches" -> "MinMax(A), ValueList(B)")) + } + test("canHandleDeletedFiles returns true.") { val index = DataSkippingIndex(Seq(MinMaxSketch("A")), emptyStructType) assert(index.canHandleDeletedFiles === true) } + test("Two indexes are equal if they have the same set of sketches.") { + val index1 = DataSkippingIndex(Seq(MinMaxSketch("A"), ValueListSketch("B")), emptyStructType) + val index2 = DataSkippingIndex(Seq(ValueListSketch("B"), MinMaxSketch("A")), emptyStructType) + assert(index1 === index2) + } + test("write writes the index data in a Parquet format.") { val sourceData = createSourceData(spark.range(100).toDF("A")) val indexConfig = DataSkippingIndexConfig("myIndex", MinMaxSketch("A")) @@ -253,11 +280,25 @@ class DataSkippingIndexTest extends DataSkippingSuite { assert(ds1.hashCode === ds2.hashCode) } + test("Indexes are equal if they have the same sketches and data types (mixed sketch types).") { + val ds1 = DataSkippingIndex(Seq(MinMaxSketch("A"), ValueListSketch("B")), emptyStructType) + val ds2 = DataSkippingIndex(Seq(ValueListSketch("B"), MinMaxSketch("A")), emptyStructType) + assert(ds1 === ds2) + assert(ds1.hashCode === ds2.hashCode) + } + test("Indexes are not equal to objects which are not indexes.") { val ds = DataSkippingIndex(Seq(MinMaxSketch("A")), emptyStructType) assert(ds !== "ds") } + test("Indexes are not equal if they don't have the same sketches.") { + val ds1 = DataSkippingIndex(Seq(MinMaxSketch("A")), emptyStructType) + val ds2 = DataSkippingIndex(Seq(ValueListSketch("A")), emptyStructType) + assert(ds1 !== ds2) + assert(ds1.hashCode !== ds2.hashCode) + } + test("Index can be serialized.") { val ds = DataSkippingIndex( Seq(MinMaxSketch("A", Some(IntegerType))), diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterAggTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterAggTest.scala new file mode 100644 index 000000000..3c38dd627 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterAggTest.scala @@ -0,0 +1,51 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.Column +import org.apache.spark.sql.functions.col +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.index.HyperspaceSuite +import com.microsoft.hyperspace.index.dataskipping.BloomFilterTestUtils + +class BloomFilterAggTest extends HyperspaceSuite with BloomFilterTestUtils { + import spark.implicits._ + + test("BloomFilterAgg computes BloomFilter correctly.") { + val n = 10000 + val m = 3000 + val fpp = 0.01 + + val agg = new Column(BloomFilterAgg(col("a").expr, m, fpp).toAggregateExpression()) + val df = spark + .range(n) + .toDF("a") + .filter(col("a") % 3 === 0) + .union(Seq[Integer](null).toDF("a")) + .agg(agg) + val bfData = df.collect()(0).getAs[Any](0) + + val expectedBf = BloomFilter.create(m, fpp) + for (i <- 0 until n) { + if (i % 3 == 0) { + expectedBf.put(i.toLong) + } + } + assert(bfData === encodeExternal(expectedBf)) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainAnyTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainAnyTest.scala new file mode 100644 index 000000000..3162a282c --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainAnyTest.scala @@ -0,0 +1,65 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.index.HyperspaceSuite +import com.microsoft.hyperspace.index.dataskipping.util.ArrayUtils.toArray + +class BloomFilterMightContainAnyTest extends HyperspaceSuite { + def test(values: Seq[Any], dataType: DataType): Unit = { + val bf = BloomFilter.create(values.length, 0.01) + val bfData = Literal( + BloomFilterEncoderProvider.defaultEncoder.encode(bf), + BloomFilterEncoderProvider.defaultEncoder.dataType) + for (k <- 1 to 3) { + values.grouped(k).foreach { vs => + val valuesArray = toArray(values.map(Literal.create(_, dataType).eval()), dataType) + assert( + BloomFilterMightContainAny(bfData, valuesArray, dataType).eval() === vs.contains( + bf.mightContain(_))) + } + } + } + + test("BloomFilterMightContainAny works correctly for an int array.") { + test((0 until 1000).map(_ * 2), IntegerType) + } + + test("BloomFilterMightContainAny works correctly for a long array.") { + test((0L until 1000L).map(_ * 2), LongType) + } + + test("BloomFilterMightContainAny works correctly for a byte array.") { + test(Seq(0, 1, 3, 7, 15, 31, 63, 127).map(_.toByte), ByteType) + } + + test("BloomFilterMightContainAny works correctly for a short array.") { + test(Seq(1, 3, 5, 7, 9).map(_.toShort), ShortType) + } + + test("BloomFilterMightContainAny works correctly for a string array.") { + test(Seq("hello", "world", "foo", "bar"), StringType) + } + + test("BloomFilterMightContainAny works correctly for a binary array.") { + test(Seq(Array[Byte](1, 2), Array[Byte](3, 4)), BinaryType) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainTest.scala new file mode 100644 index 000000000..7e354163f --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterMightContainTest.scala @@ -0,0 +1,68 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.index.HyperspaceSuite + +class BloomFilterMightContainTest extends HyperspaceSuite with ExpressionEvalHelper { + def test(values: Seq[Any], dataType: DataType): Unit = { + val bf = BloomFilter.create(values.length, 0.01) + val bfData = Literal( + BloomFilterEncoderProvider.defaultEncoder.encode(bf), + BloomFilterEncoderProvider.defaultEncoder.dataType) + values.foreach { v => + val lit = Literal.create(v, dataType) + checkEvaluation(BloomFilterMightContain(bfData, lit), bf.mightContain(v)) + } + } + + test("BloomFilterMightContain works correctly for an int array.") { + test((0 until 50).map(_ * 2), IntegerType) + } + + test("BloomFilterMightContain works correctly for a long array.") { + test((0L until 50L).map(_ * 2), LongType) + } + + test("BloomFilterMightContain works correctly for a byte array.") { + test(Seq(0, 1, 3, 7, 15, 31, 63, 127).map(_.toByte), ByteType) + } + + test("BloomFilterMightContain works correctly for a short array.") { + test(Seq(1, 3, 5, 7, 9).map(_.toShort), ShortType) + } + + test("BloomFilterMightContain works correctly for a string array.") { + test(Seq("hello", "world", "foo", "bar"), StringType) + } + + test("BloomFilterMightContain works correctly for a binary array.") { + test(Seq(Array[Byte](1, 2), Array[Byte](3, 4)), BinaryType) + } + + test("BloomFilterMightContain returns null if the value is null.") { + val bf = BloomFilter.create(10, 0.01) + val bfData = Literal( + BloomFilterEncoderProvider.defaultEncoder.encode(bf), + BloomFilterEncoderProvider.defaultEncoder.dataType) + checkEvaluation(BloomFilterMightContain(bfData, Literal(null, IntegerType)), null) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterUtilsTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterUtilsTest.scala new file mode 100644 index 000000000..f67850262 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/BloomFilterUtilsTest.scala @@ -0,0 +1,149 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.HyperspaceException +import com.microsoft.hyperspace.index.HyperspaceSuite + +class BloomFilterUtilsTest extends HyperspaceSuite { + + def testPut(value: Any, dataType: DataType): Unit = { + val bf = BloomFilter.create(100, 0.01) + BloomFilterUtils.put(bf, value, dataType) + val expected = BloomFilter.create(100, 0.01) + expected.put(value) + assert(bf === expected) + } + + test("put: long") { + testPut(10L, LongType) + } + + test("put: int") { + testPut(10, IntegerType) + } + + test("put: byte") { + testPut(10.toByte, ByteType) + } + + test("put: short") { + testPut(10.toShort, ShortType) + } + + test("put: string") { + val value = UTF8String.fromString("hello") + val bf = BloomFilter.create(100, 0.01) + BloomFilterUtils.put(bf, value, StringType) + val expected = BloomFilter.create(100, 0.01) + expected.put(value.getBytes) + assert(bf === expected) + } + + test("put: binary") { + testPut(Array[Byte](1, 2, 3, 4), BinaryType) + } + + test("put throws an exception for unsupported types.") { + val ex = intercept[HyperspaceException](testPut(3.14, DoubleType)) + assert(ex.msg.contains("BloomFilter does not support DoubleType")) + } + + def testMightContain(value: Any, value2: Any, dataType: DataType): Unit = { + val bf = BloomFilter.create(100, 0.01) + BloomFilterUtils.put(bf, value, dataType) + assert(BloomFilterUtils.mightContain(bf, value, dataType) === bf.mightContain(value)) + assert(BloomFilterUtils.mightContain(bf, value2, dataType) === bf.mightContain(value2)) + } + + test("mightContain: int") { + testMightContain(1, 0, IntegerType) + } + + test("mightContain: long") { + testMightContain(1L, 0L, LongType) + } + + test("mightContain: byte") { + testMightContain(1.toByte, 0.toByte, ByteType) + } + + test("mightContain: short") { + testMightContain(1.toShort, 0.toShort, ShortType) + } + + test("mightContain: string") { + val value = UTF8String.fromString("hello") + val value2 = UTF8String.fromString("world") + val bf = BloomFilter.create(100, 0.01) + BloomFilterUtils.put(bf, value, StringType) + assert( + BloomFilterUtils.mightContain(bf, value, StringType) === bf.mightContain(value.getBytes)) + assert( + BloomFilterUtils.mightContain(bf, value2, StringType) === bf.mightContain(value2.getBytes)) + } + + test("mightContain: binary") { + testMightContain(Array[Byte](1, 2), Array[Byte](3, 4), BinaryType) + } + + test("mightContain throws an exception for unsupported types.") { + val bf = BloomFilter.create(100, 0.01) + val ex = intercept[HyperspaceException](BloomFilterUtils.mightContain(bf, 3.14, DoubleType)) + assert(ex.msg.contains("BloomFilter does not support DoubleType")) + } + + test("mightContainCodegen: int") { + val code = BloomFilterUtils.mightContainCodegen("fb", "vl", IntegerType) + assert(code === "fb.mightContainLong(vl)") + } + + test("mightContainCodegen: long") { + val code = BloomFilterUtils.mightContainCodegen("fb", "vl", LongType) + assert(code === "fb.mightContainLong(vl)") + } + + test("mightContainCodegen: byte") { + val code = BloomFilterUtils.mightContainCodegen("fb", "vl", ByteType) + assert(code === "fb.mightContainLong(vl)") + } + + test("mightContainCodegen: short") { + val code = BloomFilterUtils.mightContainCodegen("fb", "vl", ShortType) + assert(code === "fb.mightContainLong(vl)") + } + + test("mightContainCodegen: string") { + val code = BloomFilterUtils.mightContainCodegen("fb", "vl", StringType) + assert(code === "fb.mightContainBinary((vl).getBytes())") + } + + test("mightContainCodegen: binary") { + val code = BloomFilterUtils.mightContainCodegen("fb", "vl", BinaryType) + assert(code === "fb.mightContainBinary(vl)") + } + + test("mightContainCodegen throws an exception for unsupported types.") { + val ex = + intercept[HyperspaceException](BloomFilterUtils.mightContainCodegen("fb", "vl", DoubleType)) + assert(ex.msg.contains("BloomFilter does not support DoubleType")) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/FastBloomFilterEncoderTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/FastBloomFilterEncoderTest.scala new file mode 100644 index 000000000..08a030dd6 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/FastBloomFilterEncoderTest.scala @@ -0,0 +1,38 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.index.HyperspaceSuite + +class FastBloomFilterEncoderTest extends HyperspaceSuite { + test("encode and decode restores empty bloom filter.") { + val bf = BloomFilter.create(100, 0.01) + val data = FastBloomFilterEncoder.encode(bf) + val bf2 = FastBloomFilterEncoder.decode(data) + assert(bf2 === bf) + } + + test("encode and decode restores the original bloom filter.") { + val bf = BloomFilter.create(100, 0.01) + bf.put(42) + val data = FastBloomFilterEncoder.encode(bf) + val bf2 = FastBloomFilterEncoder.decode(data) + assert(bf2 === bf) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsAnyTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsAnyTest.scala new file mode 100644 index 000000000..c76c55a34 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsAnyTest.scala @@ -0,0 +1,86 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ + +import com.microsoft.hyperspace.index.HyperspaceSuite +import com.microsoft.hyperspace.index.dataskipping.ArrayTestUtils +import com.microsoft.hyperspace.index.dataskipping.util.ArrayUtils.toArray + +class SortedArrayContainsAnyTest + extends HyperspaceSuite + with ArrayTestUtils + with ExpressionEvalHelper { + def test(arr1: Expression, arr2: Expression, expected: Boolean): Unit = { + val elementType = arr2.dataType.asInstanceOf[ArrayType].elementType + checkEvaluation( + SortedArrayContainsAny( + arr1, + toArray( + arr2.asInstanceOf[Literal].value.asInstanceOf[ArrayData].toObjectArray(elementType), + elementType), + elementType), + expected) + } + + test("SortedArrayContainsAny returns true if two arrays intersect.") { + val array1 = createArray(Seq.range(0, 100000).map(_ * 2), IntegerType) + val array2 = createArray(Seq(0), IntegerType) + val array3 = createArray(Seq(2), IntegerType) + val array4 = createArray(Seq(199998), IntegerType) + val array5 = createArray(Seq(2, 4, 5), IntegerType) + val array6 = createArray(Seq(1, 3, 199998), IntegerType) + val array7 = createArray(Seq(-1, 100000), IntegerType) + val array8 = createArray(Seq(100000, 200001), IntegerType) + test(array1, array2, true) + test(array1, array3, true) + test(array1, array4, true) + test(array1, array5, true) + test(array1, array6, true) + test(array1, array7, true) + test(array1, array8, true) + test(array3, array5, true) + test(array4, array6, true) + test(array7, array8, true) + } + + test("SortedArrayContainsAny returns false if two arrays don't intersect.") { + val array1 = createArray(Seq.range(0, 100000).map(_ * 2), IntegerType) + val array2 = createArray(Seq(), IntegerType) + val array3 = createArray(Seq(-1), IntegerType) + val array4 = createArray(Seq(1), IntegerType) + val array5 = createArray(Seq(200001), IntegerType) + val array6 = createArray(Seq(1, 3, 199999), IntegerType) + val array7 = createArray(Seq(-1, 100001), IntegerType) + val array8 = createArray(Seq(49999, 100001), IntegerType) + val array9 = createArray(Seq(-3, 1, 1), IntegerType) + test(array1, array2, false) + test(array1, array3, false) + test(array1, array4, false) + test(array1, array5, false) + test(array1, array6, false) + test(array1, array7, false) + test(array1, array9, false) + test(array2, array3, false) + test(array3, array4, false) + test(array5, array6, false) + test(array6, array7, false) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsTest.scala new file mode 100644 index 000000000..bbd8546a3 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/SortedArrayContainsTest.scala @@ -0,0 +1,85 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +import com.microsoft.hyperspace.index.HyperspaceSuite +import com.microsoft.hyperspace.index.dataskipping.ArrayTestUtils + +class SortedArrayContainsTest + extends HyperspaceSuite + with ArrayTestUtils + with ExpressionEvalHelper { + test("SortedArrayContains works correctly for an empty array.") { + val array = createArray(Nil, IntegerType) + checkEvaluation(SortedArrayContains(array, Literal(0, IntegerType)), false) + } + + test("SortedArrayContains works correctly for a array of size 1.") { + val array = createArray(Seq(1), IntegerType) + checkEvaluation(SortedArrayContains(array, Literal(0, IntegerType)), false) + checkEvaluation(SortedArrayContains(array, Literal(1, IntegerType)), true) + checkEvaluation(SortedArrayContains(array, Literal(2, IntegerType)), false) + } + + test("SortedArrayContains works correctly for a array of size 2.") { + val array = createArray(Seq(1, 3), IntegerType) + checkEvaluation(SortedArrayContains(array, Literal(0, IntegerType)), false) + checkEvaluation(SortedArrayContains(array, Literal(1, IntegerType)), true) + checkEvaluation(SortedArrayContains(array, Literal(2, IntegerType)), false) + checkEvaluation(SortedArrayContains(array, Literal(3, IntegerType)), true) + checkEvaluation(SortedArrayContains(array, Literal(4, IntegerType)), false) + } + + test("SortedArrayContains works correctly for an int array.") { + val values = Seq.range(0, 50).map(_ * 2) + val array = createArray(values, IntegerType) + values.foreach(v => + checkEvaluation(SortedArrayContains(array, Literal(v, IntegerType)), true)) + checkEvaluation(SortedArrayContains(array, Literal(-10, IntegerType)), false) + checkEvaluation(SortedArrayContains(array, Literal(1, IntegerType)), false) + checkEvaluation(SortedArrayContains(array, Literal(49, IntegerType)), false) + checkEvaluation(SortedArrayContains(array, Literal(1000, IntegerType)), false) + } + + test("SortedArrayContains works correctly for a long array.") { + val values = Seq.range(0L, 50L).map(_ * 2) + val array = createArray(values, LongType) + values.foreach(v => checkEvaluation(SortedArrayContains(array, Literal(v, LongType)), true)) + checkEvaluation(SortedArrayContains(array, Literal(-10L, LongType)), false) + checkEvaluation(SortedArrayContains(array, Literal(1L, LongType)), false) + checkEvaluation(SortedArrayContains(array, Literal(49L, LongType)), false) + checkEvaluation(SortedArrayContains(array, Literal(1000L, LongType)), false) + } + + test("SortedArrayContains works correctly for a string array.") { + val values = Seq("hello", "world", "foo", "bar", "footrix").sorted + val array = createArray(values, StringType) + values.foreach(v => + checkEvaluation(SortedArrayContains(array, Literal.create(v, StringType)), true)) + checkEvaluation(SortedArrayContains(array, Literal.create("abc", StringType)), false) + checkEvaluation(SortedArrayContains(array, Literal.create("fooo", StringType)), false) + checkEvaluation(SortedArrayContains(array, Literal.create("zoo", StringType)), false) + } + + test("SortedArrayContains returns null if the value is null.") { + val array = createArray(Seq(1), IntegerType) + checkEvaluation(SortedArrayContains(array, Literal(null, IntegerType)), null) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/StreamBloomFilterEncoderTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/StreamBloomFilterEncoderTest.scala new file mode 100644 index 000000000..b549285ca --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/expressions/StreamBloomFilterEncoderTest.scala @@ -0,0 +1,38 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.expressions + +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.index.HyperspaceSuite + +class StreamBloomFilterEncoderTest extends HyperspaceSuite { + test("encode and decode restores empty bloom filter.") { + val bf = BloomFilter.create(100, 0.01) + val data = StreamBloomFilterEncoder.encode(bf) + val bf2 = StreamBloomFilterEncoder.decode(data) + assert(bf2 === bf) + } + + test("encode and decode restores the original bloom filter.") { + val bf = BloomFilter.create(100, 0.01) + bf.put(42) + val data = StreamBloomFilterEncoder.encode(bf) + val bf2 = StreamBloomFilterEncoder.decode(data) + assert(bf2 === bf) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/rules/ApplyDataSkippingIndexTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/rules/ApplyDataSkippingIndexTest.scala index 98dca4f18..7b9d19c47 100644 --- a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/rules/ApplyDataSkippingIndexTest.scala +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/rules/ApplyDataSkippingIndexTest.scala @@ -106,6 +106,9 @@ class ApplyDataSkippingIndexTest extends DataSkippingSuite { 17, null, 19, 20).toDF("A")), "source [A:Int] with nulls") + def dataIS: SourceData = + SourceData(() => createSourceData(spark.range(10).toDF("A")), "source [A:Int] small") + def dataIIP: SourceData = SourceData( () => @@ -217,6 +220,7 @@ class ApplyDataSkippingIndexTest extends DataSkippingSuite { Param(dataI, "A is not null", MinMaxSketch("A"), 10), Param(dataI, "!(A is not null)", MinMaxSketch("A"), 10), Param(dataI, "A <=> 10", MinMaxSketch("A"), 1), + Param(dataI, "A <=> 10", ValueListSketch("A"), 1), Param(dataI, "10 <=> A", MinMaxSketch("A"), 1), Param(dataI, "A <=> null", MinMaxSketch("A"), 10), Param(dataI, "A <25", MinMaxSketch("A"), 3), @@ -237,9 +241,43 @@ class ApplyDataSkippingIndexTest extends DataSkippingSuite { Param(dataI, "!(A < 20)", MinMaxSketch("A"), 8), Param(dataI, "not (A not in (1, 2, 3))", MinMaxSketch("A"), 1), Param(dataS, "A < 'foo'", MinMaxSketch("A"), 1), + Param(dataS, "A in ('foo1', 'foo9')", ValueListSketch("A"), 2), + Param(dataS, "A in ('foo1', 'foo5', 'foo9')", BloomFilterSketch("A", 0.01, 10), 3), + Param( + dataS, + "A in ('foo1','goo1','hoo1','i1','j','k','l','m','n','o','p')", + BloomFilterSketch("A", 0.01, 10), + 1), + Param(dataI, "A = 10", ValueListSketch("A"), 1), + Param(dataI, "10 = A", ValueListSketch("a"), 1), + Param(dataIS, "A != 5", ValueListSketch("A"), 9), + Param(dataIS, "5 != A", ValueListSketch("A"), 9), + Param(dataIN, "a!=9", ValueListSketch("a"), 6), + Param(dataIN, "9 != A", ValueListSketch("A"), 6), + Param(dataI, "A != 5", ValueListSketch("A"), 10), + Param(dataI, "A < 34", ValueListSketch("A"), 4), + Param(dataI, "34 > A", ValueListSketch("A"), 4), + Param(dataIN, "A < 9", ValueListSketch("a"), 2), + Param(dataIN, "9 > A", ValueListSketch("A"), 2), + Param(dataI, "A = 10", BloomFilterSketch("A", 0.01, 10), 1), + Param(dataI, "A <=> 20", BloomFilterSketch("A", 0.01, 10), 1), + Param(dataI, "A <=> null", BloomFilterSketch("A", 0.01, 10), 10), + Param(dataI, "A in (2, 3, 5, 7, 11, 13, 17, 19)", BloomFilterSketch("A", 0.001, 10), 2), + Param( + dataI, + "A in (0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20)", + BloomFilterSketch("A", 0.001, 10), + 3), + Param( + dataIN, + "A in (0,1,10,100,1000,10000,100000,1000000,-1,-2,-3,-4,-5,-6,-7,-8,null)", + BloomFilterSketch("A", 0.001, 10), + 1), + Param(dataI, "A != 10", BloomFilterSketch("A", 0.001, 10), 10), Param(dataI, "a = 10", MinMaxSketch("A"), 1), Param(dataI, "A = 10", MinMaxSketch("a"), 1), Param(dataI, "A in (1, 2, 3, null, 10)", MinMaxSketch("A"), 2), + Param(dataI, "A in (2, 3, 10, 99)", ValueListSketch("a"), 3), Param(dataI, "A in (10,9,8,7,6,5,4,3,2,1,50,49,48,47,46,45)", MinMaxSketch("A"), 4), Param(dataS, "A in ('foo1', 'foo5', 'foo9')", MinMaxSketch("A"), 3), Param( @@ -255,6 +293,21 @@ class ApplyDataSkippingIndexTest extends DataSkippingSuite { "A in (x'00',x'01',x'02',x'03',x'04',x'05',x'06',x'07',x'08',x'09',x'0a',x'20202020')", MinMaxSketch("A"), 1), + Param(dataI, "A in (10,9,8,7,6,5,4,3,2,1,50,49,48,47,46,45)", ValueListSketch("A"), 4), + Param(dataS, "A in ('foo1', 'foo5', 'foo9')", ValueListSketch("A"), 3), + Param( + dataS, + "A in ('foo1','a','b','c','d','e','f','g','h','i','j','k')", + ValueListSketch("A"), + 1), + Param(dataD, "A in (1,2,3,15,16,17)", ValueListSketch("A"), 2), + Param(dataD, "A in (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16)", ValueListSketch("A"), 2), + Param(dataB, "A in (x'00000000', x'0001', x'0002', x'05060708')", ValueListSketch("A"), 2), + Param( + dataB, + "A in (x'00',x'01',x'02',x'03',x'04',x'05',x'06',x'07',x'08',x'09',x'0a',x'20202020')", + ValueListSketch("A"), + 1), Param(dataI, "A BETWEEN 27 AND 51", MinMaxSketch("A"), 4), Param(dataI, "IF(A=1,2,3)=2", MinMaxSketch("A"), 10), Param(dataII, "A = 10 OR B = 50", Seq(MinMaxSketch("A"), MinMaxSketch("B")), 2), @@ -265,6 +318,11 @@ class ApplyDataSkippingIndexTest extends DataSkippingSuite { Param(dataII, "A < 30 and B > 20", MinMaxSketch("A"), 3), Param(dataII, "A < 30 and b > 40", Seq(MinMaxSketch("a"), MinMaxSketch("B")), 1), Param(dataII, "A = 10 and B = 90", Seq(MinMaxSketch("A"), MinMaxSketch("B")), 0), + Param( + dataII, + "A < 31 and B in (1, 2, 11, 12, 21, 22)", + Seq(MinMaxSketch("A"), BloomFilterSketch("B", 0.001, 10)), + 2), Param(dataIN, "A is not null", MinMaxSketch("A"), 7), Param(dataIN, "!(A <=> null)", MinMaxSketch("A"), 7), Param(dataIN, "A = 2", MinMaxSketch("A"), 1), @@ -299,6 +357,12 @@ class ApplyDataSkippingIndexTest extends DataSkippingSuite { MinMaxSketch("is_less_than_23(A)"), 8, () => spark.udf.register("is_less_than_23", (a: Int) => a < 23)), + Param( + dataI, + "!is_less_than_23(A)", + ValueListSketch("is_less_than_23(A)"), + 8, + () => spark.udf.register("is_less_than_23", (a: Int) => a < 23)), Param( dataII, "A < 50 and F(A,B) < 20", @@ -316,7 +380,13 @@ class ApplyDataSkippingIndexTest extends DataSkippingSuite { "IF(A IS NULL,NULL,F(A))=2", MinMaxSketch("A"), 10, - () => spark.udf.register("F", (a: Int) => a * 2))).foreach { + () => spark.udf.register("F", (a: Int) => a * 2)), + Param( + dataB, + "F(A)", + ValueListSketch("f(A)"), + 4, + () => spark.udf.register("F", (a: Array[Byte]) => a.sum == 0))).foreach { case Param(sourceData, filter, sketches, numExpectedFiles, setup) => test( s"applyIndex works as expected for ${sourceData.description}: " + diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/sketches/BloomFilterSketchTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/sketches/BloomFilterSketchTest.scala new file mode 100644 index 000000000..876003f75 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/sketches/BloomFilterSketchTest.scala @@ -0,0 +1,162 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.sketches + +import org.apache.spark.sql.{Column, QueryTest} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.sketch.BloomFilter + +import com.microsoft.hyperspace.index.HyperspaceSuite +import com.microsoft.hyperspace.index.dataskipping.BloomFilterTestUtils +import com.microsoft.hyperspace.index.dataskipping.expressions._ + +class BloomFilterSketchTest extends QueryTest with HyperspaceSuite with BloomFilterTestUtils { + import spark.implicits._ + + val valueExtractor = AttrValueExtractor(Map.empty) + + test("indexedColumns returns the indexed column.") { + val sketch = BloomFilterSketch("A", 0.01, 100) + assert(sketch.indexedColumns === Seq("A")) + } + + test("referencedColumns returns the indexed column.") { + val sketch = BloomFilterSketch("A", 0.01, 100) + assert(sketch.referencedColumns === Seq("A")) + } + + test( + "aggregateFunctions returns an aggregation function that collects values in a bloom filter.") { + val sketch = BloomFilterSketch("A", 0.01, 100) + val aggrs = sketch.aggregateFunctions.map(new Column(_)) + assert(aggrs.length === 1) + val data = Seq(1, -1, 10, 2, 4, 2, 0, 10) + val bf = BloomFilter.create(100, 0.01) + data.foreach(bf.put) + val bfData = data.toDF("A").select(aggrs.head).collect()(0).getAs[Any](0) + assert(bfData === encodeExternal(bf)) + } + + test("toString returns a reasonable string.") { + val sketch = BloomFilterSketch("A", 0.01, 100) + assert(sketch.toString === "BloomFilter(A, 0.01, 100)") + } + + test("Two sketches are equal if their columns are equal.") { + assert(BloomFilterSketch("A", 0.01, 100) === BloomFilterSketch("A", 0.001, 1000)) + assert(BloomFilterSketch("A", 0.01, 100) !== BloomFilterSketch("a", 0.01, 100)) + assert(BloomFilterSketch("b", 0.01, 100) !== BloomFilterSketch("B", 0.01, 100)) + assert(BloomFilterSketch("B", 0.01, 100) === BloomFilterSketch("B", 0.001, 1000)) + } + + test("hashCode is reasonably implemented.") { + assert( + BloomFilterSketch("A", 0.01, 100).hashCode === BloomFilterSketch("A", 0.001, 1000).hashCode) + assert( + BloomFilterSketch("A", 0.01, 100).hashCode !== BloomFilterSketch("a", 0.001, 1000).hashCode) + } + + test("covertPredicate converts EqualTo.") { + val sketch = BloomFilterSketch("A", 0.01, 100) + val predicate = EqualTo(AttributeReference("A", IntegerType)(ExprId(0)), Literal(42)) + val sketchValues = Seq(UnresolvedAttribute("bf")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(BloomFilterMightContain(sketchValues(0), Literal(42))) + assert(result === expected) + } + + test("covertPredicate converts EqualTo - string type.") { + val sketch = BloomFilterSketch("A", 0.01, 100) + val predicate = + EqualTo(AttributeReference("A", StringType)(ExprId(0)), Literal.create("hello", StringType)) + val sketchValues = Seq(UnresolvedAttribute("bf")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", StringType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = + Some(BloomFilterMightContain(sketchValues(0), Literal.create("hello", StringType))) + assert(result === expected) + } + + test("covertPredicate converts In.") { + val sketch = BloomFilterSketch("A", 0.01, 100) + val predicate = + In(AttributeReference("A", IntegerType)(ExprId(0)), Seq(Literal(42), Literal(23))) + val sketchValues = Seq(UnresolvedAttribute("bf")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some( + Or( + BloomFilterMightContain(sketchValues(0), Literal(42)), + BloomFilterMightContain(sketchValues(0), Literal(23)))) + assert(result === expected) + } + + test("covertPredicate converts In - string type.") { + val sketch = BloomFilterSketch("A", 0.01, 100) + val predicate = + In( + AttributeReference("A", StringType)(ExprId(0)), + Seq(Literal.create("hello", StringType), Literal.create("world", StringType))) + val sketchValues = Seq(UnresolvedAttribute("bf")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", StringType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some( + Or( + BloomFilterMightContain(sketchValues(0), Literal.create("hello", StringType)), + BloomFilterMightContain(sketchValues(0), Literal.create("world", StringType)))) + assert(result === expected) + } + + test("covertPredicate does not convert Not(EqualTo(, )).") { + val sketch = BloomFilterSketch("A", 0.01, 100) + val predicate = Not(EqualTo(AttributeReference("A", IntegerType)(ExprId(0)), Literal(42))) + val sketchValues = Seq(UnresolvedAttribute("bf")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = None + assert(result === expected) + } +} diff --git a/src/test/scala/com/microsoft/hyperspace/index/dataskipping/sketches/ValueListSketchTest.scala b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/sketches/ValueListSketchTest.scala new file mode 100644 index 000000000..484fa0af1 --- /dev/null +++ b/src/test/scala/com/microsoft/hyperspace/index/dataskipping/sketches/ValueListSketchTest.scala @@ -0,0 +1,267 @@ +/* + * Copyright (2021) The Hyperspace Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.microsoft.hyperspace.index.dataskipping.sketches + +import org.apache.spark.sql.{Column, QueryTest} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +import com.microsoft.hyperspace.index.HyperspaceSuite +import com.microsoft.hyperspace.index.dataskipping.expressions._ + +class ValueListSketchTest extends QueryTest with HyperspaceSuite { + import spark.implicits._ + + val valueExtractor = AttrValueExtractor(Map.empty) + + test("indexedColumns returns the indexed column.") { + val sketch = ValueListSketch("A") + assert(sketch.indexedColumns === Seq("A")) + } + + test("referencedColumns returns the indexed column.") { + val sketch = ValueListSketch("A") + assert(sketch.referencedColumns === Seq("A")) + } + + test("aggregateFunctions returns an aggregation function that collects all unique values.") { + val sketch = ValueListSketch("A") + val aggrs = sketch.aggregateFunctions.map(new Column(_)) + val data = Seq(1, -1, 10, 2, 4, 2, 0, 10).toDF("A") + checkAnswer(data.select(aggrs: _*), Seq(Array(-1, 0, 1, 2, 4, 10)).toDF) + } + + test("toString returns a reasonable string.") { + val sketch = ValueListSketch("A") + assert(sketch.toString === "ValueList(A)") + } + + test("Two sketches are equal if their columns are equal.") { + assert(ValueListSketch("A") === ValueListSketch("A")) + assert(ValueListSketch("A") !== ValueListSketch("a")) + assert(ValueListSketch("b") !== ValueListSketch("B")) + assert(ValueListSketch("B") === ValueListSketch("B")) + } + + test("hashCode is reasonably implemented.") { + assert(ValueListSketch("A").hashCode === ValueListSketch("A").hashCode) + assert(ValueListSketch("A").hashCode !== ValueListSketch("a").hashCode) + } + + test("covertPredicate converts EqualTo(, ).") { + val sketch = ValueListSketch("A") + val predicate = EqualTo(AttributeReference("A", IntegerType)(ExprId(0)), Literal(42)) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(SortedArrayContains(sketchValues(0), Literal(42))) + assert(result === expected) + } + + test("covertPredicate converts EqualTo(, ).") { + val sketch = ValueListSketch("A") + val predicate = EqualTo(Literal(42), AttributeReference("A", IntegerType)(ExprId(0))) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(SortedArrayContains(sketchValues(0), Literal(42))) + assert(result === expected) + } + + test("covertPredicate converts EqualTo(, ) - string type.") { + val sketch = ValueListSketch("A") + val predicate = + EqualTo(AttributeReference("A", StringType)(ExprId(0)), Literal.create("hello", StringType)) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", StringType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(SortedArrayContains(sketchValues(0), Literal.create("hello", StringType))) + assert(result === expected) + } + + test("covertPredicate converts EqualTo(, ) - double type.") { + val sketch = ValueListSketch("A") + val predicate = + EqualTo(AttributeReference("A", StringType)(ExprId(0)), Literal(3.14, DoubleType)) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", StringType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(SortedArrayContains(sketchValues(0), Literal(3.14, DoubleType))) + assert(result === expected) + } + + test("covertPredicate converts Not(EqualTo(, )).") { + val sketch = ValueListSketch("A") + val predicate = Not(EqualTo(AttributeReference("A", IntegerType)(ExprId(0)), Literal(42))) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some( + And( + IsNotNull(Literal(42)), + Or( + GreaterThan(Size(sketchValues(0)), Literal(1)), + Not(EqualTo(ElementAt(sketchValues(0), Literal(1)), Literal(42)))))) + assert(result === expected) + } + + test("covertPredicate converts LessThan.") { + val sketch = ValueListSketch("A") + val predicate = LessThan(AttributeReference("A", IntegerType)(ExprId(0)), Literal(42)) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(LessThan(ElementAt(sketchValues(0), Literal(1)), Literal(42))) + assert(result === expected) + } + + test("covertPredicate converts LessThan - string type.") { + val sketch = ValueListSketch("A") + val predicate = LessThan( + AttributeReference("A", StringType)(ExprId(0)), + Literal.create("hello", StringType)) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", StringType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = + Some(LessThan(ElementAt(sketchValues(0), Literal(1)), Literal.create("hello", StringType))) + assert(result === expected) + } + + test("covertPredicate converts LessThanOrEqual.") { + val sketch = ValueListSketch("A") + val predicate = LessThanOrEqual(AttributeReference("A", IntegerType)(ExprId(0)), Literal(42)) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(LessThanOrEqual(ElementAt(sketchValues(0), Literal(1)), Literal(42))) + assert(result === expected) + } + + test("covertPredicate converts GreaterThan.") { + val sketch = ValueListSketch("A") + val predicate = GreaterThan(AttributeReference("A", IntegerType)(ExprId(0)), Literal(42)) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(GreaterThan(ElementAt(sketchValues(0), Literal(-1)), Literal(42))) + assert(result === expected) + } + + test("covertPredicate converts GreaterThanOrEqual.") { + val sketch = ValueListSketch("A") + val predicate = + GreaterThanOrEqual(AttributeReference("A", IntegerType)(ExprId(0)), Literal(42)) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some(GreaterThanOrEqual(ElementAt(sketchValues(0), Literal(-1)), Literal(42))) + assert(result === expected) + } + + test("covertPredicate converts In.") { + val sketch = ValueListSketch("A") + val predicate = + In(AttributeReference("A", IntegerType)(ExprId(0)), Seq(Literal(42), Literal(23))) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", IntegerType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some( + Or( + SortedArrayContains(sketchValues(0), Literal(42)), + SortedArrayContains(sketchValues(0), Literal(23)))) + assert(result === expected) + } + + test("covertPredicate converts In - string type.") { + val sketch = ValueListSketch("A") + val predicate = + In( + AttributeReference("A", StringType)(ExprId(0)), + Seq(Literal.create("world", StringType), Literal.create("hello", StringType))) + val sketchValues = Seq(UnresolvedAttribute("valueList")) + val nameMap = Map(ExprId(0) -> "A") + val result = sketch.convertPredicate( + predicate, + Seq(AttributeReference("A", StringType)(ExpressionUtils.nullExprId)), + sketchValues, + nameMap, + valueExtractor) + val expected = Some( + Or( + SortedArrayContains(sketchValues(0), Literal.create("world")), + SortedArrayContains(sketchValues(0), Literal.create("hello")))) + assert(result === expected) + } +}