diff --git a/python/tvm/relay/backend/contrib/ethosu/__init__.py b/python/tvm/relay/backend/contrib/ethosu/__init__.py index f5c595462e73..2b424ebb5dec 100644 --- a/python/tvm/relay/backend/contrib/ethosu/__init__.py +++ b/python/tvm/relay/backend/contrib/ethosu/__init__.py @@ -20,4 +20,5 @@ from . import preprocess from . import errors from . import vela_api +from . import tir_to_cs_translator from .util import partition_for_ethosu diff --git a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py index 40015ac296a6..26f7ea979219 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/convolution.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/convolution.py @@ -140,7 +140,7 @@ def conv2d_compute( ).astype(ifm.dtype) * weight[cc, rh, rw, rc].astype(ifm.dtype) # This is a trick to load 10 elements of the scale_bias at once, not accurate maths - + (scale_bias[cc, 0] * scale_bias[cc, 9]), + + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype), axis=[rh, rw, rc], ), name="ethosu_conv2d", diff --git a/python/tvm/relay/backend/contrib/ethosu/te/dma.py b/python/tvm/relay/backend/contrib/ethosu/te/dma.py index d19c8c56f7c2..bf9a018ea855 100644 --- a/python/tvm/relay/backend/contrib/ethosu/te/dma.py +++ b/python/tvm/relay/backend/contrib/ethosu/te/dma.py @@ -59,7 +59,9 @@ def _pad(*indices): not_zero.append(indices[i] < tensor.shape[i] + pad_before[i]) if not_zero: not_zero = tvm.tir.all(*not_zero) - return tvm.tir.if_then_else(not_zero, tensor(*index_tuple), tvm.tir.const(0, "uint8")) + return tvm.tir.if_then_else( + not_zero, tensor(*index_tuple), tvm.tir.const(0, tensor.dtype) + ) return tensor(*index_tuple) return _pad diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/__init__.py b/python/tvm/relay/backend/contrib/ethosu/tir/__init__.py new file mode 100644 index 000000000000..cc285e5241cd --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Arm(R) Ethos(TM)-U NPU TIR codegen modules.""" diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py new file mode 100644 index 000000000000..c59a386fefbb --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/compiler.py @@ -0,0 +1,199 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""The integration of Arm(R) Ethos(TM)-U NPU TIR compiler""" +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprMutator +from tvm.driver.build_module import get_binds + +from .passes import ReplaceOperators, RemoveZeroStores, EncodeConstants +from .scheduler import schedule + + +def lower_ethosu(sch, args, const_dict, name="main"): + """Lower a schedule to TIR for the Arm(R) Ethos(TM)-U NPU target. + + The resulting TIR module will contain a single function + that comprises of a sequence of tir.extern_calls to NPU + operations. + + Parameters + ---------- + sch : tvm.te.Schedule + The schedule to be lowered. + args : Union[list of tvm.te.Tensor, TEGraph] + The input/output tensors. + const_dict : dict of int to numpy.ndarray + The constant dictionary. + name : str, optional + The name of the lowered primitive function. + + Returns + ------- + mod : tvm.IRModule + The lowered TIR module. + const_dict : dict of int to numpy.ndarray + The modified constant dictionary. + + """ + if not isinstance(args, list): + args = list(args.inputs) + list(args.outputs) + # config setup + curr_pass_ctx = tvm.ir.transform.PassContext.current() + curr_cfg = dict() + for key, value in curr_pass_ctx.config.items(): + curr_cfg[key] = value + tir_compiler_cfg = { + "tir.LoopPartition": { + "partition_const_loop": True, + "no_unroll_loop_with_extent_one": True, + }, + "tir.UnrollLoop": {"auto_max_depth": -1}, + } + # Merge two configs + curr_cfg = {**curr_cfg, **tir_compiler_cfg} + + sch = sch.normalize() + bounds = tvm.te.schedule.InferBound(sch) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds, True) + + compact = tvm.te.schedule.VerifyCompactBuffer(stmt) + binds, arg_list = get_binds(args, compact, None) + func = tvm.te.schedule.SchedulePostProcToPrimFunc(arg_list, stmt, binds) + + func = func.with_attr("global_symbol", name) + func = func.with_attr("tir.noalias", True) + mod = tvm.IRModule({name: func}) + with tvm.transform.PassContext(config=curr_cfg): + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.tir.transform.UnrollLoop()(mod) + mod = tvm.tir.transform.LoopPartition()(mod) + mod = RemoveZeroStores()(mod) + mod = tvm.tir.transform.Simplify()(mod) + mod = tvm.tir.transform.RemoveNoOp()(mod) + mod = ReplaceOperators()(mod) + mod = tvm.tir.transform.RemoveNoOp()(mod) + mod, const_dict = EncodeConstants(const_dict)(mod) + mod = tvm.tir.transform.StorageRewrite()(mod) + mod = tvm.tir.transform.RemoveNoOp()(mod) + return mod, const_dict + + +def lower_to_te(prim_func): + """Lower a Relay primitive function to a Tensor Expression graph. + + Parameters + ---------- + prim_func : tvm.relay.Function + The Relay function to lowerethosu_runtime([]). + + Returns + ------- + out : TEGraph + The lowered Tensor Expression graph. + + """ + f = tvm._ffi.get_global_func("relay.backend.contrib.ethosu.LowerToTE") + return f(prim_func) + + +class ExtractConstants(ExprMutator): + """The actual mutator pass to extract the constants from a function and replace them with + Vars so the function can be lowered to a TE graph. Additionally returns all the values of + the constants extracted.""" + + def __init__(self): + super().__init__() + self.constants = [] + + def visit_constant(self, const): + if isinstance(const.checked_type, relay.ty.TensorType): + if const.checked_type.concrete_shape != (): + self.constants.append(const.data.asnumpy()) + name = "p" + str(len(self.constants)) + return relay.var(type_annotation=const.checked_type, name_hint=name) + + return const + + def visit_function(self, fn): + new_body = self.visit(fn.body) + new_params = list(relay.analysis.free_vars(new_body)) + return relay.Function(new_params, new_body) + + def extract_constants(self, func): + new_func = self.visit(func) + return new_func, self.constants + + +def extract_constants(func): + """Extract the constants from a function and replace them with + Vars so the function can be lowered to a TE graph. Additionally + returns all the values of the constants extracted. + + Parameters + ---------- + func : tvm.relay.Function + The Relay function from which to extract constants. + + Returns + ------- + new_func : tvm.relay.Function + The Relay function with constants replaced by vars. + const_dict : dict of int to numpy.ndarray + A dict of the extracted constants keyed by their param index. + + """ + const_dict = {} + params = len(func.params) + new_func, consts = ExtractConstants().extract_constants(func) + for i, const in enumerate(consts): + const_dict[params + i] = const + + new_func = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(new_func))["main"] + return new_func, const_dict + + +def lower_to_tir(func, cascader=None): + """Lower a Relay function to TIR for the Arm(R) Ethos(TM)-U NPU target. + + The Relay function should only contain operations supported + by the NPU. + + Parameters + ---------- + func : tvm.relay.Function + The Relay function to lower. + cascader : Callable + An optional cascading function, + + Returns + ------- + mod : tvm.IRModule + The lowered TIR module. + consts : dict of int to numpy.ndarray + A dict of the extracted constants keyed by their param index. + + """ + func, consts = extract_constants(func) + mod = tvm.IRModule.from_expr(func) + func = relay.transform.InferType()(mod)["main"] + te_graph = lower_to_te(func) + s = schedule(te_graph, consts, cascader) + mod, consts = lower_ethosu(s, te_graph, consts) + return mod, consts diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py new file mode 100644 index 000000000000..33fbdcd2b24f --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/convolution.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the convolution operators in TIR.""" +import tvm +from ..vela_api import SCALE_BIAS_LENGTH +from .utils import get_outer_loops, get_op_attrs, get_base_address, get_loads, get_stores +from .dma import get_ifm_params, get_ofm_params +from .spec import SerialKernel, SerialAddressRange, SerialActivation, Serial2DConvolution + + +def get_conv2d_params(stmt, producers, consumers): + """Get the parameters necessary to construct a call_extern for a 2D convolution. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a convolution loop nest. + producers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + Serial2DConvolution + The parameters needed to construct a 2D convolution. + output_pointer : tvm.tir.Var + The output pointer of the convolution operation. + replace_pointer : tvm.tir.Var + The output pointer of the DMA write operation, which is to replace + the convolution output pointer. + + """ + attrs, body = get_op_attrs(stmt) + _, _, _, _, _, inner = get_outer_loops(body, "NHWC") + rh = inner + rw = rh.body + rc = rw.body + # loads = [output, input, weights, scale_bias, scale_bias] + loads = get_loads(rc.body) + # stores = [output] + stores = get_stores(rc.body) + input_pointer = loads[1].buffer_var + output_pointer = stores[0].buffer_var + # Get feature map info + serial_ifm, serial_padding = get_ifm_params(input_pointer, producers) + serial_ofm, replace_pointer = get_ofm_params(output_pointer, consumers) + # Get kernel info + serial_kernel = SerialKernel( + width=int(rw.extent), + height=int(rh.extent), + stride_w=int(attrs["stride_w"]), + stride_h=int(attrs["stride_h"]), + dilation_w=int(attrs["dilation_w"]), + dilation_h=int(attrs["dilation_h"]), + ) + # Get scale_bias info + scale_bias_load = loads[3] + scale_bias_base = get_base_address(scale_bias_load.index) + serial_scale_bias = SerialAddressRange( + address=tvm.tir.Load("uint8", scale_bias_load.buffer_var, scale_bias_base), + length=SCALE_BIAS_LENGTH * serial_ofm[3], + ) + # Get weight info + weight_load = loads[2] + weight_base = get_base_address(weight_load.index) + serial_weight = SerialAddressRange( + address=tvm.tir.Load("uint8", weight_load.buffer_var, weight_base), + length=serial_ofm[3] * serial_kernel[0] * serial_kernel[1] * rc.extent, + ) + # Get activation info + serial_activation = SerialActivation( + op=attrs["activation"], clip_min=attrs["clip_min"], clip_max=attrs["clip_max"] + ) + return ( + Serial2DConvolution( + ifm=serial_ifm, + ofm=serial_ofm, + kernel=serial_kernel, + weight=serial_weight, + weight_zero_point=attrs["weight_zero_point"], + scale_bias=serial_scale_bias, + padding=serial_padding, + activation=serial_activation, + upscale="NONE", + ), + output_pointer, + replace_pointer, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/dma.py b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py new file mode 100644 index 000000000000..ecd402d63309 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/dma.py @@ -0,0 +1,291 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the DMA operators in TIR.""" +import tvm +from .utils import get_outer_loops, get_base_address, get_strides, get_op_attrs +from .spec import SerialFeatureMap, SerialPadding + + +def get_pad_params(stmt): + """Get the padding parameters from a pad loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a pad loop nest. + + Returns + ------- + pad : SerialPadding + The serializable padding. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + _, body = get_op_attrs(stmt) + n, h, w, c, _, inner = get_outer_loops(body, "NHWC") + output_pointer = inner.buffer_var + pad = SerialPadding(top=0, left=0, bottom=0, right=0) + if isinstance(inner.value, tvm.tir.Call): + input_pointer = inner.value.args[1].buffer_var + else: + input_pointer = inner.value.buffer_var + return pad, input_pointer, output_pointer + + padded_shape = [n.extent, h.extent, w.extent, c.extent] + + def _visit(expr): + if isinstance(expr, tvm.tir.expr.LT): + var = expr.a + val = expr.b + if var == h.loop_var: + pad.bottom = padded_shape[1] - val + else: + pad.right = padded_shape[2] - val + elif isinstance(expr, tvm.tir.expr.LE): + var = expr.b + val = expr.a + if var == h.loop_var: + pad.top = val + else: + pad.left = val + + cond = inner.value.args[0] + tvm.tir.stmt_functor.post_order_visit(cond, _visit) + return ( + pad, + input_pointer, + output_pointer, + ) + + +def get_convert_to_nhwc_params(stmt): + """Get the true number of channels from a convert_to_nhwc loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a convert_to_nhwc loop nest. + + Returns + ------- + int + The true number of channels. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + _, body = get_op_attrs(stmt) + _, _, _, c, _, inner = get_outer_loops(body, "NHWC") + output_pointer = inner.buffer_var + input_pointer = inner.value.buffer_var + return c.extent, input_pointer, output_pointer + + +def get_convert_to_nhcwb16_params(stmt): + """Get the true number of channels from a convert_to_nhcwb16 loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a convert_to_nhcwb16 loop nest. + + Returns + ------- + out_channels : int + The true number of channels. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + attrs, body = get_op_attrs(stmt) + _, _, _, c, b, inner = get_outer_loops(body, attrs["layout"]) + output_pointer = inner.buffer_var + if isinstance(inner.value, tvm.tir.Call): + cond = inner.value.args[0] + out_channels = cond.b.value + input_pointer = inner.value.args[1].buffer_var + else: + input_pointer = inner.value.buffer_var + out_channels = c.extent * b.extent if attrs["layout"] == "NHCWB16" else c.extent + + return out_channels, input_pointer, output_pointer + + +def get_read_params(stmt): + """Get the feature map parameters from a read loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a read loop nest. + + Returns + ------- + SerialFeatureMap + The serializable feature map. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + attrs, body = get_op_attrs(stmt) + _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) + input_pointer = inner.value.buffer_var + output_pointer = inner.buffer_var + stride_vars = [h.loop_var, w.loop_var, c.loop_var] + strides = get_strides(inner.value.index, stride_vars) + base_address = get_base_address(inner.value.index) + data_type = inner.buffer_var.type_annotation.element_type.dtype + return ( + SerialFeatureMap( + data_type=data_type, + height=h.extent, + width=w.extent, + channels=c.extent, + tile_height_0=h.extent, + tile_height_1=0, + tile_width_0=w.extent, + tile_address_0=tvm.tir.Load(data_type, inner.value.buffer_var, base_address), + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=attrs["scale"], + zero_point=attrs["zero_point"], + layout=attrs["layout"], + stride_h=strides[0], + stride_w=strides[1], + stride_c=strides[2], + ), + input_pointer, + output_pointer, + ) + + +def get_write_params(stmt): + """Get the feature map parameters from a write loop nest. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a write loop nest. + + Returns + ------- + SerialFeatureMap + The serializable feature map. + input_pointer : tvm.tir.Var + The pointer consumed by the operation. + output_pointer : tvm.tir.Var + The pointer produced by the operation. + + """ + attrs, body = get_op_attrs(stmt) + _, h, w, c, _, inner = get_outer_loops(body, attrs["layout"]) + input_pointer = inner.value.buffer_var + output_pointer = inner.buffer_var + stride_vars = [h.loop_var, w.loop_var, c.loop_var] + strides = get_strides(inner.index, stride_vars) + base_address = get_base_address(inner.index) + data_type = inner.buffer_var.type_annotation.element_type.dtype + return ( + SerialFeatureMap( + data_type=data_type, + height=h.extent, + width=w.extent, + channels=c.extent, + tile_height_0=h.extent, + tile_height_1=0, + tile_width_0=w.extent, + tile_address_0=tvm.tir.Load(data_type, inner.buffer_var, base_address), + tile_address_1=0, + tile_address_2=0, + tile_address_3=0, + scale=attrs["scale"], + zero_point=attrs["zero_point"], + layout=attrs["layout"], + stride_h=strides[0], + stride_w=strides[1], + stride_c=strides[2], + ), + input_pointer, + output_pointer, + ) + + +def get_ifm_params(pointer, producers): + """Get the parameters associated with the DMA capabilities for an IFM. + + Parameters + ---------- + pointer : tvm.tir.Var + The pointer that the IFM DMA pipeline produces. + producers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that produces their values. + + Returns + ------- + serial_ifm : SerialFeatureMap + The serializable IFM. + serial_padding : SerialPadding + The serializable padding. + + """ + pad = producers[pointer] + serial_padding, input_pointer, _ = get_pad_params(pad) + convert_to_nhwc = producers[input_pointer] + in_channels, input_pointer, _ = get_convert_to_nhwc_params(convert_to_nhwc) + read = producers[input_pointer] + serial_ifm, _, _ = get_read_params(read) + serial_ifm.channels = in_channels + return serial_ifm, serial_padding + + +def get_ofm_params(pointer, consumers): + """Get the parameters associated with the DMA capabilities for an OFM. + + Parameters + ---------- + pointer : tvm.tir.Var + The pointer that the OFM DMA pipeline consumes. + consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + serial_ifm : SerialFeatureMap + The serializable OFM. + output_pointer : tvm.tir.Var + The pointer that the OFM DMA pipeline produces. + + """ + convert_to_nhcwb16 = consumers[pointer] + out_channels, _, output_pointer = get_convert_to_nhcwb16_params(convert_to_nhcwb16) + write = consumers[output_pointer] + serial_ofm, _, output_pointer = get_write_params(write) + serial_ofm.channels = out_channels + return serial_ofm, output_pointer diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/passes.py b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py new file mode 100644 index 000000000000..1af44962c141 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/passes.py @@ -0,0 +1,475 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""The TIR passes to be run on Arm(R) Ethos(TM)-U NPU TIR Compiler""" +import numpy as np # type: ignore + +import tvm +from tvm.relay.backend.contrib.ethosu import vela_api +from .convolution import get_conv2d_params +from .transform import get_copy_params +from .utils import get_weights_pointer, get_scale_bias_pointer + + +def RemoveZeroStores(): + """This pass removes stores which just store zero to initialise buffers. + + We don't codegen these stores and they otherwise considerably reduce + the simplicity of the static traversal of convolution.""" + + def _remove_zero_store(stmt): + if isinstance(stmt.value, tvm.tir.IntImm) and int(stmt.value) == 0: + return tvm.tir.Evaluate(tvm.tir.IntImm("uint8", 0)) + return stmt + + def _ftransform(f, mod, ctx): + return f.with_body( + tvm.tir.stmt_functor.ir_transform(f.body, _remove_zero_store, None, ["tir.Store"]) + ) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.remove_zero_stores" + ) + + +def ReplaceOperators(): + """Replace operators represented as explicit loop nests with call_externs + to NPU operators.""" + op_map = { + "ethosu_conv2d": get_conv2d_params, + "ethosu_copy": get_copy_params, + } + pointer_to_producer = {} + pointer_to_consumer = {} + replace_output_pointer = {} + pointer_to_extents = {} + + def _resolve_pointers(stmt): + """This pass determines information about the pointers present in the IR. + In particular, it associates pointers with both the operations that + produce them and the operations that consume them through the + pointer_to_producer and pointer_to_consumer dicts. + + Additionally, it determines the extent (size/shape) of each pointer which + is required for the _replace_pointers pass which runs later.""" + loads = [] + + def _get_loads(stmt): + if isinstance(stmt, tvm.tir.Load): + loads.append(stmt.buffer_var) + + if isinstance(stmt, tvm.tir.Allocate): + pointer_to_extents[stmt.buffer_var] = stmt.extents + if isinstance(stmt.body[0], tvm.tir.AttrStmt): + if stmt.body[0].attr_key == "pragma_op": + pointer_to_producer[stmt.buffer_var] = stmt.body[0] + + elif isinstance(stmt, tvm.tir.AttrStmt): + if stmt.attr_key == "pragma_op": + tvm.tir.stmt_functor.post_order_visit(stmt, _get_loads) + for load_buffer in loads: + pointer_to_consumer[load_buffer] = stmt + + def _replace_operator(stmt): + """Replace operators with call_externs, having derived the parameters + from the relevant TIR expressions/statements. + + Note the complexity of this pass is mostly from the concept of 'replace + pointers'. A call_extern may in principle require information from several + loop nests in TIR (each corresponding to a different TE compute op). For + example, a convolution operator will have other TE compute ops before and + after corresponding to the input/output DMA functionality. Therefore, when + the 'central' convolution op is replaced with a call_extern, the memory + from the final DMA output op must be hoisted to the location/scope of + the call_extern. + + The is done by replacing the pointer corresponding to the current operation + with the 'true' output operator through the replace_output_pointer dict. + Because of this, the param_func must provide a replace_pointer if the op + isn't the true output but instead a no_compile op is.""" + if isinstance(stmt, tvm.tir.AttrStmt): + op_name = stmt.value.value + if stmt.attr_key == "pragma_op" and op_name in op_map: + # Get the parameters for the extern call + param_func = op_map[op_name] + info, output_pointer, replace_pointer = param_func( + stmt, pointer_to_producer, pointer_to_consumer + ) + if replace_pointer is not None: + replace_output_pointer[output_pointer] = replace_pointer + # Make the extern call + irb = tvm.tir.ir_builder.create() + irb.emit(tvm.tir.call_extern("handle", op_name, *info)) + return irb.get() + return None + + def _remove_no_compile(stmt): + """Certain operators are marked as 'no compile' operators. This means they + should be removed from the IR as they are compiled as part of other operators. + The IFM DMA operations are an example of this, as they don't get compiled + independently but instead get compiled into the operator they're associated with, + e.g. a conv2d. + + There are potentially 3 parts to remove for an operator: the memory scope, the + allocate for its output and the compute nest itself. For the memory scope and + allocate, we can check if the pointer they reference is produced by a 'no compile' + operator. For the compute nest, we can just check the op pragma.""" + if isinstance(stmt, tvm.tir.AttrStmt): + # Remove memory scopes + if stmt.node in pointer_to_producer: + producer_attr = pointer_to_producer[stmt.node] + if ( + producer_attr.attr_key == "pragma_op" + and producer_attr.value.value not in op_map + ): + return stmt.body + + # Remove compute nests + if stmt.attr_key == "pragma_op" and stmt.value.value not in op_map: + return tvm.tir.Evaluate(0) + + if isinstance(stmt, tvm.tir.Allocate): + # Remove allocates + if stmt.buffer_var in pointer_to_producer: + op_attr = pointer_to_producer[stmt.buffer_var] + if op_attr.attr_key == "pragma_op" and op_attr.value.value not in op_map: + return stmt.body + return None + + def _replace_pointers(stmt): + if isinstance(stmt, tvm.tir.AttrStmt): + # If the attribute references a pointer that needs replacing + if stmt.node in replace_output_pointer: + replace_pointer = replace_output_pointer[stmt.node] + # If the pointer doesn't have an extent registered to it, + # this means the pointer is to a Buffer. In this case, we + # just want to delete the memory scope attribute + if replace_pointer not in pointer_to_extents: + return stmt.body + # Otherwise, rewrite the memory scope attribute with the new pointer + return tvm.tir.AttrStmt( + replace_output_pointer[stmt.node], stmt.attr_key, stmt.value, stmt.body + ) + + if isinstance(stmt, tvm.tir.Allocate): + # If the allocate allocates a pointer that needs replacing + if stmt.buffer_var in replace_output_pointer: + replace_pointer = replace_output_pointer[stmt.buffer_var] + # If the pointer doesn't have an extent registered to it, + # this means the pointer is to a Buffer. In this case, we + # just want to delete the allocation statement + if replace_pointer not in pointer_to_extents: + return stmt.body + # Otherwise, rewrite the allocation statement with the new pointer + # and the new extent + replace_type = replace_pointer.type_annotation.element_type.dtype + replace_extents = pointer_to_extents[replace_pointer] + return tvm.tir.Allocate( + replace_pointer, replace_type, replace_extents, stmt.condition, stmt.body + ) + return None + + def _post_transform(stmt): + # Replace operators with call_externs + result = _replace_operator(stmt) + # Remove operators that don't need compiling + result = result or _remove_no_compile(stmt) + # Replace necessary pointers that were removed in the previous step + return result or _replace_pointers(stmt) + + def _ftransform(f, mod, ctx): + tvm.tir.stmt_functor.post_order_visit(f.body, _resolve_pointers) + return f.with_body( + tvm.tir.stmt_functor.ir_transform( + f.body, None, _post_transform, ["tir.AttrStmt", "tir.Allocate"] + ) + ) + + return tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.replace_operators" + ) + + +def DivideConstants(const_dict): + """This pass rewrites the IR and constant dict such that all constant + accesses are at 0 offset and full length (i.e. they read the whole buffer). + + Where necessary, new constants are created in order to ensure the rewrite + can take place. As an example, if a convolution is tiled along the channels + axis, the accesses to the weights will need to be offset. This pass will + create new constants consisting of 'slices' of the weights so each tile + of the compute can access one of these 'slices'. + + The purpose of this pass is to transform the IR into a form we can apply + constant encoding to (which will compress weights and encode biases).""" + buffer_to_const = {} # type: ignore + new_buffers = [] + new_consts = [] + keep_buffers = set() + new_const_dict = {} + + def _visit(stmt): + new_args = [] + for i, arg in enumerate(stmt.args): + if isinstance(arg, tvm.tir.expr.Load): + # If we're trying to load a buffer that maps to a constant + if arg.buffer_var in buffer_to_const: + const = buffer_to_const[arg.buffer_var] + offset = int(arg.index) + # Note by convention the arg after a constant read is the length of the read + length = int(stmt.args[i + 1]) + # If it's anything other than a full read, create a new buffer + if offset != 0 or len(const) != length: + new_consts.append(const[offset : offset + length]) + new_buffer = tvm.tir.decl_buffer((length,), arg.dtype) + new_buffers.append(new_buffer) + new_args.append(tvm.tir.expr.Load(new_buffer.dtype, new_buffer.data, 0)) + continue + keep_buffers.add(arg.buffer_var) + + new_args.append(arg) + + return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) + + def _ftransform(f, mod, ctx): + for i, param in enumerate(f.params): + if i in const_dict: + buffer_to_const[param] = const_dict[i].flatten() + buffer_to_const[f.buffer_map[param].data] = const_dict[i].flatten() + + new_body = tvm.tir.stmt_functor.ir_transform(f.body, _visit, None, ["tir.Call"]) + # Both the params and buffer map need updating for the newly introduced buffers + new_params = [] # type: ignore + new_buffer_map = {} + for i, param in enumerate(f.params): + buffer = f.buffer_map[param] + pointer = buffer.data + if pointer in buffer_to_const: + if pointer not in keep_buffers: + continue + new_const_dict[len(new_params)] = const_dict[i] + new_params.append(param) + new_buffer_map[param] = buffer + + for i, new_buffer in enumerate(new_buffers): + handle = tvm.tir.Var("placeholder", "handle") + new_params.append(handle) + new_buffer_map[handle] = new_buffer + new_const_dict[len(new_params) - 1] = new_consts[i] + + new_f = tvm.tir.PrimFunc(new_params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span) + return new_f + + def _divide_constants(mod): + transform_func = tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.divide_constants" + ) + new_func = transform_func(mod) + return new_func, new_const_dict + + return _divide_constants + + +def EncodeConstants(const_dict): + """the NPU requires that weights are compressed and bias/scales are 'encoded', both + of which are performed by this pass. + + This pass modifies both the constant dict to contain the post-encoding values of the + constants and the IR to adjust buffer types/sizes/accesses so they align with the + encoded constants. Calls to the Vela API are made to perform the actual compression/ + encoding. + + """ + new_const_dict = {} + buffer_to_const = {} + pointer_to_buffer = {} + rewrite_buffer = {} + rewrite_pointer = {} + accel_type = vela_api.get_target_accel_type() # type: ignore + + def _align_scale_bias(tir_extern_call, bias): + """Align the scale_bias to 16 bytes.""" + value_bytes = bytearray() + value_bytes.extend(bias.tobytes()) + # Align to 16 + remainder = (len(value_bytes)) % 16 + if remainder > 0: + value_bytes.extend(bytearray(16 - remainder)) + value = np.frombuffer(value_bytes, dtype="uint8") + return value + + def _encode_weights(tir_extern_call, weights): + """Encode the weights for a TIR extern call.""" + value_bytes = vela_api.encode_weights(tir_extern_call, weights, accel_type) + value = np.frombuffer(value_bytes, dtype="uint8") + return value + + def _new_buffer(old_buffer, new_value): + """Create a new buffer and add the old buffer and its pointer to the + rewriting maps.""" + new_buffer = tvm.tir.decl_buffer((len(new_value),), str(new_value.dtype)) + pointer_to_buffer[new_buffer.data] = new_buffer + rewrite_buffer[old_buffer] = new_buffer + rewrite_pointer[old_buffer.data] = new_buffer.data + buffer_to_const[new_buffer] = new_value + + def _visit_encode_pre(stmt): + if isinstance(stmt, tvm.tir.Call): + # Handle copies as a special-case by propagating the buffer information + # from the read to the write pointer. + if stmt.args[0] == "ethosu_copy": + read_pointer = stmt.args[1].buffer_var + if read_pointer in pointer_to_buffer: + write_pointer = stmt.args[3].buffer_var + # Assert writing to the base of the write_var (pre-StorageRewrite) + assert stmt.args[3].index == 0 + assert stmt.args[1].index == 0 + pointer_to_buffer[write_pointer] = pointer_to_buffer[read_pointer] + else: + # Encode the weights + weights_pointer = get_weights_pointer(stmt) + if weights_pointer is not None: + assert weights_pointer in pointer_to_buffer + weights_buffer = pointer_to_buffer[weights_pointer] + weights_value = buffer_to_const[weights_buffer] + new_weights_value = _encode_weights(stmt, weights_value) + _new_buffer(weights_buffer, new_weights_value) + # Align the scale_bias to 16 bytes + scale_bias_pointer = get_scale_bias_pointer(stmt) + if scale_bias_pointer is not None: + assert scale_bias_pointer in pointer_to_buffer + scale_bias_buffer = pointer_to_buffer[scale_bias_pointer] + scale_bias_value = buffer_to_const[scale_bias_buffer] + new_scale_bias_value = _align_scale_bias(stmt, scale_bias_value) + _new_buffer(scale_bias_buffer, new_scale_bias_value) + + def _visit_encode_post(stmt): + # Because encoding may change the data type (e.g. bias to uint8) and type information + # is stored in pointer vars, it's necessary to rewrite all the pointers which point + # to encoded data. + if isinstance(stmt, tvm.tir.Allocate): + allocate_pointer = stmt.buffer_var + if allocate_pointer in pointer_to_buffer: + buffer = pointer_to_buffer[allocate_pointer] + if buffer in rewrite_buffer: # If the pointer needs rewriting + # Create a new pointer var with the type of the new buffer + new_buffer = rewrite_buffer[buffer] + storage_type = tvm.ir.PrimType(new_buffer.dtype) + new_pointer = tvm.tir.Var( + allocate_pointer.name, + tvm.ir.PointerType(storage_type, buffer.scope()), + allocate_pointer.span, + ) + # Set the new pointer to resolve to the new buffer + pointer_to_buffer[new_pointer] = new_buffer + # Add the old pointer to the pointer rewriting dict + rewrite_pointer[allocate_pointer] = new_pointer + + def _visit_rewrite(stmt): + if isinstance(stmt, tvm.tir.Call): + # For extern calls, we need to rewrite pairs of arguments corresponding to + # base address load and the length of the load. + new_args = [stmt.args[0]] + for i in range(1, len(stmt.args)): + # If the previous argument was a load, the current should be a length + if isinstance(stmt.args[i - 1], tvm.tir.Load): + load = stmt.args[i - 1] + pointer = load.buffer_var + if pointer in pointer_to_buffer: + new_args.append(np.prod(list(pointer_to_buffer[pointer].shape))) + continue + new_args.append(stmt.args[i]) + + return tvm.tir.Call(stmt.dtype, stmt.op, new_args, stmt.span) + if isinstance(stmt, tvm.tir.Allocate): + # Where a pointer needs rewriting, the allocate for it must be rewritten + allocate_pointer = stmt.buffer_var + if allocate_pointer in pointer_to_buffer: + if pointer_to_buffer[allocate_pointer] in rewrite_buffer: + new_buffer = rewrite_buffer[pointer_to_buffer[allocate_pointer]] + new_pointer = rewrite_pointer[allocate_pointer] + return tvm.tir.Allocate( + new_pointer, + new_buffer.dtype, + new_buffer.shape, + stmt.condition, + stmt.body, + stmt.span, + ) + # The following rewrites would be better expressed by just rewriting the Vars, however + # ir_transform doesn't seem to visit Vars. So instead we do the next best thing and rewrite + # the nodes which contain the Vars. + if isinstance(stmt, tvm.tir.Load): + load_pointer = stmt.buffer_var + if load_pointer in rewrite_pointer: + new_pointer = rewrite_pointer[load_pointer] + element_type = new_pointer.type_annotation.element_type.dtype + return tvm.tir.Load( + element_type, new_pointer, stmt.index, stmt.predicate, stmt.span + ) + if isinstance(stmt, tvm.tir.AttrStmt): + node_pointer = stmt.node + if node_pointer in rewrite_pointer: + return tvm.tir.AttrStmt( + rewrite_pointer[node_pointer], stmt.attr_key, stmt.value, stmt.body, stmt.span + ) + return None + + def _ftransform(f, mod, ctx): + for i, param in enumerate(f.params): + if i in const_dict: + buffer_to_const[f.buffer_map[param]] = const_dict[i].flatten() + pointer_to_buffer[f.buffer_map[param].data] = f.buffer_map[param] + + # First analyse what needs to be rewritten + new_body = tvm.tir.stmt_functor.ir_transform( + f.body, _visit_encode_pre, _visit_encode_post, ["tir.Call", "tir.Allocate"] + ) + # Then perform the rewrites + new_body = tvm.tir.stmt_functor.ir_transform( + f.body, None, _visit_rewrite, ["tir.Call", "tir.Allocate", "tir.Load", "tir.AttrStmt"] + ) + new_buffer_map = {} + # Rewrite the buffer map and const dict to instead use the encoded versions + for i, param in enumerate(f.params): + buffer = f.buffer_map[param] + if buffer in rewrite_buffer: + new_buffer = rewrite_buffer[buffer] + new_buffer_map[param] = new_buffer + new_value = buffer_to_const[new_buffer] + new_const_dict[i] = new_value + elif buffer in buffer_to_const: + new_const_dict[i] = buffer_to_const[buffer] + new_buffer_map[param] = buffer + else: + new_buffer_map[param] = buffer + + new_f = tvm.tir.PrimFunc(f.params, new_body, f.ret_type, new_buffer_map, f.attrs, f.span) + return new_f + + def _encode_constants(mod): + mod, divided_const_dict = DivideConstants(const_dict)(mod) + const_dict.clear() + for key, value in divided_const_dict.items(): + const_dict[key] = value + transform_func = tvm.tir.transform.prim_func_pass( + _ftransform, opt_level=0, name="tir.ethosu.encode_constants" + ) + new_func = transform_func(mod) + return new_func, new_const_dict + + return _encode_constants diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py new file mode 100644 index 000000000000..5d9027bf2078 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/scheduler.py @@ -0,0 +1,277 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Different schedulers for Arm(R) Ethos(TM)-U NPU""" +import tvm + + +def schedule(te_graph, const_dict, cascader=None): + """Schedule a TE graph for NPU compilation. + + Parameters + ---------- + te_graph + The TE graph to schedule. + const_dict : dict of int to numpy.ndarray + The constant dictionary. + cascader : callable, optional + A cascading function to apply optimizing scheduling + to the graph. + + Returns + ------- + s : tvm.te.Schedule + The completed schedule for the graph. + + """ + s = tvm.te.create_schedule([t.op for t in te_graph.outputs]) + if cascader: + cascader(te_graph, const_dict, s) + inline_no_ops(te_graph, s) + schedule_pragmas(s) + schedule_cache_reads(s) + return s + + +def tile_nd(s, tensor, tile): + """Scheduling utility to perform N-dimensional tiling. + + Parameters + ---------- + s : tvm.te.Schedule + The schedule to apply the tiling to. + tensor : tvm.te.Tensor + The tensor to apply the tiling to. + tile : tuple + The N-dimensional tile size. + + Returns + ------- + outer_indices : list of tvm.tir.IterVar + The outer iteration variables. + inner_indices : list of tvm.tir.IterVar + The inner iteration variables. + + """ + outer_indices = [] + inner_indices = [] + for i, size in enumerate(tile): + outer, inner = s[tensor].split(tensor.op.axis[i], size) + outer_indices.append(outer) + inner_indices.append(inner) + + s[tensor].reorder(*outer_indices, *inner_indices) + return outer_indices, inner_indices + + +def total_cascader(stripe_size): + """A demo/test cascader which tries to cascade every op in the graph together. + + The desired output stride size should be specified. Note this only works + for single output graphs. + + Parameters + ---------- + stripe_size : tuple + The output stripe size. + + Returns + ------- + func : callable + The cascading function. + + """ + + def _cascader(te_graph, const_dict, sch): + scheduled = set() + + def _visit(tensor, stage, ax): + if tensor not in scheduled and isinstance(tensor.op, tvm.te.ComputeOp): + sch[tensor].compute_at(stage, ax) + scheduled.add(tensor) + for input_tensor in tensor.op.input_tensors: + _visit(input_tensor, stage, ax) + + assert len(te_graph.outputs) == 1 + out = te_graph.outputs[0] + oi, _ = tile_nd(sch, out, stripe_size) + for ax in oi: + sch[out].unroll(ax) + for input_tensor in out.op.input_tensors: + _visit(input_tensor, sch[out], oi[-1]) + + return _cascader + + +def copy_constants(): + """A simple planner which copies all constant data from FLASH -> SRAM. + + Returns + ------- + planner : callable + The planning function. + """ + + def _planner(te_graph, const_dict, sch): + planned = set() # type: ignore + + def _visit(tensor, reader): + if tensor is not planned: + planned.add(tensor) + if isinstance(tensor.op, tvm.te.PlaceholderOp): + index = list(te_graph.inputs).index(tensor) + if index in const_dict: + sch.cache_read(tensor, "global", [reader]) + + elif isinstance(tensor.op, tvm.te.ComputeOp): + for input_tensor in tensor.op.input_tensors: + _visit(input_tensor, tensor) + + for output_tensor in te_graph.outputs: + _visit(output_tensor, None) + + return _planner + + +def schedule_pragmas(sch): + """Add pragmas to the operators that require them. + + This adds the pragmas used for codegen to the NPU ops. + They are taken directly from the TE compute op's attributes. + Modifies the schedule in-place. + + Parameters + ---------- + sch : tvm.te.Schedule + The schedule. + + """ + + def _add_pragmas(stage, ax): + if "op" in [attr for attr, val in stage.op.attrs.items()]: + stage.pragma(ax, "op", stage.op.attrs["op"]) + for attr, val in stage.op.attrs.items(): + if attr != "op": + stage.pragma(ax, str(attr), val) + + for stage in sch.stages: + if ( + isinstance(stage.op, tvm.te.ComputeOp) + and len(stage.op.axis) + len(stage.op.reduce_axis) > 0 + ): + # The logic ensures the pragmas are assigned to the inner tiling loops + # rather than the outer ones (which end up getting unrolled). + num_inner_loops = len(stage.op.axis) + len(stage.op.reduce_axis) + ax = stage.leaf_iter_vars[-num_inner_loops] + _add_pragmas(stage, ax) + + +def schedule_cache_reads(sch): + """Schedule cache reads that have been introduced. + + There are two things we need to happen to cache_read stages. They should be tagged + with the 'ethosu_copy' pragma and have all their axes fused to make them 1D. + + Parameters + ---------- + sch : tvm.te.Schedule + The schedule. + + """ + + def _detect_cache_read(stage): + # Try and detect cache_reads by checking if the compute op is identity + if isinstance(stage.op, tvm.te.ComputeOp): + op = stage.op + if "ethosu" in op.name: + return False + axes = op.axis + if len(op.input_tensors) == 1: + tensor = op.input_tensors[0] + try: + identity_op = tensor(*axes) + except ValueError: + return False + if tvm.tir.analysis.expr_deep_equal(identity_op, op.body[0]): + return True + return False + + for stage in sch.stages: + if _detect_cache_read(stage): + fax = stage.fuse(*stage.op.axis) + stage.pragma(fax, "op", "ethosu_copy") + + +def inline_no_ops(te_graph, sch): + """Inline 'no-ops' - operations that in principle do nothing. + + Modifies the schedule in-place. For now we inline reshape and + strided slice - more could be added. + + Parameters + ---------- + te_graph + The TE graph. + sch : tvm.te.Schedule + The schedule. + + """ + no_ops = {"T_reshape", "T_strided_slice"} + scheduled = set() + + def _visit(tensor): + if tensor not in scheduled and isinstance(tensor.op, tvm.te.ComputeOp): + if tensor.op.name in no_ops: + sch[tensor].compute_inline() + scheduled.add(tensor) + for input_tensor in tensor.op.input_tensors: + _visit(input_tensor) + + for out in te_graph.outputs: + _visit(out) + + +class Convolution2DCompute: + """A helper class to manipulate the series of compute ops that make up a 2D convolution.""" + + def __init__(self, read, convert_to_nhwc, pad, conv2d, convert_to_nhcwb16, write): + self.read = read + self.convert_to_nhwc = convert_to_nhwc + self.pad = pad + self.conv2d = conv2d + self.convert_to_nhcwb16 = convert_to_nhcwb16 + self.write = write + + @classmethod + def from_output(cls, out): + write = out + convert_to_nhcwb16 = write.op.input_tensors[0] + conv2d = convert_to_nhcwb16.op.input_tensors[0] + pad = conv2d.op.input_tensors[0] + convert_to_nhwc = pad.op.input_tensors[0] + read = convert_to_nhwc.op.input_tensors[0] + return cls(read, convert_to_nhwc, pad, conv2d, convert_to_nhcwb16, write) + + def split(self, sch, axis, val): + outer, inner = sch[self.write].split(self.write.op.axis[axis], val) + sch[self.write].reorder( + outer, *[ax for ax in self.write.op.axis if ax != self.write.op.axis[axis]], inner + ) + sch[self.write].unroll(outer) + g = sch.create_group(outputs=self.convert_to_nhcwb16, inputs=self.read, include_inputs=True) + g.compute_at(sch[self.write], outer) + return outer diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/spec.py b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py new file mode 100644 index 000000000000..3ecbcd5f3cdc --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/spec.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The TIR serialization specification for Arm(R) Ethos(TM)-U NPU.""" +from typing import Union +from typing import get_type_hints +from inspect import isclass + +import tvm +from tvm.relay.backend.contrib.ethosu import util + + +def create_serial_object(serialized_type, deserialized_elements): + """ + This function will create serialized type that is one of the subclasses + of tvm.relay.backend.contrib.ethosu.tir.spec.SerializableFormat + + Parameters + ---------- + serialized_type : a subclass type of SerializableFormat + + deserialized_elements : list + The list of arguments that needs to packed to create SerializableFormat objects + + Returns + ------- + The constructed object of type serialized_type + """ + + def _create_serial_object(internal_serialized_type, read_element_idx=0): + """The internal function that increments the read_element_idx + when creating nested serial objects""" + arg_len = util.get_arg_count(internal_serialized_type.__init__) - 1 + serial_init_types = get_type_hints(internal_serialized_type.__init__) + serial_init_arg_names = list(serial_init_types.keys()) + serial_init_args = [] + assert arg_len == len(serial_init_arg_names) + for si_arg_name in serial_init_arg_names: + si_arg_type = serial_init_types[si_arg_name] + if isclass(si_arg_type) and issubclass(si_arg_type, SerializableFormat): + sia, read_element_idx = _create_serial_object(si_arg_type, read_element_idx) + serial_init_args.append(sia) + else: + serial_init_args.append(deserialized_elements[read_element_idx]) + read_element_idx += 1 + return internal_serialized_type(*serial_init_args), read_element_idx + + # Just return the primary serial object + return _create_serial_object(serialized_type)[0] + + +class SerializableFormat: + """Base class to retrieve arguments on a predefined ordering""" + + def __iter__(self): + # Note class attribute definition order is preserved - see PEP 520 + for name in self.__dict__: + value = self.__getattribute__(name) + if isinstance(value, SerializableFormat): + yield from list(value) + else: + yield value + + def __getitem__(self, index): + # Note class attribute definition order is preserved - see PEP 520 + name = list(self.__dict__.keys())[index] + return self.__getattribute__(name) + + +class SerialFeatureMap(SerializableFormat): + """Specialization class to retrieve arguments of a Feature Map + (similiar to NpuFeatureMap of Vela) on a predefined ordering""" + + def __init__( + self, + data_type: str, + height: int, + width: int, + channels: int, + tile_height_0: int, + tile_height_1: int, + tile_width_0: int, + tile_address_0: tvm.tir.expr.Load, + tile_address_1: Union[tvm.tir.expr.Load, int], + tile_address_2: Union[tvm.tir.expr.Load, int], + tile_address_3: Union[tvm.tir.expr.Load, int], + scale: float, + zero_point: int, + layout: str, + stride_h: int, + stride_w: int, + stride_c: int, + ): + self.data_type = data_type + self.height = height + self.width = width + self.channels = channels + self.tile_height_0 = tile_height_0 + self.tile_height_1 = tile_height_1 + self.tile_width_0 = tile_width_0 + self.tile_address_0 = tile_address_0 + self.tile_address_1 = tile_address_1 + self.tile_address_2 = tile_address_2 + self.tile_address_3 = tile_address_3 + self.scale = scale + self.zero_point = zero_point + self.layout = layout + self.stride_h = stride_h + self.stride_w = stride_w + self.stride_c = stride_c + + +class SerialKernel(SerializableFormat): + """Specialization class to retrieve arguments of a Kernel + (similiar to NpuKernel of Vela) on a predefined ordering""" + + def __init__( + self, + width: int, + height: int, + stride_w: int, + stride_h: int, + dilation_w: int, + dilation_h: int, + ): + self.width = width + self.height = height + self.stride_w = stride_w + self.stride_h = stride_h + self.dilation_w = dilation_w + self.dilation_h = dilation_h + + +class SerialAddressRange(SerializableFormat): + """Specialization class to retrieve arguments of a AddressRange + (similiar to NpuAddressRange of Vela) on a predefined ordering""" + + def __init__(self, address: tvm.tir.expr.Load, length: int): + self.address = address + self.length = length + + +class SerialPadding(SerializableFormat): + """Specialization class to retrieve arguments of a Padding + (similiar to NpuPadding of Vela) on a predefined ordering""" + + def __init__(self, top: int, left: int, bottom: int, right: int): + self.top = top + self.left = left + self.bottom = bottom + self.right = right + + +class SerialActivation(SerializableFormat): + """Specialization class to retrieve arguments of a Activation + (similiar to NpuActivation of Vela) on a predefined ordering""" + + def __init__(self, op: str, clip_min: int, clip_max: int): + self.op = op + self.clip_min = clip_min + self.clip_max = clip_max + + +class Serial2DConvolution(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.conv2d tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ofm: SerialFeatureMap, + kernel: SerialKernel, + weight: SerialAddressRange, + weight_zero_point: int, + scale_bias: SerialAddressRange, + padding: SerialPadding, + activation: SerialActivation, + upscale: str, + ): + self.ifm = ifm + self.ofm = ofm + self.kernel = kernel + self.weight = weight + self.weight_zero_point = weight_zero_point + self.scale_bias = scale_bias + self.padding = padding + self.activation = activation + self.upscale = upscale + + +class Serial2DDepthwise(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.depthwise2d tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ofm: SerialFeatureMap, + kernel: SerialKernel, + weight: SerialAddressRange, + weight_zero_point: int, + scale_bias: SerialAddressRange, + padding: SerialPadding, + activation: SerialActivation, + upscale: str, + ): + self.ifm = ifm + self.ofm = ofm + self.kernel = kernel + self.weight = weight + self.weight_zero_point = weight_zero_point + self.scale_bias = scale_bias + self.padding = padding + self.activation = activation + self.upscale = upscale + + +class SerialCopy(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.copy tir extern call on a predefined ordering""" + + def __init__( + self, read_address: tvm.tir.expr.Load, length: int, write_address: tvm.tir.expr.Load + ): + self.read_address = read_address + self.length = length + self.write_address = write_address + + +class SerialPooling(SerializableFormat): + """Specialization class to retrieve arguments of + a ethosu.pooling tir extern call on a predefined ordering""" + + def __init__( + self, + ifm: SerialFeatureMap, + ofm: SerialFeatureMap, + pooling_type: str, + pool_shape: SerialKernel, + padding: SerialPadding, + activation: SerialActivation, + upscale: str, + ): + self.ifm = ifm + self.ofm = ofm + self.pooling_type = pooling_type + self.pool_shape = pool_shape + self.padding = padding + self.activation = activation + self.upscale = upscale diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/transform.py b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py new file mode 100644 index 000000000000..0403ce2c7e8f --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/transform.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-argument +"""Extract information from the transform operators in TIR.""" +import tvm +from .spec import SerialCopy +from .utils import get_base_address, get_op_attrs + + +def get_copy_params(stmt, producers, consumers): + """Get the parameters necessary to construct a call_extern for a copy. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement of a copy loop nest. + producers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that produces their values. + consumers : dict of tvm.tir.Var to tvm.tir.AttrStmt + A dictionary to associate pointers with the loop nest + that consumes their values. + + Returns + ------- + SerialCopy + The parameters needed to construct a copy. + tvm.tir.Var + The output pointer of the copy operation. + + """ + _, body = get_op_attrs(stmt) + length = body.extent + write_store = body.body + write_base = get_base_address(write_store.index) + read_load = body.body.value + read_base = get_base_address(read_load.index) + dtype = body.body.value.dtype + return ( + SerialCopy( + read_address=tvm.tir.expr.Load(dtype, read_load.buffer_var, read_base), + length=length, + write_address=tvm.tir.expr.Load(dtype, write_store.buffer_var, write_base), + ), + write_store.buffer_var, + None, + ) diff --git a/python/tvm/relay/backend/contrib/ethosu/tir/utils.py b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py new file mode 100644 index 000000000000..7d6fd3bf82d8 --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir/utils.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Helper utility functions used by the TIR compiler""" +import tvm +from tvm import arith + + +# TODO(@mbaret): Formalise this with a specification +def get_weights_pointer(tir_extern_call): + """Get the weights pointer from a NPU extern call if it exists""" + if tir_extern_call.args[0] == "ethosu_conv2d": + return tir_extern_call.args[41].buffer_var + return None + + +# TODO(@mbaret): Formalise this with a specification +def get_scale_bias_pointer(tir_extern_call): + """Get the scale_bias pointer from a NPU extern call if it exists""" + if tir_extern_call.args[0] == "ethosu_conv2d": + return tir_extern_call.args[44].buffer_var + return None + + +def get_op_attrs(stmt): + """Iterate through nested attribute statements accumulating their values + in an attribute dictionary. + + The "pragma_" prefix is removed as a convenience. + + Parameters + ---------- + stmt : tvm.tir.AttrStmt + The outermost attribute statement to begin from. + + Returns + ------- + attrs : dict of str to object + The attribute dictionary. + stmt : tvm.tir.Stmt + The body after having collected the final attribute statement. + + """ + attrs = {} + while isinstance(stmt, tvm.tir.AttrStmt): + # The pragma scheduler inserts "pragma_" before all the + # attr names, this is annoying so we get rid of it + attr = stmt.attr_key.replace("pragma_", "") + attrs[attr] = stmt.value + stmt = stmt.body + + return attrs, stmt + + +def get_strides(index, stride_vars): + """Get the striding of given vars in an indexing expression. + + Parameters + ---------- + index : tvm.tir.PrimExpr + The index expression where the stride vars are present. + stride_vars : list of tvm.tir.Var + The vars to determine the striding of. + + Returns + ------- + strides : list of int + The striding of each stride var in the index expression + in the same order as the stride vars were given. + + """ + strides = [1] * len(stride_vars) + dmap = {} + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Var): + dmap[stmt] = arith.IntervalSet(0, 0) + + tvm.tir.stmt_functor.post_order_visit(index, _visit) + min_value = int(arith.Analyzer().int_set(index, dmap).min_value) + for var in dmap: + if var in stride_vars: + # NOTE: Doing this using a [0, 1] interval doesn't work reliably + # Seems to be a bug + dmap[var] = arith.IntervalSet(1, 1) + max_value = int(arith.Analyzer().int_set(index, dmap).max_value) + stride = int(max_value - min_value) + i = stride_vars.index(var) + strides[i] = stride + dmap[var] = arith.IntervalSet(0, 0) + + return strides + + +def get_base_address(index): + """Determine the first (base) address accessed by an index expression. + + Parameters + ---------- + index : tvm.tir.PrimExpr + The index expression to determine the base address of. + + Returns + ------- + base_address: + The first address accessed by the index expression. + + """ + dmap = {} + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Var): + dmap[stmt] = arith.IntervalSet(0, 0) + + tvm.tir.stmt_functor.post_order_visit(index, _visit) + base_address = int(arith.Analyzer().int_set(index, dmap).min_value) + return base_address + + +def get_outer_loops(stmt, layout): + """Get the outer loops of an operator. + + Parameters + ---------- + stmt : tvm.tir.For + The outermost loop. + layout : str + The output tensor layout (NHWC or NHCWB16). + + Returns + ------- + n : tvm.tir.For + The batch loop. + h : tvm.tir.For + The height loop. + w : tvm.tir.For + The width loop. + c : tvm.tir.For + The channels loop. + b : tvm.tir.For + The brick loop. None for NHWC + body : tvm.tir.Stmt + The inner body of the loops. + + """ + if layout == "NHWC": + n = stmt + h = n.body + w = h.body + c = w.body + b = tvm.tir.For(tvm.tir.Var("b", "int32"), 0, 0, 0, tvm.tir.Evaluate(0)) + return n, h, w, c, b, c.body + if layout == "NHCWB16": + n = stmt + h = n.body + cb = h.body + w = cb.body + b = w.body + return n, h, w, cb, b, b.body + return None + + +def get_loads(stmt): + """Get the Load statements. + + Parameters + ---------- + stmt : tvm.tir.Stmt + The statement to get the Loads from. + + Returns + ------- + loads : list of tvm.tir.Load + The Loads found. + + """ + loads = [] + + def _visit(s): + if isinstance(s, tvm.tir.Load): + loads.append(s) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + return loads + + +def get_stores(stmt): + """Get the Store statements. + + Parameters + ---------- + stmt : tvm.tir.Stmt + The statement to get the Stores from. + + Returns + ------- + stores : list of tvm.tir.Store + The Stores found. + + """ + stores = [] + + def _visit(s): + if isinstance(s, tvm.tir.Store): + stores.append(s) + + tvm.tir.stmt_functor.post_order_visit(stmt, _visit) + return stores diff --git a/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py new file mode 100644 index 000000000000..ce9abcbd683d --- /dev/null +++ b/python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This source will contain code to convert TIR, as produced by +the Relay to TIR compilation process, to Vela API calls to +generate command stream. +""" +from typing import NamedTuple +from enum import auto +from enum import Enum +import numpy as np # type: ignore +import ethosu.vela.api as vapi # type: ignore + +import tvm +from tvm.relay.backend.contrib.ethosu import vela_api +from tvm.relay.backend.contrib.ethosu.tir import spec + + +class BufferType(Enum): + """The buffer types the codegen supports""" + + constant = auto() + input_or_output = auto() + scratch = auto() + input = auto() + output = auto() + + +class BufferInfo(NamedTuple): + """A data structure to hold metadata of the buffer""" + + # If the buffer holds constants, the values will contain that otherwise None + values: np.ndarray + shape: tvm.ir.container.Array + dtype: np.dtype + btype: BufferType + + +def extract_buffer_info(mod, param_dict): + """ + This function is to read the tvm.IRModule that + contains Relay to TIR compiled IRModule. Thereafter, + this will extract the buffer information as the shape + and constant data (if any). + + Parameters + ---------- + mod : tvm.IRModule + The NPU TIR IRModule. + param_dict : dict + A dictionary containing param idx --> const numpy.NDArray + Returns + ------- + dict + a dictionary of buffer names --> BufferInfo + """ + buffer_info = dict() + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + for idx, const_data in param_dict.items(): + param = primfunc.params[idx] + buffer_info[primfunc.buffer_map[param].data] = BufferInfo( + const_data, const_data.shape, const_data.dtype, BufferType.constant + ) + + for param in primfunc.params: + if primfunc.buffer_map[param].data not in buffer_info.keys(): + buffer_info[primfunc.buffer_map[param].data] = BufferInfo( + None, + primfunc.buffer_map[param].shape, + primfunc.buffer_map[param].dtype, + BufferType.input_or_output, + ) + + def populate_allocate_buffer_info(stmt): + if isinstance(stmt, tvm.tir.stmt.Allocate): + allocate = stmt + buffer_info[allocate.buffer_var] = BufferInfo( + None, + allocate.extents, + allocate.dtype, + BufferType.scratch, + ) + + tvm.tir.stmt_functor.post_order_visit(primfunc.body, populate_allocate_buffer_info) + + return buffer_info + + +def _convert_clip_bounds(npu_op): + """ + This function will convert the min and max value + of clip activations to non quantized floats as + expected by the API. + Parameters + ---------- + npu_op : ethosu.vela.api.NpuBlockOperation + """ + clip_min_quant = npu_op.activation.min + clip_max_quant = npu_op.activation.max + clip_min_actual = ( + clip_min_quant - npu_op.ofm.quantization.zero_point + ) * npu_op.ofm.quantization.scale_f32 + clip_max_actual = ( + clip_max_quant - npu_op.ofm.quantization.zero_point + ) * npu_op.ofm.quantization.scale_f32 + npu_op.activation.min = clip_min_actual + npu_op.activation.max = clip_max_actual + + +def translate_ethosu_conv2d(tir_extern_call): + """This function will translate a tir extern_call + as produced by Relay to TIR compilation. + Parameters + ---------- + tir_extern_call : tvm.tir.Call + This should be an tir external call that has a agreed upon ordering + for TIR Compiler. See Serial2DConvolution in + tvm/relay/backend/contrib/ethosu/tir/spec.py for the ordering. + + Returns + ------- + ethosu.vela.api.NpuConv2DOperation + The vela object containing the params of ethosu_conv2d + weights_zero_point : int + The zero point of the weights + """ + # We skip the first element as it is the extern_call function name + serial_object = spec.create_serial_object(spec.Serial2DConvolution, tir_extern_call.args[1:]) + return _create_npu_op_conv2d(serial_object) + + +def _create_npu_op_conv2d(serial_2d_convolution): + """This is a helper function to capture a list + of arguments to create Vela NpuConv2DOperation object + """ + npu_conv2d_op = vapi.NpuConv2DOperation() + npu_conv2d_op.ifm = _create_npu_feature_map(serial_2d_convolution.ifm) + npu_conv2d_op.ofm = _create_npu_feature_map(serial_2d_convolution.ofm) + npu_conv2d_op.kernel = _create_npu_kernel(serial_2d_convolution.kernel) + npu_conv2d_op.weights = [_create_npu_address_range(serial_2d_convolution.weight)] + weights_zero_point = np.int64(serial_2d_convolution.weight_zero_point.value) + npu_conv2d_op.biases = [_create_npu_address_range(serial_2d_convolution.scale_bias)] + npu_conv2d_op.padding = _create_npu_padding(serial_2d_convolution.padding) + + npu_conv2d_op.activation = _create_npu_activation(serial_2d_convolution.activation) + if ( + npu_conv2d_op.activation + and npu_conv2d_op.activation.op_type == vapi.NpuActivationOp.NONE_OR_RELU + ): + _convert_clip_bounds(npu_conv2d_op) + + npu_conv2d_op.upscale = _create_npu_resampling_mode(serial_2d_convolution.upscale) + target_accel_type = vela_api.get_target_accel_type() # type: ignore + block_config = vela_api.get_optimal_block_config(npu_conv2d_op, target_accel_type) + npu_conv2d_op.block_config = block_config + weights_shape_ohwi = [ + npu_conv2d_op.ofm.shape.depth, + npu_conv2d_op.kernel.height, + npu_conv2d_op.kernel.width, + npu_conv2d_op.ifm.shape.depth, + ] + npu_conv2d_op.block_traversal = vela_api.calculate_block_traversal_mode( + is_depthwise=False, + weights_shape_ohwi=weights_shape_ohwi, + ifm_bitdepth=npu_conv2d_op.ifm.data_type.size_in_bits(), + ) + return npu_conv2d_op, weights_zero_point + + +def _create_npu_feature_map(serial_feature_map): + """This is a helper function to capture a list + of arguments to create Vela NpuFeatureMap object + """ + layout_map = {"NHWC": vapi.NpuLayout.NHWC, "NHCWB16": vapi.NpuLayout.NHCWB16} + datatype_map = { + "uint8": vapi.NpuDataType.UINT8, + "int8": vapi.NpuDataType.INT8, + "uint16": vapi.NpuDataType.UINT16, + "int16": vapi.NpuDataType.INT16, + "int32": vapi.NpuDataType.INT32, + } + layout = str(serial_feature_map.layout.value) + data_type = str(serial_feature_map.data_type.value) + assert layout in layout_map.keys() + assert data_type in datatype_map.keys() + nfm = vapi.NpuFeatureMap() + nfm.data_type = datatype_map[data_type] + nfm.shape = vapi.NpuShape3D( + int(serial_feature_map.height.value), + int(serial_feature_map.width.value), + int(serial_feature_map.channels.value), + ) + nfm.tiles = vapi.NpuTileBox( + int(serial_feature_map.tile_height_0.value), + int(serial_feature_map.tile_height_1.value), + int(serial_feature_map.tile_width_0.value), + [ + serial_feature_map.tile_address_0, + serial_feature_map.tile_address_1, + serial_feature_map.tile_address_2, + serial_feature_map.tile_address_3, + ], + ) + nfm.quantization = _create_npu_quantization( + serial_feature_map.scale, serial_feature_map.zero_point + ) + nfm.layout = layout_map[layout] + nfm.strides = vapi.NpuShape3D( + int(serial_feature_map.stride_h.value), + int(serial_feature_map.stride_w.value), + int(serial_feature_map.stride_c.value), + ) + return nfm + + +def _create_npu_kernel(serial_kernel): + """This is a helper function to capture a list + of arguments to create Vela NpuKernel object + """ + nknl = vapi.NpuKernel( + w=int(serial_kernel.width.value), + h=int(serial_kernel.height.value), + stride_x=int(serial_kernel.stride_w.value), + stride_y=int(serial_kernel.stride_h.value), + dilation_x=int(serial_kernel.dilation_w.value), + dilation_y=int(serial_kernel.dilation_h.value), + ) + return nknl + + +def _create_npu_address_range(serial_address_range): + """This is a helper function to capture a list + of arguments to create Vela NpuAddressRange object + """ + addr_range = vapi.NpuAddressRange( + # region will be updated later + region=0, + address=serial_address_range.address, + length=int(serial_address_range.length.value), + ) + return addr_range + + +def _create_npu_quantization( + scale, + zero_point, +): + """This is a helper function to capture a list + of arguments to create Vela NpuQuantization object + """ + # Scale could be an ndarray if per-channel quantization is available + if not isinstance(scale, tvm.tir.expr.Load): + if isinstance(scale.value, float): + scale = np.single(scale.value) + else: + assert isinstance(scale.value.value, float) + scale = np.single(scale.value.value) + q_params = vapi.NpuQuantization(scale_f32=scale, zero_point=zero_point.value) + return q_params + + +def _create_npu_weights_zero_point( + zero_point, +): + """This is a helper function to capture the weights zero point""" + return zero_point.value + + +def _create_npu_padding(serial_padding): + """This is a helper function to capture a list + of arguments to create Vela NpuPadding object""" + padding = vapi.NpuPadding( + top=int(serial_padding.top.value), + left=int(serial_padding.left.value), + bottom=int(serial_padding.bottom.value), + right=int(serial_padding.right.value), + ) + return padding + + +def _create_npu_activation(serial_activation): + """This is a helper function to capture a list + of arguments to create Vela NpuActivation object""" + if serial_activation.op == "NONE": + return None + if ( + serial_activation.op == "CLIP" + and serial_activation.clip_min == 0 + and serial_activation.clip_max == 0 + ): + return None + op_map = { + "CLIP": vapi.NpuActivationOp.NONE_OR_RELU, + "TANH": vapi.NpuActivationOp.TANH, + "SIGMOID": vapi.NpuActivationOp.SIGMOID, + } + op = str(serial_activation.op.value) + assert op in op_map.keys() + act_op = vapi.NpuActivation(op_map[op]) + act_op.min = int(serial_activation.clip_min.value) + act_op.max = int(serial_activation.clip_max.value) + return act_op + + +def _create_npu_resampling_mode( + mode, +): + """This is a helper function to capture a list + of arguments to create Vela NpuResamplingMode object""" + mode_map = { + "NONE": vapi.NpuResamplingMode.NONE, + "NEAREST": vapi.NpuResamplingMode.NEAREST, + "TRANSPOSE": vapi.NpuResamplingMode.TRANSPOSE, + } + mode = str(mode.value) + assert mode in mode_map.keys() + return mode_map[mode] diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py b/python/tvm/relay/backend/contrib/ethosu/util.py index e9d89d33e6f0..0919d3fe7a5f 100644 --- a/python/tvm/relay/backend/contrib/ethosu/util.py +++ b/python/tvm/relay/backend/contrib/ethosu/util.py @@ -21,6 +21,7 @@ Refer to the description inside such functions """ +from inspect import signature from enum import Enum from typing import Union, Tuple, Dict, Optional import numpy as np # type: ignore @@ -138,6 +139,12 @@ def round_up(a: int, b: int) -> int: return ((a + b - 1) // b) * b +def get_accelerator_config(): + """Get the variant of the accelerator to compile for""" + compiler_attrs = tvm.get_global_func("relay.ext.ethosu.get_compiler_attrs")() + return compiler_attrs.accelerator_config + + # pylint: disable=unused-argument def partition_for_ethosu( mod: tvm.ir.IRModule, params: Optional[Dict[str, tvm.runtime.NDArray]] = None, **opts @@ -173,6 +180,13 @@ def partition_for_ethosu( return mod +def get_arg_count(func): + """Helper function to get the number of + arguments in a python function""" + sig = signature(func) + return len(sig.parameters) + + def get_dim_value(layout: str, dim: int): """This is a helper function to retrieve the value of the dimension given the shape and the layout diff --git a/python/tvm/relay/backend/contrib/ethosu/vela_api.py b/python/tvm/relay/backend/contrib/ethosu/vela_api.py index 72ae18123b3d..5009c3157c77 100644 --- a/python/tvm/relay/backend/contrib/ethosu/vela_api.py +++ b/python/tvm/relay/backend/contrib/ethosu/vela_api.py @@ -28,6 +28,7 @@ from ethosu.vela import api as vapi # type: ignore from tvm.relay.backend.contrib.ethosu import util # type: ignore +from tvm.relay.backend.contrib.ethosu import tir_to_cs_translator as tirtocs # pylint: disable=invalid-name logger = logging.getLogger("Ethos-U") @@ -111,6 +112,53 @@ def _get_optimal_block_config(all_valid_block_configs: List[vapi.NpuShape3D]) -> return max_area_depth_block_configs[0] +def encode_weights(tir_extern_call, values, accel_type): + """This is an API function to compress weights by passing + a tir_extern_call to NPU Convolution operation and values. + + Parameters + ---------- + tir_extern_call : tvm.tir.Call + tir_extern_call to NPU Convolution operation + values : numpy.ndarray + The constant flattened weight data in OHWI layout + accel_type : ethosu.vela.api.NpuAccelerator + The NPU accelerator variant + + Returns + ------- + bytearray + Compressed weights + """ + supported_ops = ["ethosu_conv2d"] + op = str(tir_extern_call.args[0].value) + assert op in supported_ops + npu_op, weights_zero_point = tirtocs.translate_ethosu_conv2d(tir_extern_call) + block_config = get_optimal_block_config(npu_op, accel_type) + # The weight layout is assumed to be flat OHWI, always. + assert len(values.shape) == 1 + shape_ohwi = ( + npu_op.ofm.shape.depth, + npu_op.kernel.height, + npu_op.kernel.width, + npu_op.ifm.shape.depth, + ) + assert values.size == np.prod(shape_ohwi) + values = np.reshape(values, shape_ohwi) + return compress_weights( + weights=values, + weights_zp=weights_zero_point, + # The weight layout is assumed to be OHWI, always. + weights_layout="OHWI", + ifm_bitdepth=npu_op.ifm.data_type.size_in_bits(), + block_depth=block_config.depth, + dilation=(npu_op.kernel.dilation_x, npu_op.kernel.dilation_y), + accel_type=accel_type, + # TODO(@manupa-arm): change this when we support depthwise + is_depthwise=False, + ) + + def compress_weights( weights: np.ndarray, weights_zp: int, @@ -308,3 +356,17 @@ def _calculate_hw_bias_scales( hw_bias_scales = [_quantize_scale(bs) for bs in bias_scales] return hw_bias_scales + + +def get_target_accel_type(): + """This is a helper function to convert cli accelerator type str argument + to NpuAccelerator""" + npu_accel_str_map = { + "ethos-u55-256": vapi.NpuAccelerator.Ethos_U55_256, + "ethos-u55-128": vapi.NpuAccelerator.Ethos_U55_128, + "ethos-u55-64": vapi.NpuAccelerator.Ethos_U55_64, + "ethos-u55-32": vapi.NpuAccelerator.Ethos_U55_32, + } + accel_type_str = util.get_accelerator_config() + assert accel_type_str in npu_accel_str_map.keys(), f"{accel_type_str} is not supported" + return npu_accel_str_map[accel_type_str] diff --git a/src/relay/backend/contrib/ethosu/compiler_attrs.cc b/src/relay/backend/contrib/ethosu/compiler_attrs.cc new file mode 100644 index 000000000000..6a87d11d5d6a --- /dev/null +++ b/src/relay/backend/contrib/ethosu/compiler_attrs.cc @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../../op/make_op.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! \brief Attributes to store the compiler options for Arm(R) Ethos(TM)-U NPU. */ +struct EthosUCompilerConfigNode : public tvm::AttrsNode { + String accelerator_config; + + TVM_DECLARE_ATTRS(EthosUCompilerConfigNode, "ext.attrs.EthosUCompilerConfigNode") { + TVM_ATTR_FIELD(accelerator_config) + .describe( + "The class of Arm(R) Ethos(TM)-U NPU; possible values = {ethos-u55-32, ethos-u55-64, " + "ethos-u55-128, ethos-u55-256}") + .set_default("ethos-u55-256"); + } +}; + +class EthosUCompilerConfig : public Attrs { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(EthosUCompilerConfig, Attrs, EthosUCompilerConfigNode); +}; + +TVM_REGISTER_NODE_TYPE(EthosUCompilerConfigNode); +TVM_REGISTER_PASS_CONFIG_OPTION("relay.ext.ethosu.options", EthosUCompilerConfig); + +auto GetCompilerAttrs() { + auto ctx = transform::PassContext::Current(); + auto cfg = ctx->GetConfig("relay.ext.ethosu.options"); + if (!cfg.defined()) { + cfg = AttrsWithDefaultValues(); + } + return cfg; +} +TVM_REGISTER_GLOBAL("relay.ext.ethosu.get_compiler_attrs").set_body_typed(GetCompilerAttrs); + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/ethosu/to_te_graph.cc b/src/relay/backend/contrib/ethosu/to_te_graph.cc new file mode 100644 index 000000000000..9646c39da089 --- /dev/null +++ b/src/relay/backend/contrib/ethosu/to_te_graph.cc @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/backend/contrib/ethosu/to_te_graph.cc + * \brief Lower a Relay function to a TE graph. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../compile_engine.h" +#include "../../utils.h" + +namespace tvm { +namespace relay { +namespace contrib { +namespace ethosu { + +/*! \brief Node container to represent a Tensor Expression graph. */ +class TEGraphNode : public Object { + public: + /* \brief The inputs to the graph */ + tvm::Array inputs; + /* \brief The outputs to the graph */ + tvm::Array outputs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + } + + static constexpr const char* _type_key = "relay.TEGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(TEGraphNode, Object); +}; + +class TEGraph : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TEGraph, ObjectRef, TEGraphNode); +}; + +TVM_REGISTER_NODE_TYPE(TEGraphNode); + +Array GetShape(const Array& shape) { + // for now, we always use int32 shape when possible + // even if the result of shape inference becomes int64. + Array res; + for (IndexExpr val : shape) { + const int64_t* pval = tir::as_const_int(val); + if (pval != nullptr) { +#ifndef TVM_INDEX_DEFAULT_I64 + ICHECK_LE(pval[0], std::numeric_limits::max()); + ICHECK_GE(pval[0], std::numeric_limits::min()); + res.push_back(IntImm(DataType::Int(32), *pval)); +#else + res.push_back(val); +#endif // TVM_INDEX_DEFAULT_I64 + } else if (val->IsInstance()) { + res.push_back(val.as()->ToVar()); + } else { + res.push_back(val); + } + } + return res; +} + +class RelayToTE : public backend::MemoizedExprTranslator> { + public: + RelayToTE() = default; + + TEGraph Lower(const Function& prim_func) { + auto graph_node = make_object(); + for (Var param : prim_func->params) { + Array inputs; + if (const auto* ttype = param->checked_type().as()) { + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + graph_node->inputs.push_back(tensor); + inputs.push_back(tensor); + } else { + // flatten tuple of tensor type. + const auto* tuple_type = param->type_as(); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + ICHECK(ttype != nullptr); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + graph_node->inputs.push_back(tensor); + inputs.push_back(tensor); + } + } + memo_[param] = inputs; + } + graph_node->outputs = this->VisitExpr(prim_func->body); + return TEGraph(graph_node); + } + + Array VisitExpr_(const VarNode* op) final { + LOG(FATAL) << "Free variable " << op->name_hint(); + return {}; + } + + Array VisitExpr_(const ConstantNode* op) final { + using tir::make_const; + ICHECK(op->is_scalar()); + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); + return {value}; + } + + Array VisitExpr_(const CallNode* call_node) final { + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + + Array inputs; + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + } + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + Array outputs; + LoweredOutput lowered_out = + (*flower_call)(GetRef(call_node), inputs, tvm::Target("llvm")); + outputs = lowered_out->outputs; + + if (outputs.size() != 1) { + const auto* tuple_type = call_node->checked_type().as(); + ICHECK(tuple_type) << "Expect output to be a tuple type"; + ICHECK_EQ(tuple_type->fields.size(), outputs.size()); + } + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + const auto* tuple_type = op->tuple->type_as(); + Array tuple = VisitExpr(op->tuple); + ICHECK_EQ(tuple_type->fields.size(), tuple.size()); + ICHECK_GE(op->index, 0); + ICHECK_LT(static_cast(op->index), tuple.size()); + return {tuple[op->index]}; + } +}; + +TVM_REGISTER_GLOBAL("relay.backend.contrib.ethosu.LowerToTE") + .set_body_typed([](Function prim_func) { return RelayToTE().Lower(prim_func); }); + +} // namespace ethosu +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_ethosu/infra.py b/tests/python/contrib/test_ethosu/infra.py new file mode 100644 index 000000000000..fc795c066cb6 --- /dev/null +++ b/tests/python/contrib/test_ethosu/infra.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +This module provides infrastructure to verify the correctness of +the command stream produced. + +Currently it will invoke vela to generate a vela-optimized tflite +in which the command stream is contained as a custom operator. +This class include methods to parse the custom operator to extract +the command stream and perform an equivalency check for single operator +test cases. +""" + +import numpy +from enum import IntEnum + +import tvm +from tvm import relay +import tvm.relay.backend.contrib.ethosu.op as ethosu_ops +from tvm.topi.nn.utils import get_pad_tuple + + +class AttachType(IntEnum): + kGroupRoot = 1 + kInline = 2 + kInlinedAlready = 3 + kScope = 4 + kScanUpdate = 5 + + +def generate_weights_data(shape, dtype): + size = 1 + for dim in shape: + size *= dim + return (numpy.arange(size) % 255).reshape(shape).astype(dtype) + + +def get_convolutional_args(call, include_buffers=False, remove_constants=False): + """A method to extract the arguments from conv2d or depthwise2d extern call.""" + args = call.args + conv_args = [] + remove_indices = [0] + + if remove_constants: + remove_indices += [41, 42, 44, 45] + + for i, arg in enumerate(args): + if i in remove_indices: + continue + elif isinstance(arg, tvm.tir.expr.IntImm) or isinstance(arg, tvm.tir.expr.FloatImm): + conv_args.append(arg.value) + elif isinstance(arg, tvm.tir.expr.Load) and not include_buffers: + conv_args.append(arg.index) + else: + conv_args.append(arg) + + return conv_args + + +def make_ethosu_conv2d( + ifm, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + activation="NONE", + ifm_layout="NHWC", + ofm_layout="NHWC", + weight_dtype="int8", +): + # conv params + weight_shape = (ofm_channels, kernel_shape[0], kernel_shape[1], ifm_channels) + padding = get_pad_tuple(padding, kernel_shape) + + scale_bias_data = generate_weights_data((weight_shape[0], 10), "uint8") + scale_bias = relay.const(scale_bias_data, dtype="uint8") + weight_data = generate_weights_data(weight_shape, "int8") + weight = relay.const(weight_data, dtype=weight_dtype) + conv = ethosu_ops.ethosu_conv2d( + ifm, + weight, + scale_bias, + lut=relay.const([], dtype="int8"), + ifm_scale=0.5, + ifm_zero_point=10, + weight_zero_point=12, + ofm_scale=0.25, + ofm_zero_point=14, + kernel_shape=kernel_shape, + ofm_channels=ofm_channels, + strides=strides, + padding=padding, + dilation=dilation, + activation=activation, + clip_min=10 if activation == "CLIP" else 0, + clip_max=100 if activation == "CLIP" else 0, + upscale="NONE", + ifm_layout=ifm_layout, + ofm_layout=ofm_layout, + ) + return conv diff --git a/tests/python/contrib/test_ethosu/test_attr_passing.py b/tests/python/contrib/test_ethosu/test_attr_passing.py new file mode 100644 index 000000000000..a2fbe1888d2a --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_attr_passing.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu import util + + +def test_compiler_attr(): + config = { + "accelerator_config": "ethos-u55-32", + } + with tvm.transform.PassContext(opt_level=3, config={"relay.ext.ethosu.options": config}): + with tvm.target.Target("c -device=micro_dev"): + assert util.get_accelerator_config() == config["accelerator_config"] + + +def test_compiler_attr_default(): + default_config = { + "accelerator_config": "ethos-u55-256", + } + with tvm.transform.PassContext(opt_level=3): + with tvm.target.Target("c -device=micro_dev"): + assert util.get_accelerator_config() == default_config["accelerator_config"] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_compiler.py b/tests/python/contrib/test_ethosu/test_compiler.py new file mode 100644 index 000000000000..4df6311a230c --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_compiler.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir + + +def test_lower_to_tir(): + data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8") + weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8") + p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32") + conv = relay.nn.conv2d( + data, + weight, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + multiply = relay.multiply(relay.const(-22, dtype="int32"), p2) + tile = relay.tile(multiply, reps=(1, 1, 1, 1001)) + subtract = relay.subtract(conv, tile) + func = subtract + expr = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + lower_to_tir(mod["main"]) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py b/tests/python/contrib/test_ethosu/test_encode_constants.py new file mode 100644 index 000000000000..0e546ae2fd24 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_encode_constants.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +from tvm import tir +from tvm import script +from tvm import relay +from tvm.script import ty +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute + +from infra import make_ethosu_conv2d + + +# fmt: off +@tvm.script.tir +class WeightStreamOnly: + def main(placeholder: ty.handle, ethosu_write: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_7, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = tir.match_buffer(placeholder_5, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_9 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = tir.match_buffer(placeholder_3, [112], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_1, [128], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + placeholder_global = tir.allocate([128], "uint8", "global") + placeholder_d_global = tir.allocate([32], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_6.data, 0), 128, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 128, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_5.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_7.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 112, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_9.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 112, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_weight_stream_only(): + def _planner(te_graph, const_dict, sch): + weights = te_graph.inputs[1] + bias = te_graph.inputs[2] + out = te_graph.outputs[0] + conv_compute = Convolution2DCompute.from_output(out) + co = conv_compute.split(sch, 3, 2) + cache_weights = sch.cache_read(weights, "global", [conv_compute.conv2d]) + cache_bias = sch.cache_read(bias, "global", [conv_compute.conv2d]) + sch[cache_weights].compute_at(sch[out], co) + sch[cache_bias].compute_at(sch[out], co) + + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv = make_ethosu_conv2d( + ifm, + 32, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, consts = lower_to_tir(func, cascader=_planner) + script = tvm.script.asscript(mod, True) + test_mod = tvm.script.from_source(script) + reference_mod = WeightStreamOnly() + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + reference_const_sizes = {2: 128, 3: 32, 4: 112, 5: 32, 6: 112, 7: 32, 8: 112, 9: 32} + test_const_sizes = {} + for key, value in consts.items(): + test_const_sizes[key] = len(value) + + assert reference_const_sizes == test_const_sizes + + +# fmt: off +@tvm.script.tir +class DirectReadOnly: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_3, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([4096], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 592, 12, tir.load("uint8", buffer_2.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 8, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 160, 12, tir.load("uint8", buffer_3.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_direct_read_only(): + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv1 = make_ethosu_conv2d( + ifm, + 32, + 16, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + conv2 = make_ethosu_conv2d( + conv1, + 16, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, consts = lower_to_tir(func) + + script = tvm.script.asscript(mod, True) + test_mod = tvm.script.from_source(script) + reference_mod = DirectReadOnly() + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + reference_const_sizes = {1: 592, 2: 160, 3: 160, 4: 80} + test_const_sizes = {} + for key, value in consts.items(): + test_const_sizes[key] = len(value) + + assert reference_const_sizes == test_const_sizes + + +# fmt: off +@tvm.script.tir +class MixedRead: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, placeholder_5: ty.handle, placeholder_6: ty.handle, placeholder_7: ty.handle, placeholder_8: ty.handle, placeholder_9: ty.handle, placeholder_10: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_7, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_5, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_3, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_4, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_4 = tir.match_buffer(placeholder_9, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_5 = tir.match_buffer(placeholder_6, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_11 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_6 = tir.match_buffer(placeholder_1, [592], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_7 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_8 = tir.match_buffer(placeholder_8, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_9 = tir.match_buffer(placeholder_10, [32], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([4096], "int8", "global") + placeholder_global = tir.allocate([80], "uint8", "global") + placeholder_d_global = tir.allocate([32], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_11.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 16, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_6.data, 0), 592, 12, tir.load("uint8", buffer_7.data, 0), 160, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_2.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_3.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_5.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 2), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_8.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 4), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_4.data, 0), 80, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_9.data, 0), 32, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 16, 16, 0, 16, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 16, 1, "int8", 16, 16, 2, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 6), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 80, 12, tir.load("uint8", placeholder_d_global, 0), 32, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_mixed_read(): + def _planner(te_graph, const_dict, sch): + weight = te_graph.inputs[4] + scale_bias = te_graph.inputs[5] + out = te_graph.outputs[0] + conv_compute = Convolution2DCompute.from_output(out) + co = conv_compute.split(sch, 3, 2) + cache_weight = sch.cache_read(weight, "global", [conv_compute.conv2d]) + cache_scale_bias = sch.cache_read(scale_bias, "global", [conv_compute.conv2d]) + sch[cache_weight].compute_at(sch[out], co) + sch[cache_scale_bias].compute_at(sch[out], co) + + def _get_func(): + ifm = relay.var("ifm", shape=(1, 16, 16, 32), dtype="int8") + conv1 = make_ethosu_conv2d( + ifm, + 32, + 16, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + conv2 = make_ethosu_conv2d( + conv1, + 16, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, consts = lower_to_tir(func, cascader=_planner) + + script = tvm.script.asscript(mod, True) + test_mod = tvm.script.from_source(script) + reference_mod = MixedRead() + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + reference_const_sizes = { + 1: 592, + 2: 160, + 4: 80, + 5: 32, + 6: 80, + 7: 32, + 8: 80, + 9: 32, + 10: 80, + 11: 32, + } + test_const_sizes = {} + for key, value in consts.items(): + test_const_sizes[key] = len(value) + + assert reference_const_sizes == test_const_sizes + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_extract_constants.py b/tests/python/contrib/test_ethosu/test_extract_constants.py new file mode 100644 index 000000000000..98094d8a4ed4 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_extract_constants.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import extract_constants + +import numpy as np + + +def test_extract_constants_single(): + def _get_func(): + var_input = relay.var("data", shape=(10, 10), dtype="uint8") + const_data = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_input = relay.const(const_data, dtype="uint8") + out = relay.add(var_input, const_input) + func = relay.Function(relay.analysis.free_vars(out), out) + func = run_opt_pass(func, relay.transform.InferType()) + return func, const_input + + def _expected(): + var_input1 = relay.var("data", shape=(10, 10), dtype="uint8") + var_input2 = relay.var("p1", shape=(10, 10), dtype="uint8") + out = relay.add(var_input1, var_input2) + func = relay.Function(relay.analysis.free_vars(out), out) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func, const = _get_func() + new_func, const_dict = extract_constants(func) + assert tvm.ir.structural_equal(new_func, _expected()) + assert 1 in const_dict + assert (const_dict[1] == const.data.asnumpy()).all() + + +def test_extract_constants_multi(): + def _get_func(): + var_input1 = relay.var("data1", shape=(10, 10), dtype="uint8") + var_input2 = relay.var("data2", shape=(10, 10), dtype="uint8") + const_data_1 = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_data_2 = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_data_3 = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_data_4 = np.random.uniform(0, 255, (10, 10)).astype("uint8") + const_input_1 = relay.const(const_data_1, dtype="uint8") + const_input_2 = relay.const(const_data_2, dtype="uint8") + const_input_3 = relay.const(const_data_3, dtype="uint8") + const_input_4 = relay.const(const_data_4, dtype="uint8") + out = relay.add(var_input1, var_input2) + out = relay.add(out, const_input_1) + out = relay.add(out, const_input_2) + out = relay.add(out, const_input_3) + out = relay.add(out, const_input_4) + func = relay.Function(relay.analysis.free_vars(out), out) + func = run_opt_pass(func, relay.transform.InferType()) + return func, [const_input_1, const_input_2, const_input_3, const_input_4] + + def _expected(): + var_input1 = relay.var("data1", shape=(10, 10), dtype="uint8") + var_input2 = relay.var("data2", shape=(10, 10), dtype="uint8") + var_input3 = relay.var("p1", shape=(10, 10), dtype="uint8") + var_input4 = relay.var("p2", shape=(10, 10), dtype="uint8") + var_input5 = relay.var("p3", shape=(10, 10), dtype="uint8") + var_input6 = relay.var("p4", shape=(10, 10), dtype="uint8") + out = relay.add(var_input1, var_input2) + out = relay.add(out, var_input3) + out = relay.add(out, var_input4) + out = relay.add(out, var_input5) + out = relay.add(out, var_input6) + func = relay.Function(relay.analysis.free_vars(out), out) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func, consts = _get_func() + new_func, const_dict = extract_constants(func) + assert tvm.ir.structural_equal(new_func, _expected()) + for i, const in enumerate(consts): + assert i + 2 in const_dict + assert (const_dict[i + 2] == consts[i].data.asnumpy()).all() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_lower_to_te.py b/tests/python/contrib/test_ethosu/test_lower_to_te.py new file mode 100644 index 000000000000..cabd68b4e8d2 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_lower_to_te.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +from tvm import relay +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te +from tvm.relay.backend.contrib.ethosu.tir.scheduler import Convolution2DCompute +import tvm.relay.backend.contrib.ethosu.op as ethosu_ops + + +def test_ethosu_conv2d(): + ifm = relay.var("ifm", shape=(1, 10, 20, 30), dtype="uint8") + weight = relay.var("weight", shape=(40, 3, 3, 30), dtype="uint8") + scale_bias = relay.var("scale_bias", shape=(40, 10), dtype="uint8") + lut = relay.var("lut", shape=(), dtype="uint8") + conv = ethosu_ops.ethosu_conv2d( + ifm, + weight, + scale_bias, + lut, + ifm_scale=0.5, + ifm_zero_point=10, + weight_zero_point=12, + ofm_scale=0.25, + ofm_zero_point=14, + ofm_channels=40, + padding=(1, 1, 1, 1), + kernel_shape=(3, 3), + strides=(1, 1), + dilation=(1, 1), + ) + expr = relay.Function(relay.analysis.free_vars(conv), conv) + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + lowered = lower_to_te(mod["main"]) + assert len(lowered.outputs) == 1 + assert len(lowered.inputs) == 4 + conv2d_compute = Convolution2DCompute.from_output(lowered.outputs[0]) + assert conv2d_compute.conv2d.name == "ethosu_conv2d" + input_shapes = set() + for inp in lowered.inputs: + input_shapes.add(tuple([x.value for x in inp.shape])) + assert input_shapes == {(40, 10), (1, 10, 20, 30), (40, 3, 3, 30), ()} + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_conv2d.py b/tests/python/contrib/test_ethosu/test_replace_conv2d.py new file mode 100644 index 000000000000..96fe56d1778e --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_conv2d.py @@ -0,0 +1,547 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +import tvm.script +from tvm.script import tir, ty +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import total_cascader +from infra import make_ethosu_conv2d, get_convolutional_args + + +@pytest.mark.parametrize( + "trial", + [ + [(1, 8, 8, 3), 3, 16, (1, 1), (2, 1), (1, 1), (1, 1), "TANH", "NHWC", "NHWC"], + [(1, 8, 8, 3), 3, 16, (1, 1), (0, 0), (1, 1), (1, 1), "NONE", "NHWC", "NHWC"], + [(1, 1, 1, 1), 1, 16, (1, 1), (0, 0), (1, 1), (1, 1), "CLIP", "NHWC", "NHWC"], + [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (1, 2), "SIGMOID", "NHWC", "NHWC"], + [(1, 8, 2, 8, 16), 18, 12, (1, 1), (2, 1), (1, 1), (1, 1), "CLIP", "NHCWB16", "NHWC"], + [(1, 7, 9, 4), 4, 71, (3, 2), (1, 2), (2, 1), (1, 2), "CLIP", "NHWC", "NHCWB16"], + [(1, 4, 12, 9, 16), 182, 67, (2, 3), (6, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], + [(1, 7, 9, 4), 4, 13, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHWC"], + [(1, 7, 9, 4), 4, 71, (3, 2), (1, 2), (2, 1), (2, 2), "CLIP", "NHWC", "NHCWB16"], + [ + (1, 13, 12, 19, 16), + 182, + 67, + (1, 3), + (5, 3), + (2, 1), + (2, 1), + "CLIP", + "NHCWB16", + "NHCWB16", + ], + ], +) +def test_conv2d_single(trial): + def _get_func( + ifm_shape, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + conv = make_ethosu_conv2d( + ifm, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + # TODO(@mbaret) Fix the tests for these known failures + # These are anticipated to actually be correct, just a testing issue to do with + # equivalent convolutions. + known_failures = [ + [(1, 3, 12, 9, 16), 182, 67, (2, 3), (1, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], + [(1, 2, 12, 9, 16), 182, 67, (1, 3), (6, 3), (2, 2), (1, 1), "CLIP", "NHCWB16", "NHCWB16"], + ] + func = _get_func(*trial) + mod, _ = lower_to_tir(func) + data = [] + + def _visit(stmt): + if isinstance(stmt, tvm.tir.Call): + data.append(get_convolutional_args(stmt, remove_constants=True)) + + tvm.tir.stmt_functor.post_order_visit(mod["main"].body, _visit) + ( + ifm_shape, + ifm_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + activation, + ifm_layout, + ofm_layout, + ) = trial + dilated_kernel_h = (kernel_shape[0] - 1) * dilation[0] + 1 + dilated_kernel_w = (kernel_shape[1] - 1) * dilation[1] + 1 + if ifm_layout == "NHWC": + ifm_stride_c = 1 + ifm_stride_w = ifm_shape[3] + ifm_stride_h = ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[2] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1 + else: + ifm_stride_w = 16 + ifm_stride_c = 16 * ifm_shape[3] + ifm_stride_h = 16 * ifm_shape[2] * ifm_shape[3] + ofm_height = (ifm_shape[1] - dilated_kernel_h + padding[0] + padding[0]) // strides[0] + 1 + ofm_width = (ifm_shape[3] - dilated_kernel_w + padding[1] + padding[1]) // strides[1] + 1 + + if ofm_layout == "NHWC": + ofm_stride_c = 1 + ofm_stride_w = ofm_channels if ofm_width > 1 else 1 + ofm_stride_h = ofm_channels * ofm_width if ofm_height > 1 else 1 + else: + ofm_stride_w = 16 + ofm_stride_c = 16 * ofm_width + ofm_stride_h = 16 * ofm_width * ((ofm_channels - 1) // 16 + 1) + + answer = [ + "int8", + ifm_shape[1], + ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + ifm_channels, + ifm_shape[1], + 0, + ifm_shape[2] if ifm_layout == "NHWC" else ifm_shape[3], + 0, + 0, + 0, + 0, + 0.5, + 10, + ifm_layout, + ifm_stride_h, + ifm_stride_w, + ifm_stride_c, + "int8", + ofm_height, + ofm_width, + ofm_channels, + ofm_height, + 0, + ofm_width, + 0, + 0, + 0, + 0, + 0.25, + 14, + ofm_layout, + ofm_stride_h, + ofm_stride_w, + ofm_stride_c, + kernel_shape[1], + kernel_shape[0], + strides[1], + strides[0], + dilation[1], + dilation[0], + 12, + padding[0], + padding[1], + padding[0], + padding[1], + activation, + 10 if activation == "CLIP" else 0, + 100 if activation == "CLIP" else 0, + "NONE", + ] + assert data[0] == answer, data[0] + + +# fmt: off +@tvm.script.tir +class Conv2dDoubleCascade1: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_3, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_1, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([1024], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 160, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 304, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 3, 8, 0, 4, tir.load("int8", placeholder_5.data, 12), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 32, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 160, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 4, 32, 8, 0, 4, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 128, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 32), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 304, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dDoubleCascade2: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_1, [1312], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_3, [2608], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 8, 8, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([1536], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_2.data, 0), 1312, 12, tir.load("uint8", buffer_1.data, 0), 320, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 256), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 2608, 12, tir.load("uint8", buffer.data, 0), 80, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 48), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 3, 1, "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_2.data, 0), 1312, 12, tir.load("uint8", buffer_1.data, 0), 320, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 32, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 8, 8, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 64, 8, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 2608, 12, tir.load("uint8", buffer.data, 0), 80, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dDoubleCascade3: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 20, 4, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_3, [1744], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_4, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_2, [320], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_1, [880], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 16, 16, 3], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([2560], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 16, 3, 8, 0, 16, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 8, 8, 32, 8, 0, 8, tir.load("int8", ethosu_write_2, 512), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 8, 32, 8, 0, 8, tir.load("int8", ethosu_write_2, 512), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 2, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 12, 16, 3, 12, 0, 16, tir.load("int8", placeholder_5.data, 192), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 10, 8, 32, 10, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 10, 8, 32, 10, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 8, 4, 8, 8, 0, 4, tir.load("int8", ethosu_write_1.data, 256), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 1, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 4, 16, 3, 4, 0, 16, tir.load("int8", placeholder_5.data, 576), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 48, 3, 1, "int8", 4, 8, 32, 4, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 256, 32, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer_3.data, 0), 880, 12, tir.load("uint8", buffer_2.data, 0), 320, 0, 1, 2, 0, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 4, 8, 32, 4, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 256, 32, 1, "int8", 4, 4, 8, 4, 0, 4, tir.load("int8", ethosu_write_1.data, 512), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 32, 8, 1, 2, 3, 2, 1, 2, 1, tir.load("uint8", buffer.data, 0), 1744, 12, tir.load("uint8", buffer_1.data, 0), 80, 0, 1, 2, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dDoubleCascade4: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, placeholder_3: ty.handle, placeholder_4: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_1, [1456], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_2, [352], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_5 = tir.match_buffer(placeholder, [1, 8, 1, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 2, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_2 = tir.match_buffer(placeholder_4, [272], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_3 = tir.match_buffer(placeholder_3, [11040], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + ethosu_write_2 = tir.allocate([2304], "int8", "global") + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 384), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 1456, 12, tir.load("uint8", buffer_1.data, 0), 352, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 384), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 11040, 12, tir.load("uint8", buffer_2.data, 0), 272, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 6, 8, 3, 6, 0, 8, tir.load("int8", placeholder_5.data, 256), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 128, 16, 1, "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 384, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 1456, 12, tir.load("uint8", buffer_1.data, 0), 352, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 8, 35, 5, 0, 8, tir.load("int8", ethosu_write_2, 0), 0, 0, 0, tir.float32(0.5), 10, "NHCWB16", 384, 16, 128, "int8", 4, 8, 26, 4, 0, 8, tir.load("int8", ethosu_write_1.data, 1024), 0, 0, 0, tir.float32(0.25), 14, "NHCWB16", 256, 16, 128, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_3.data, 0), 11040, 12, tir.load("uint8", buffer_2.data, 0), 272, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize( + "trial", + [ + [ + Conv2dDoubleCascade1(), + (1, 8, 8, 3), + 3, + 32, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + "NHWC", + (1, 8, 4, 8), + ], + [ + Conv2dDoubleCascade2(), + (1, 8, 8, 3), + 3, + 32, + 8, + (3, 3), + (1, 1), + (1, 1), + (1, 1), + "NHWC", + (1, 4, 8, 8), + ], + [ + Conv2dDoubleCascade3(), + (1, 16, 16, 3), + 3, + 32, + 8, + (3, 2), + (2, 1), + (1, 2), + (1, 2), + "NHWC", + (1, 8, 4, 8), + ], + [ + Conv2dDoubleCascade4(), + (1, 8, 1, 8, 16), + 3, + 35, + 26, + (3, 3), + (1, 1), + (1, 1), + (1, 1), + "NHCWB16", + (1, 4, 2, 8, 16), + ], + ], +) +def test_conv2d_double_cascade(trial): + def _get_func( + ifm_shape, + ifm_channels, + mid_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + layout, + ): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + conv1 = make_ethosu_conv2d( + ifm, + ifm_channels, + mid_channels, + kernel_shape, + padding, + strides, + dilation, + "NONE", + layout, + layout, + ) + conv2 = make_ethosu_conv2d( + conv1, + mid_channels, + ofm_channels, + kernel_shape, + padding, + strides, + dilation, + "NONE", + layout, + layout, + ) + func = relay.Function(relay.analysis.free_vars(conv2), conv2) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + reference_mod = trial[0] + params = trial[1:] + func = _get_func(*params[:-1]) + mod, _ = lower_to_tir(func, cascader=total_cascader(params[-1])) + script = tvm.script.asscript(mod, True) + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) + + +# fmt: off +@tvm.script.tir +class Conv2dInlineCopy1: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 10, 12, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 8, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 8, 8, 4, 8, 0, 8, tir.load("int8", placeholder_3.data, 120), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 96, 8, 1, "int8", 8, 8, 16, 8, 0, 8, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer.data, 0), 848, 12, tir.load("uint8", buffer_1.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dInlineCopy2: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 3, 5, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 7, 9, 5], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [656], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 3, 5, 3, 3, 0, 5, tir.load("int8", placeholder_3.data, 146), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 45, 5, 1, "int8", 3, 5, 16, 3, 0, 5, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 80, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 656, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize( + "trial", + [ + [Conv2dInlineCopy1(), (1, 10, 12, 8), (0, 1, 3, 0), (1, 9, 11, 4)], + [Conv2dInlineCopy2(), (1, 7, 9, 5), (0, 3, 2, 1), (1, 6, 7, 4)], + ], +) +def test_conv2d_inline_copy(trial): + def _get_func(ifm_shape, lower, upper, ofm_channels=16): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + sliced = relay.strided_slice(ifm, lower, upper) + conv = make_ethosu_conv2d( + sliced, upper[3] - lower[3], ofm_channels, (3, 3), (1, 1), (1, 1), (1, 1) + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + reference_mod = trial[0] + params = trial[1:] + func = _get_func(*params) + mod, _ = lower_to_tir(func) + script = tvm.script.asscript(mod, True) + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) + + +# fmt: off +@tvm.script.tir +class Conv2dInlineReshape1: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [4, 6, 8, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dInlineReshape2: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 24, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dInlineReshape3: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [192, 1], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None + + +@tvm.script.tir +class Conv2dInlineReshape4: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 8, 6, 16], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer = tir.match_buffer(placeholder_2, [160], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [192], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [848], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + # body + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 1, 1, 0, 1, "NONE", 0, 0, "NONE", dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 5, 6, 4, 5, 0, 6, tir.load("int8", placeholder_3.data, 72), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 24, 4, 1, "int8", 4, 6, 16, 4, 0, 6, tir.load("int8", ethosu_write_1.data, 384), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 96, 16, 1, 3, 3, 1, 1, 1, 1, tir.load("uint8", buffer_1.data, 0), 848, 12, tir.load("uint8", buffer.data, 0), 160, 0, 1, 1, 1, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +@pytest.mark.parametrize( + "trial", + [ + [Conv2dInlineReshape1(), (4, 6, 8, 1), (1, 8, 6, 4), "NHWC"], + [Conv2dInlineReshape2(), (1, 4 * 6, 8), (1, 8, 6, 4), "NHWC"], + [Conv2dInlineReshape3(), (4 * 6 * 8, 1), (1, 8, 6, 4), "NHWC"], + [Conv2dInlineReshape4(), (4 * 6 * 8,), (1, 8, 6, 4), "NHWC"], + ], +) +def test_conv2d_inline_reshape(trial): + def _get_func(ifm_shape, reshaped, ifm_layout): + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + ifm_reshaped = relay.reshape(ifm, reshaped) + conv = make_ethosu_conv2d( + ifm_reshaped, reshaped[3], 16, (3, 3), (1, 1), (1, 1), (1, 1), "NONE", ifm_layout + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + reference_mod = trial[0] + params = trial[1:] + func = _get_func(*params) + mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 6, 16))) + script = tvm.script.asscript(mod, True) + mod = tvm.script.from_source(script) + tvm.ir.assert_structural_equal(mod["main"], reference_mod["main"], True) + + +# TODO(@mbaret) Fix this case +@pytest.mark.xfail(raises=TypeError, strict=True) +def test_conv2d_big_pad(): + def _get_func(): + ifm_shape = (1, 2, 2, 8) + ifm = relay.var("ifm", shape=ifm_shape, dtype="int8") + conv = make_ethosu_conv2d(ifm, ifm_shape[3], 16, (1, 1), (7, 7), (1, 1), (1, 1), "NHWC") + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, _ = lower_to_tir(func, cascader=total_cascader((1, 4, 4, 16))) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_replace_copy.py b/tests/python/contrib/test_ethosu/test_replace_copy.py new file mode 100644 index 000000000000..222dccacc906 --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_replace_copy.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +import tvm +import tvm.script +from tvm.script import tir, ty +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_tir +from tvm.relay.backend.contrib.ethosu.tir.scheduler import copy_constants + +from infra import make_ethosu_conv2d + + +# fmt: off +@tvm.script.tir +class ReferenceModule: + def main(placeholder: ty.handle, placeholder_1: ty.handle, placeholder_2: ty.handle, ethosu_write: ty.handle) -> None: + # function attr dict + tir.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True}) + buffer = tir.match_buffer(placeholder_2, [80], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + placeholder_3 = tir.match_buffer(placeholder, [1, 16, 16, 32], dtype="int8", elem_offset=0, align=128, offset_factor=1) + buffer_1 = tir.match_buffer(placeholder_1, [304], dtype="uint8", elem_offset=0, align=128, offset_factor=1) + ethosu_write_1 = tir.match_buffer(ethosu_write, [1, 16, 16, 8], dtype="int8", elem_offset=0, align=128, offset_factor=1) + # body + placeholder_global = tir.allocate([304], "uint8", "global") + placeholder_d_global = tir.allocate([80], "uint8", "global") + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer_1.data, 0), 304, tir.load("uint8", placeholder_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_copy", tir.load("uint8", buffer.data, 0), 80, tir.load("uint8", placeholder_d_global, 0), dtype="handle")) + tir.evaluate(tir.call_extern("ethosu_conv2d", "int8", 16, 16, 32, 16, 0, 16, tir.load("int8", placeholder_3.data, 0), 0, 0, 0, tir.float32(0.5), 10, "NHWC", 512, 32, 1, "int8", 16, 16, 8, 16, 0, 16, tir.load("int8", ethosu_write_1.data, 0), 0, 0, 0, tir.float32(0.25), 14, "NHWC", 128, 8, 1, 1, 1, 1, 1, 1, 1, tir.load("uint8", placeholder_global, 0), 304, 12, tir.load("uint8", placeholder_d_global, 0), 80, 0, 0, 0, 0, "NONE", 0, 0, "NONE", dtype="handle")) + __tvm_meta__ = None +# fmt: on + + +def test_copy(): + def _get_func(): + data = relay.var("data", shape=(1, 16, 16, 32), dtype="int8") + conv = make_ethosu_conv2d( + data, + 32, + 8, + (1, 1), + (0, 0), + (1, 1), + (1, 1), + ) + func = relay.Function(relay.analysis.free_vars(conv), conv) + func = run_opt_pass(func, relay.transform.InferType()) + return func + + func = _get_func() + mod, _ = lower_to_tir(func, cascader=copy_constants()) + + script = tvm.script.asscript(mod, True) + test_mod = tvm.script.from_source(script) + reference_mod = ReferenceModule() + tvm.ir.assert_structural_equal(test_mod["main"], reference_mod["main"], True) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_scheduler.py b/tests/python/contrib/test_ethosu/test_scheduler.py new file mode 100644 index 000000000000..b07f8ea7f48b --- /dev/null +++ b/tests/python/contrib/test_ethosu/test_scheduler.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +pytest.importorskip("ethosu.vela") +from tvm import relay +from tvm.relay.testing import run_opt_pass +from tvm import te, topi +from tvm.relay.backend.contrib.ethosu.tir.scheduler import ( + tile_nd, + schedule_pragmas, + inline_no_ops, + total_cascader, + copy_constants, + schedule_cache_reads, +) +from tvm.relay.backend.contrib.ethosu.tir.compiler import lower_to_te, extract_constants +from infra import AttachType, make_ethosu_conv2d + + +class TestTEGraph: + def __init__(self, inputs, outputs): + self.inputs = inputs + self.outputs = outputs + + +def test_tile_nd(): + input = te.placeholder((12, 12), dtype="uint8", name="input") + out = topi.nn.relu(input) + sch = te.create_schedule([out.op]) + outer_iters, inner_iters = tile_nd(sch, out, (3, 4)) + assert tuple(sch[out].leaf_iter_vars) == (*outer_iters, *inner_iters) + + +def test_schedule_pragmas(): + input = te.placeholder((12, 12), dtype="uint8", name="input") + out = te.compute( + (12, 12), + lambda i, j: input[i, j], + attrs={ + "op": "unity", + "info": 1, + }, + ) + sch = te.create_schedule([out.op]) + sch[out].split(out.op.axis[0], 3) + schedule_pragmas(sch) + iter_var = sch[out].leaf_iter_vars[1] + assert list(sch[out].iter_var_attrs[iter_var].pragma_keys) == ["op", "info"] + assert list(sch[out].iter_var_attrs[iter_var].pragma_values) == ["unity", 1] + + +def test_schedule_pragmas_for_const(): + input = te.placeholder((12, 12), dtype="uint8", name="input") + const = te.compute((), lambda: 2) + add = topi.add(input, const) + sch = te.create_schedule([add.op]) + schedule_pragmas(sch) + + +def test_inline_no_ops(): + input = relay.var("input", shape=(12, 12), dtype="uint8") + slice = relay.strided_slice(input, [0, 0], [6, 6]) + relu1 = relay.nn.relu(slice) + reshape = relay.reshape(relu1, (36,)) + relu2 = relay.nn.relu(reshape) + func = relay.Function(relay.analysis.free_vars(relu2), relu2) + func = run_opt_pass(func, relay.transform.InferType()) + + te_graph = lower_to_te(func) + sch = te.create_schedule([te_graph.outputs[0].op]) + inline_no_ops(te_graph, sch) + reshape_tensor = te_graph.outputs[0].op.input_tensors[0] + slice_tensor = reshape_tensor.op.input_tensors[0].op.input_tensors[0] + assert sch[reshape_tensor].attach_type == AttachType.kInline + assert sch[slice_tensor].attach_type == AttachType.kInline + + +def test_total_cascader(): + input = te.placeholder((12, 12), dtype="uint8", name="input") + relu1 = topi.nn.relu(input) + relu2 = topi.nn.relu(relu1) + relu3 = topi.nn.relu(relu2) + sch = te.create_schedule([relu3.op]) + cascader = total_cascader((4, 4)) + cascader(TestTEGraph([input], [relu3]), {}, sch) + assert sch[relu1].attach_type == AttachType.kScope + assert sch[relu2].attach_type == AttachType.kScope + assert sch[relu3].attach_type == AttachType.kGroupRoot + # Check that the attaches are at the correct iter var + assert sch[relu1].attach_ivar == sch[relu3].leaf_iter_vars[1] + assert sch[relu2].attach_ivar == sch[relu3].leaf_iter_vars[1] + + +def test_copy_constants(): + ifm_a = relay.var("IFM_A", shape=(1, 26, 26, 32), dtype="int8") + conv_a = make_ethosu_conv2d(ifm_a, 32, 8, (3, 3), (0, 0), (1, 1), (1, 1)) + conv_b = make_ethosu_conv2d(conv_a, 8, 4, (1, 1), (0, 0), (1, 1), (1, 1)) + func = relay.Function(relay.analysis.free_vars(conv_b), conv_b) + func = run_opt_pass(func, relay.transform.InferType()) + + func, const_dict = extract_constants(func) + te_graph = lower_to_te(func) + + sch = te.create_schedule([te_graph.outputs[0].op]) + planner = copy_constants() + planner(te_graph, const_dict, sch) + assert len(sch.stages) == 21 + assert ".global" in sch.stages[5].op.name + assert ".global" in sch.stages[7].op.name + assert ".global" in sch.stages[15].op.name + assert ".global" in sch.stages[17].op.name + + +def test_schedule_cache_reads(): + a = te.placeholder((12, 12), dtype="uint8", name="a") + b = te.placeholder((12, 12), dtype="uint8", name="b") + add = topi.add(a, b) + sch = te.create_schedule([add.op]) + cr = sch.cache_read(b, "global", [add]) + schedule_cache_reads(sch) + assert len(sch.stages) == 4 + assert len(sch[cr].leaf_iter_vars) == 1 + iv = sch[cr].leaf_iter_vars[0] + assert list(sch[cr].iter_var_attrs[iv].pragma_keys) == ["op"] + assert list(sch[cr].iter_var_attrs[iv].pragma_values) == ["ethosu_copy"] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/contrib/test_ethosu/test_vela_api.py b/tests/python/contrib/test_ethosu/test_vela_api.py index d9b22d10e9e2..a86dd919d5ca 100644 --- a/tests/python/contrib/test_ethosu/test_vela_api.py +++ b/tests/python/contrib/test_ethosu/test_vela_api.py @@ -24,7 +24,9 @@ import tvm from tvm import tir from tvm.script import ty +from tvm.tir import stmt_functor from tvm.relay.backend.contrib.ethosu import vela_api +import tvm.relay.backend.contrib.ethosu.tir_to_cs_translator as tirtocs ACCEL_TYPES = [ vapi.NpuAccelerator.Ethos_U55_256, @@ -451,5 +453,104 @@ def create_mock(test_vec): verify(_test_vec, mock_obj, packed_biases) +def extract_ethosu_conv2d_extern_calls(mod): + """This function will obtain all ethosu_conv2d + calls from a NPU TIR module + + Parameters + ---------- + mod : tvm.IRModule + This is a NPU TIR Module + + Returns + ------- + list + List of tvm.tir.Call objects + that are tir extern calls + for ethosu_conv2d + """ + # There should only be a single function + assert len(mod.functions.items()) == 1 + primfunc = mod.functions.items()[0][1] + + ethosu_conv2d_calls = list() + + def populate_ethosu_conv2d_calls(stmt): + if ( + isinstance(stmt, tvm.tir.Call) + and stmt.op.name == "tir.call_extern" + and stmt.args[0] == "ethosu_conv2d" + ): + ethosu_conv2d_calls.append(stmt) + + stmt_functor.post_order_visit(primfunc.body, populate_ethosu_conv2d_calls) + return ethosu_conv2d_calls + + +@pytest.mark.parametrize( + "accel", + ACCEL_TYPES, +) +def test_encode_weights(accel): + test_vecs = [ + { + # Stimulus + "tir_module": Module1(), + "param_dict": { + 1: np.random.randint(np.iinfo("uint8").min, np.iinfo("uint8").max, [48], "uint8"), + 2: np.random.randint(np.iinfo("int32").min, np.iinfo("int32").max, [16], "int32"), + }, + "accel_type": accel, + # Reference outputs + "block_traversal": vapi.NpuBlockTraversal.PART_KERNEL_FIRST, + }, + ] + + def create_mock(test_vec): + with patch( + "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_encode_weights" + ) as mock_enc_w: + with patch( + "tvm.relay.backend.contrib.ethosu.vela_api.vapi.npu_find_block_configs" + ) as mock_blk_cfg: + mock_blk_cfg.return_value = [vapi.NpuShape3D(8, 8, 8)] + ethosu_conv2d_calls = extract_ethosu_conv2d_extern_calls(test_vec["tir_module"]) + buffer_info = tirtocs.extract_buffer_info( + test_vec["tir_module"], test_vec["param_dict"] + ) + for ethosu_conv2d_call in ethosu_conv2d_calls: + npu_op, _ = tirtocs.translate_ethosu_conv2d(ethosu_conv2d_call) + weights = buffer_info[npu_op.weights[0].address.buffer_var][0] + vela_api.encode_weights(ethosu_conv2d_call, weights, accel) + return mock_enc_w + + def verify(test_vec, mock_enc_w): + ethosu_conv2d_calls = extract_ethosu_conv2d_extern_calls(test_vec["tir_module"]) + buffer_info = tirtocs.extract_buffer_info(test_vec["tir_module"], test_vec["param_dict"]) + for ethosu_conv2d_call in ethosu_conv2d_calls: + npu_op, w_zero_point = tirtocs.translate_ethosu_conv2d(ethosu_conv2d_call) + weights = buffer_info[npu_op.weights[0].address.buffer_var][0] + + assert mock_enc_w.call_args[1]["accelerator"] == accel + assert ( + mock_enc_w.call_args[1]["weights_volume"].flatten() + == weights.astype(np.int64) - w_zero_point + ).all() + assert mock_enc_w.call_args[1]["dilation_xy"] == ( + npu_op.kernel.dilation_x, + npu_op.kernel.dilation_y, + ) + assert mock_enc_w.call_args[1]["dilation_xy"] == ( + npu_op.kernel.dilation_x, + npu_op.kernel.dilation_y, + ) + assert mock_enc_w.call_args[1]["ifm_bitdepth"] == npu_op.ifm.data_type.size_in_bits() + assert mock_enc_w.call_args[1]["block_traversal"] == test_vec["block_traversal"] + + for _test_vec in test_vecs: + _mock_enc_w = create_mock(_test_vec) + verify(_test_vec, _mock_enc_w) + + if __name__ == "__main__": pytest.main([__file__])