Skip to content

Commit

Permalink
[CMSIS-NN] code generator for softmax (#8833)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashutosh-arm authored Sep 7, 2021
1 parent 6dea994 commit 1dc1707
Show file tree
Hide file tree
Showing 14 changed files with 700 additions and 15 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF)
tvm_option(USE_TF_TVMDSOOP "Build with TensorFlow TVMDSOOp" OFF)
tvm_option(USE_FALLBACK_STL_MAP "Use TVM's POD compatible Map" OFF)
tvm_option(USE_ETHOSN "Build with Arm Ethos-N" OFF)
tvm_option(USE_CMSISNN "Build with Arm CMSIS-NN" OFF)
tvm_option(INDEX_DEFAULT_I64 "Defaults the index datatype to int64" ON)
tvm_option(USE_LIBBACKTRACE "Build libbacktrace to supply linenumbers on stack traces" AUTO)
tvm_option(BUILD_STATIC_RUNTIME "Build static version of libtvm_runtime" OFF)
Expand Down Expand Up @@ -401,6 +402,7 @@ include(cmake/modules/ROCM.cmake)
include(cmake/modules/LLVM.cmake)
include(cmake/modules/Micro.cmake)
include(cmake/modules/contrib/EthosN.cmake)
include(cmake/modules/contrib/CMSISNN.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
Expand Down
22 changes: 22 additions & 0 deletions cmake/modules/contrib/CMSISNN.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_CMSISNN)
message(STATUS "Build with CMSIS-NN support")
file(GLOB RELAY_CONTRIB_CMSISNN_SRCS src/relay/backend/contrib/cmsisnn/*.cc)
list(APPEND COMPILER_SRCS ${RELAY_CONTRIB_CMSISNN_SRCS})
endif(USE_CMSISNN)
5 changes: 5 additions & 0 deletions python/tvm/driver/tvmc/composite_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from tvm.relay.op.contrib.arm_compute_lib import partition_for_arm_compute_lib
from tvm.relay.op.contrib.ethosn import partition_for_ethosn
from tvm.relay.op.contrib.cmsisnn import partition_for_cmsisnn
from tvm.relay.op.contrib.bnns import partition_for_bnns
from tvm.relay.op.contrib.vitis_ai import partition_for_vitis_ai

Expand All @@ -49,6 +50,10 @@
"config_key": None,
"pass_pipeline": partition_for_arm_compute_lib,
},
"cmsis-nn": {
"config_key": None,
"pass_pipeline": partition_for_cmsisnn,
},
"ethos-n77": {
"config_key": "relay.ext.ethos-n.options",
"pass_pipeline": partition_for_ethosn,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
# under the License.
"""Backend codegen modules for relay."""
from . import compile_engine
from .contrib import cmsisnn
18 changes: 18 additions & 0 deletions python/tvm/relay/backend/contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""external backend codegen modules for relay."""
from . import cmsisnn
18 changes: 18 additions & 0 deletions python/tvm/relay/backend/contrib/cmsisnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""CMSIS-NN codegen modules for relay."""
from . import codegen
134 changes: 134 additions & 0 deletions python/tvm/relay/backend/contrib/cmsisnn/codegen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Codegen for CMSIS-NN"""
import tvm
from tvm import relay
from tvm.relay.expr_functor import ExprVisitor


class GenerateTIR(ExprVisitor):
"""Generates TIR module containing TIR primfuncs corresponding to the Relay operators.
Note: Relay operator to primfunc mapping may not be 1:1.
"""

def __init__(self, name):
super().__init__()
self.name = name
self.tir_mod = None
self.scale = 1.0 / 256

def call_contains_op(self, call, op_name):
if not isinstance(call.op, tvm.ir.op.Op):
return False
if call.op.name != op_name:
return False
return True

def is_quantized_softmax(self, call):
"""Checks for the following relay sequence
a = qnn.dequantize(in, scale, zero_point)
b = nn.softmax(a)
c = qnn.quantize(c, scale, zero_point)
"""
if not self.call_contains_op(call, "qnn.quantize"):
return False
softmax_call = call.args[0]
if not self.call_contains_op(softmax_call, "nn.softmax"):
return False
dequantize_call = softmax_call.args[0]
if not self.call_contains_op(dequantize_call, "qnn.dequantize"):
return False
self.scale = dequantize_call.args[1].data.numpy().item(0)
return True

def emit_softmax_tir(self, call):
"""Generates TIR extern_call for softmax"""
shape = call.checked_type.shape # NHWC
dtype = call.checked_type.dtype
ir_builder = tvm.tir.ir_builder.create()
in_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype)
out_buf = tvm.tir.decl_buffer(shape=shape, dtype=dtype)

trailing_dim = len(shape) - 1
num_rows = 1
for dim in range(trailing_dim):
num_rows *= shape[dim]
row_size = shape[trailing_dim]
ir_builder.emit(
tvm.tir.call_extern(
dtype,
"arm_softmax_s8",
in_buf.data,
num_rows,
row_size,
self.scale,
out_buf.data,
)
)
prim_func = tvm.tir.PrimFunc([in_buf, out_buf], ir_builder.get())
prim_func = prim_func.with_attr("global_symbol", self.name)
prim_func = prim_func.with_attr("tir.noalias", True)
self.tir_mod = tvm.IRModule({self.name: prim_func})

def visit_call(self, call):
"""Iterates over the relay operators within relay external function"""
super().visit_call(call)
if self.is_quantized_softmax(call):
self.emit_softmax_tir(call)

def generate_tir(self, func):
self.visit(func)
return self.tir_mod


def relay_to_tir(name, func):
"""Lower a Relay function to TIR for the CMSIS-NN target.
The Relay function should only contain operations supported
by the CMSIS-NN target. This is enforced by the graph partitioner
for CMSIS-NN.
Parameters
----------
name: str
Name of the external relay function
func : tvm.relay.Function
The Relay function to lower.
Returns
-------
mod : tvm.IRModule
The lowered TIR module.
"""
return GenerateTIR(name).generate_tir(func)


@tvm.register_func("relay.ext.cmsisnn")
def cmsisnn_compiler(relay_func):
"""It compiles Relay's external function into equivalent TIR
and subsequently converts that into 'c' code. During the 'c'
code generation, it embeds CMSIS-NN APIs for the corresponding
operators.
"""
mod = tvm.IRModule()
mod["main"] = relay_func
mod = relay.transform.InferType()(mod)
func_name = relay_func.attrs["global_symbol"]
tir_mod = relay_to_tir(func_name, mod["main"])
cmsisnn_runtime = tvm._ffi.get_global_func("runtime.module.cmsisnn.create")
return cmsisnn_runtime(tir_mod)
8 changes: 6 additions & 2 deletions python/tvm/relay/op/contrib/cmsisnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,15 @@ def softmax_pattern():

def check_quantized_softmax(extract):
"""Check if softmax is supported by CMSIS-NN."""
dequantize_call = extract.args[0].args[0]
scale = extract.args[1].data.numpy().item(0)
zero_point = extract.args[2].data.numpy().item(0)

# check for dtypes of quantize and dequantize
return (
extract.attrs.out_dtype == "int8"
and extract.args[0].args[0].args[0].checked_type.dtype == "int8"
(scale == 1.0 / 256 and zero_point == -128)
and extract.attrs.out_dtype == "int8"
and dequantize_call.args[0].checked_type.dtype == "int8"
)

return [
Expand Down
Loading

0 comments on commit 1dc1707

Please sign in to comment.