diff --git a/tools/src/tensil/tools/compiler/MemoryManager.scala b/tools/src/tensil/tools/compiler/MemoryManager.scala index 7bfbf31..36a5569 100644 --- a/tools/src/tensil/tools/compiler/MemoryManager.scala +++ b/tools/src/tensil/tools/compiler/MemoryManager.scala @@ -85,6 +85,9 @@ class MemoryManager( def hasPendingFloatConst(name: String) = pendingFloatConsts.get(name).isDefined + def hasPendingLongConst(name: String) = + pendingLongConsts.get(name).isDefined + def getPendingIntConst(name: String) = pendingIntConsts(name) def getPendingLongConst(name: String) = pendingLongConsts(name) def getPendingFloatConst(name: String) = pendingFloatConsts(name) diff --git a/tools/src/tensil/tools/compiler/OnnxFrontend.scala b/tools/src/tensil/tools/compiler/OnnxFrontend.scala index ee83fcb..462b0f8 100644 --- a/tools/src/tensil/tools/compiler/OnnxFrontend.scala +++ b/tools/src/tensil/tools/compiler/OnnxFrontend.scala @@ -449,6 +449,14 @@ class OnnxFrontend( emitGlobalPool(_, nodeProto), emitters ) + case "Gather" => + rewriteSimple(remainingProtos, emitGather(_, nodeProto), emitters) + case "Unsqueeze" => + rewriteSimple( + remainingProtos, + emitUnsqueeze(_, nodeProto), + emitters + ) case op => throw new CompilerException( s"Unsupported op ${op} (${nodeProto.name.get})" @@ -917,6 +925,69 @@ class OnnxFrontend( ) } + private def emitGather( + context: EmitContext, + gatherProto: NodeProto + ): Unit = { + val axisAttr = getAttr(gatherProto, "axis").get + + require(axisAttr.`type`.get.isInt) + + val axis = axisAttr.i.get + + val data = + context.mm + .getPendingLongConst(gatherProto.input(0)) + .asInstanceOf[TensorData[Long]] + + val indices = context.mm + .getPendingLongConst(gatherProto.input(1)) + .asInstanceOf[TensorData[Long]] + + if (axis != 0 || data.shape.size != 1 || indices.shape.size != 0) + throw new CompilerException("Only 1D gather is supported"); + + if (indices.as1D(0) < 0 || indices.as1D(0) >= data.shape(0)) + throw new CompilerException("Gather index is outside of data shape"); + + context.mm.addPendingConst( + gatherProto.output(0), + new TensorData( + Shape(), + Seq(data.as1D(indices.as1D(0).toInt)), + org.tensorflow.framework.types.DataType.DT_INT64 + ) + ) + } + + private def emitUnsqueeze( + context: EmitContext, + unsqueezeProto: NodeProto + ): Unit = { + val axesAttr = getAttr(unsqueezeProto, "axes").get + + require(axesAttr.`type`.get.isInts) + + val axes = axesAttr.ints + + val data = + context.mm + .getPendingLongConst(unsqueezeProto.input(0)) + .asInstanceOf[TensorData[Long]] + + if (axes.size != 1 || axes(0) != 0 || data.shape.size != 0) + throw new CompilerException("Only scalar unsqueeze is supported"); + + context.mm.addPendingConst( + unsqueezeProto.output(0), + new TensorData( + Shape(1), + data.as1D, + org.tensorflow.framework.types.DataType.DT_INT64 + ) + ) + } + private def emitConstant( context: EmitContext, constantProto: NodeProto @@ -1058,7 +1129,10 @@ class OnnxFrontend( context: EmitContext, reshapeProto: NodeProto ): Unit = { - val shape = getTensorData(tensorProtos(reshapeProto.input(1))) + val shapeInputName = reshapeProto.input(1) + val shape = (if (tensorProtos.contains(shapeInputName)) + getTensorData(tensorProtos(shapeInputName)) + else context.mm.getPendingLongConst(shapeInputName)) .asInstanceOf[TensorData[Long]] .as1D .map(_.toInt) @@ -1439,6 +1513,27 @@ class OnnxFrontend( org.tensorflow.framework.types.DataType.DT_FLOAT ) ) + } else if ( + concatProto.input.forall(name => + context.mm.hasPendingLongConst(name) || tensorProtos.contains(name) + ) + ) { + val output = concatProto.input + .map(name => + (if (context.mm.hasPendingLongConst(name)) + context.mm.getPendingLongConst(name) + else getTensorData(tensorProtos(name))).as1D + ) + .flatten + + context.mm.addPendingConst( + concatProto.output(0), + new TensorData( + Shape(output.size), + output, + org.tensorflow.framework.types.DataType.DT_INT64 + ) + ) } else { if (axis != 1) diff --git a/tools/test/src/tools/CompilerSpec.scala b/tools/test/src/tools/CompilerSpec.scala index 9f182f4..9a21bed 100644 --- a/tools/test/src/tools/CompilerSpec.scala +++ b/tools/test/src/tools/CompilerSpec.scala @@ -2768,6 +2768,59 @@ class CompilerSpec extends AnyFlatSpec { ) } + it should "Compile ONNX float MobileNetV2 with input batch of 3" taggedAs (Slow) in { + val name = "mobilenetv2_float_onnx" + val traceContext = new ExecutiveTraceContext() + val options = CompilerOptions( + arch = MobileNetFloat32Architecture, + inputShapes = CompilerInputShapes.mkWithBatchSize(3), + printSummary = true + ) + + Compiler.compile( + name, + s"${Models}/mobilenetv2.onnx", + List("output"), + options, + traceContext + ) + + EmulatorHelper.test( + name, + inputBatchSize = options.inputShapes.batchSize, + traceContext = traceContext + ) + } + + it should "Compile ONNX fixed18bp10 MobileNetV2 with input batch of 3" taggedAs (Slow) in { + val name = "mobilenetv2_fixed18bp10_onnx" + val traceContext = new ExecutiveTraceContext() + val options = CompilerOptions( + arch = MobileNetFp18bp10Architecture, + inputShapes = CompilerInputShapes.mkWithBatchSize(3), + printSummary = true, + printLayersSummary = true, + printGraph = true, + tracepointConditions = List( + TracepointCondition(MemoryTag.DRAM0, "output") + ) + ) + + Compiler.compile( + name, + s"${Models}/mobilenetv2.onnx", + List("output"), + options, + traceContext + ) + + EmulatorHelper.test( + name, + inputBatchSize = options.inputShapes.batchSize, + traceContext = traceContext + ) + } + val SpeechCommandsFp16bp8Architecture = Architecture.mkWithDefaults( dataType = ArchitectureDataType.FP16BP8, arraySize = 8,