Skip to content

Commit

Permalink
[3/6] Arm(R) Ethos(TM)-U NPU TIR compiler with conv2d support (apache…
Browse files Browse the repository at this point in the history
…#8806)

* Arm(R) Ethos(TM)-U NPU TIR compiler with conv2d support

This commit adds the lowering passes necessary to lower
an NPU Relay module down to a TIR module that can be
compiled for the NPU. Conv2d is supported as the first
NPU operator. An intermediate TE stage between Relay and
TIR allows support for scheduling the operators.

Co-authored-by: Manupa Karunaratne <[email protected]>

* Fix Conv2D TIR type sensitivity

Change-Id: I3741f9dd8bb5952590ff8c586f6b96e5c3a03795

* Arm(R) Ethos(TM)-U NPU TIR passes and TE for Conv2D

*fixing tests

Change-Id: Id4a4c80f72ce29b98fc8b3954a1413c1c7fda500

* Fix import guards for tests

Change-Id: Iaee06017bd125d3040ce42182c4ccdb80d7fc946

* Fix typing failures with ignores

Change-Id: I81513f112a42b93cfdd3bcaf8e8852dd60ffe9e9

* Remove unused import

Change-Id: I6596b62ab56e4ca8b31ef08293686f53f38454d2

* Reintroduce get_target_accel_type

Change-Id: I0aaf83fe0204c0db435692e9b92dee6e9d6997fe

Co-authored-by: Manupa Karunaratne <[email protected]>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 80de207 commit 7402fc8
Show file tree
Hide file tree
Showing 27 changed files with 4,144 additions and 2 deletions.
1 change: 1 addition & 0 deletions python/tvm/relay/backend/contrib/ethosu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
from . import preprocess
from . import errors
from . import vela_api
from . import tir_to_cs_translator
from .util import partition_for_ethosu
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def conv2d_compute(
).astype(ifm.dtype)
* weight[cc, rh, rw, rc].astype(ifm.dtype)
# This is a trick to load 10 elements of the scale_bias at once, not accurate maths
+ (scale_bias[cc, 0] * scale_bias[cc, 9]),
+ (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
axis=[rh, rw, rc],
),
name="ethosu_conv2d",
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/relay/backend/contrib/ethosu/te/dma.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def _pad(*indices):
not_zero.append(indices[i] < tensor.shape[i] + pad_before[i])
if not_zero:
not_zero = tvm.tir.all(*not_zero)
return tvm.tir.if_then_else(not_zero, tensor(*index_tuple), tvm.tir.const(0, "uint8"))
return tvm.tir.if_then_else(
not_zero, tensor(*index_tuple), tvm.tir.const(0, tensor.dtype)
)
return tensor(*index_tuple)

return _pad
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Arm(R) Ethos(TM)-U NPU TIR codegen modules."""
199 changes: 199 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler"""
import tvm
from tvm import relay
from tvm.relay.expr_functor import ExprMutator
from tvm.driver.build_module import get_binds

from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants
from .scheduler import schedule


def lower_ethosu(sch, args, const_dict, name="main"):
"""Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target.
The resulting TIR module will contain a single function
that comprises of a sequence of tir.extern_calls to NPU
operations.
Parameters
----------
sch : tvm.te.Schedule
The schedule to be lowered.
args : Union[list of tvm.te.Tensor, TEGraph]
The input/output tensors.
const_dict : dict of int to numpy.ndarray
The constant dictionary.
name : str, optional
The name of the lowered primitive function.
Returns
-------
mod : tvm.IRModule
The lowered TIR module.
const_dict : dict of int to numpy.ndarray
The modified constant dictionary.
"""
if not isinstance(args, list):
args = list(args.inputs) + list(args.outputs)
# config setup
curr_pass_ctx = tvm.ir.transform.PassContext.current()
curr_cfg = dict()
for key, value in curr_pass_ctx.config.items():
curr_cfg[key] = value
tir_compiler_cfg = {
"tir.LoopPartition": {
"partition_const_loop": True,
"no_unroll_loop_with_extent_one": True,
},
"tir.UnrollLoop": {"auto_max_depth": -1},
}
# Merge two configs
curr_cfg = {**curr_cfg, **tir_compiler_cfg}

sch = sch.normalize()
bounds = tvm.te.schedule.InferBound(sch)
stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True)

compact = tvm.te.schedule.VerifyCompactBuffer(stmt)
binds, arg_list = get_binds(args, compact, None)
func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds)

func = func.with_attr("global_symbol", name)
func = func.with_attr("tir.noalias", True)
mod = tvm.IRModule({name: func})
with tvm.transform.PassContext(config=curr_cfg):
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.StorageFlatten(64)(mod)
mod = tvm.tir.transform.UnrollLoop()(mod)
mod = tvm.tir.transform.LoopPartition()(mod)
mod = RemoveZeroStores()(mod)
mod = tvm.tir.transform.Simplify()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod = ReplaceOperators()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
mod, const_dict = EncodeConstants(const_dict)(mod)
mod = tvm.tir.transform.StorageRewrite()(mod)
mod = tvm.tir.transform.RemoveNoOp()(mod)
return mod, const_dict


