Skip to content

Commit

Permalink
Merge pull request #88 from tensil-ai/peter/sc-488/support-mobilenet-…
Browse files Browse the repository at this point in the history
…with-unknown-batch-size

Support MobileNet with unknown batch size
  • Loading branch information
petrohi authored Aug 17, 2022
2 parents 07e35d6 + a23d801 commit d3baeb6
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tools/src/tensil/tools/compiler/MemoryManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class MemoryManager(
def hasPendingFloatConst(name: String) =
pendingFloatConsts.get(name).isDefined

def hasPendingLongConst(name: String) =
pendingLongConsts.get(name).isDefined

def getPendingIntConst(name: String) = pendingIntConsts(name)
def getPendingLongConst(name: String) = pendingLongConsts(name)
def getPendingFloatConst(name: String) = pendingFloatConsts(name)
Expand Down
97 changes: 96 additions & 1 deletion tools/src/tensil/tools/compiler/OnnxFrontend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,14 @@ class OnnxFrontend(
emitGlobalPool(_, nodeProto),
emitters
)
case "Gather" =>
rewriteSimple(remainingProtos, emitGather(_, nodeProto), emitters)
case "Unsqueeze" =>
rewriteSimple(
remainingProtos,
emitUnsqueeze(_, nodeProto),
emitters
)
case op =>
throw new CompilerException(
s"Unsupported op ${op} (${nodeProto.name.get})"
Expand Down Expand Up @@ -917,6 +925,69 @@ class OnnxFrontend(
)
}

private def emitGather(
context: EmitContext,
gatherProto: NodeProto
): Unit = {
val axisAttr = getAttr(gatherProto, "axis").get

require(axisAttr.`type`.get.isInt)

val axis = axisAttr.i.get

val data =
context.mm
.getPendingLongConst(gatherProto.input(0))
.asInstanceOf[TensorData[Long]]

val indices = context.mm
.getPendingLongConst(gatherProto.input(1))
.asInstanceOf[TensorData[Long]]

if (axis != 0 || data.shape.size != 1 || indices.shape.size != 0)
throw new CompilerException("Only 1D gather is supported");

if (indices.as1D(0) < 0 || indices.as1D(0) >= data.shape(0))
throw new CompilerException("Gather index is outside of data shape");

context.mm.addPendingConst(
gatherProto.output(0),
new TensorData(
Shape(),
Seq(data.as1D(indices.as1D(0).toInt)),
org.tensorflow.framework.types.DataType.DT_INT64
)
)
}

private def emitUnsqueeze(
context: EmitContext,
unsqueezeProto: NodeProto
): Unit = {
val axesAttr = getAttr(unsqueezeProto, "axes").get

require(axesAttr.`type`.get.isInts)

val axes = axesAttr.ints

val data =
context.mm
.getPendingLongConst(unsqueezeProto.input(0))
.asInstanceOf[TensorData[Long]]

if (axes.size != 1 || axes(0) != 0 || data.shape.size != 0)
throw new CompilerException("Only scalar unsqueeze is supported");

context.mm.addPendingConst(
unsqueezeProto.output(0),
new TensorData(
Shape(1),
data.as1D,
org.tensorflow.framework.types.DataType.DT_INT64
)
)
}

private def emitConstant(
context: EmitContext,
constantProto: NodeProto
Expand Down Expand Up @@ -1058,7 +1129,10 @@ class OnnxFrontend(
context: EmitContext,
reshapeProto: NodeProto
): Unit = {
val shape = getTensorData(tensorProtos(reshapeProto.input(1)))
val shapeInputName = reshapeProto.input(1)
val shape = (if (tensorProtos.contains(shapeInputName))
getTensorData(tensorProtos(shapeInputName))
else context.mm.getPendingLongConst(shapeInputName))
.asInstanceOf[TensorData[Long]]
.as1D
.map(_.toInt)
Expand Down Expand Up @@ -1439,6 +1513,27 @@ class OnnxFrontend(
org.tensorflow.framework.types.DataType.DT_FLOAT
)
)
} else if (
concatProto.input.forall(name =>
context.mm.hasPendingLongConst(name) || tensorProtos.contains(name)
)
) {
val output = concatProto.input
.map(name =>
(if (context.mm.hasPendingLongConst(name))
context.mm.getPendingLongConst(name)
else getTensorData(tensorProtos(name))).as1D
)
.flatten

context.mm.addPendingConst(
concatProto.output(0),
new TensorData(
Shape(output.size),
output,
org.tensorflow.framework.types.DataType.DT_INT64
)
)
} else {

if (axis != 1)
Expand Down
53 changes: 53 additions & 0 deletions tools/test/src/tools/CompilerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2768,6 +2768,59 @@ class CompilerSpec extends AnyFlatSpec {
)
}

it should "Compile ONNX float MobileNetV2 with input batch of 3" taggedAs (Slow) in {
val name = "mobilenetv2_float_onnx"
val traceContext = new ExecutiveTraceContext()
val options = CompilerOptions(
arch = MobileNetFloat32Architecture,
inputShapes = CompilerInputShapes.mkWithBatchSize(3),
printSummary = true
)

Compiler.compile(
name,
s"${Models}/mobilenetv2.onnx",
List("output"),
options,
traceContext
)

EmulatorHelper.test(
name,
inputBatchSize = options.inputShapes.batchSize,
traceContext = traceContext
)
}

it should "Compile ONNX fixed18bp10 MobileNetV2 with input batch of 3" taggedAs (Slow) in {
val name = "mobilenetv2_fixed18bp10_onnx"
val traceContext = new ExecutiveTraceContext()
val options = CompilerOptions(
arch = MobileNetFp18bp10Architecture,
inputShapes = CompilerInputShapes.mkWithBatchSize(3),
printSummary = true,
printLayersSummary = true,
printGraph = true,
tracepointConditions = List(
TracepointCondition(MemoryTag.DRAM0, "output")
)
)

Compiler.compile(
name,
s"${Models}/mobilenetv2.onnx",
List("output"),
options,
traceContext
)

EmulatorHelper.test(
name,
inputBatchSize = options.inputShapes.batchSize,
traceContext = traceContext
)
}

val SpeechCommandsFp16bp8Architecture = Architecture.mkWithDefaults(
dataType = ArchitectureDataType.FP16BP8,
arraySize = 8,
Expand Down

0 comments on commit d3baeb6

Please sign in to comment.