def lower_to_te(prim_func):
"""Lower a Relay primitive function to a Tensor Expression graph.
Parameters
----------
prim_func : tvm.relay.Function
The Relay function to lowerethosu_runtime([]).
Returns
-------
out : TEGraph
The lowered Tensor Expression graph.
"""
f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE")
return f(prim_func)


class ExtractConstants(ExprMutator):
"""The actual mutator pass to extract the constants from a function and replace them with
Vars so the function can be lowered to a TE graph. Additionally returns all the values of
the constants extracted."""

def __init__(self):
super().__init__()
self.constants = []

def visit_constant(self, const):
if isinstance(const.checked_type, relay.ty.TensorType):
if const.checked_type.concrete_shape != ():
self.constants.append(const.data.asnumpy())
name = "p" + str(len(self.constants))
return relay.var(type_annotation=const.checked_type, name_hint=name)

return const

def visit_function(self, fn):
new_body = self.visit(fn.body)
new_params = list(relay.analysis.free_vars(new_body))
return relay.Function(new_params, new_body)

def extract_constants(self, func):
new_func = self.visit(func)
return new_func, self.constants


def extract_constants(func):
"""Extract the constants from a function and replace them with
Vars so the function can be lowered to a TE graph. Additionally
returns all the values of the constants extracted.
Parameters
----------
func : tvm.relay.Function
The Relay function from which to extract constants.
Returns
-------
new_func : tvm.relay.Function
The Relay function with constants replaced by vars.
const_dict : dict of int to numpy.ndarray
A dict of the extracted constants keyed by their param index.
"""
const_dict = {}
params = len(func.params)
new_func, consts = ExtractConstants().extract_constants(func)
for i, const in enumerate(consts):
const_dict[params + i] = const

new_func = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(new_func))["main"]
return new_func, const_dict


def lower_to_tir(func, cascader=None):
"""Lower a Relay function to TIR for the Arm(R) Ethos(TM)-U NPU target.
The Relay function should only contain operations supported
by the NPU.
Parameters
----------
func : tvm.relay.Function
The Relay function to lower.
cascader : Callable
An optional cascading function,
Returns
-------
mod : tvm.IRModule
The lowered TIR module.
consts : dict of int to numpy.ndarray
A dict of the extracted constants keyed by their param index.
"""
func, consts = extract_constants(func)
mod = tvm.IRModule.from_expr(func)
func = relay.transform.InferType()(mod)["main"]
te_graph = lower_to_te(func)
s = schedule(te_graph, consts, cascader)
mod, consts = lower_ethosu(s, te_graph, consts)
return mod, consts
106 changes: 106 additions & 0 deletions python/tvm/relay/backend/contrib/ethosu/tir/convolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""Extract information from the convolution operators in TIR."""
import tvm
from ..vela_api import SCALE_BIAS_LENGTH
from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores
from .dma import get_ifm_params, get_ofm_params
from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution


def get_conv2d_params(stmt, producers, consumers):
"""Get the parameters necessary to construct a call_extern for a 2D convolution.
Parameters
----------
stmt : tvm.tir.AttrStmt
The outermost attribute statement of a convolution loop nest.
producers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that produces their values.
consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt
A dictionary to associate pointers with the loop nest
that consumes their values.
Returns
-------
Serial2DConvolution
The parameters needed to construct a 2D convolution.
output_pointer : tvm.tir.Var
The output pointer of the convolution operation.
replace_pointer : tvm.tir.Var
The output pointer of the DMA write operation, which is to replace
the convolution output pointer.
"""
attrs, body = get_op_attrs(stmt)
_, _, _, _, _, inner = get_outer_loops(body, "NHWC")
rh = inner
rw = rh.body
rc = rw.body
# loads = [output, input, weights, scale_bias, scale_bias]
loads = get_loads(rc.body)
# stores = [output]
stores = get_stores(rc.body)
input_pointer = loads[1].buffer_var
output_pointer = stores[0].buffer_var
# Get feature map info
serial_ifm, serial_padding = get_ifm_params(input_pointer, producers)
serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers)
# Get kernel info
serial_kernel = SerialKernel(
width=int(rw.extent),
height=int(rh.extent),
stride_w=int(attrs["stride_w"]),
stride_h=int(attrs["stride_h"]),
dilation_w=int(attrs["dilation_w"]),
dilation_h=int(attrs["dilation_h"]),
)
# Get scale_bias info
scale_bias_load = loads[3]
scale_bias_base = get_base_address(scale_bias_load.index)
serial_scale_bias = SerialAddressRange(
address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base),
length=SCALE_BIAS_LENGTH * serial_ofm[3],
)
# Get weight info
weight_load = loads[2]
weight_base = get_base_address(weight_load.index)
serial_weight = SerialAddressRange(
address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base),
length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent,
)
# Get activation info
serial_activation = SerialActivation(
op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"]
)
return (
Serial2DConvolution(
ifm=serial_ifm,
ofm=serial_ofm,
kernel=serial_kernel,
weight=serial_weight,
weight_zero_point=attrs["weight_zero_point"],
scale_bias=serial_scale_bias,
padding=serial_padding,
activation=serial_activation,
upscale="NONE",
),
output_pointer,
replace_pointer,
)
Loading

0 comments on commit 7402fc8

Please sign in to comment.