Skip to content

Commit

Permalink
[BYOC][ACL] Depthwise convolution support
Browse files Browse the repository at this point in the history
Added support for depthwise convolution. ACL only supports depth-wise convolution when kernel size is 3x3 and 5x5 and strides are (1, 1) or (2, 2), if this is not the case then fallback to TVM.

Also rework tests to remove non-deterministic trials.

*Compute Library for the Arm Architecture (ACL).
*All credits to Luke Hutton @lhutton1

Change-Id: Ida1f5802a65377b84325edf14a0149242c1af857
  • Loading branch information
lhutton1 authored and d-smirnov committed Jan 8, 2021
1 parent 29da763 commit f02e26e
Show file tree
Hide file tree
Showing 11 changed files with 450 additions and 221 deletions.
8 changes: 5 additions & 3 deletions docs/deploy/arm_compute_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
specific language governing permissions and limitations
under the License.
Relay Arm :sup:`®` Compute Library Integration
Relay Arm:sup:`®` Compute Library Integration
==============================================
**Author**: `Luke Hutton <https://github.com/lhutton1>`_

Expand Down Expand Up @@ -195,12 +195,14 @@ Operator support
| | Simple: nn.conv2d |
| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu? |
| | |
| | (only groups = 1 supported) |
| | Normal and depth-wise (when kernel is 3x3 or 5x5 and strides are 1x1 |
| | or 2x2) convolution supported. Grouped convolution is not supported. |
+----------------------+-------------------------------------------------------------------------+
| qnn.conv2d | uint8: |
| | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize |
| | |
| | (only groups = 1 supported) |
| | Normal and depth-wise (when kernel is 3x3 or 5x5 and strides are 1x1 |
| | or 2x2) convolution supported. Grouped convolution is not supported. |
+----------------------+-------------------------------------------------------------------------+
| nn.dense | fp32: |
| | Simple: nn.dense |
Expand Down
109 changes: 105 additions & 4 deletions python/tvm/relay/op/contrib/arm_compute_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import numpy as np
import tvm

from tvm._ffi import register_func
from tvm.relay.expr import const
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.testing.temp_op_attr import TempOpAttr

from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr
from .register import register_pattern_table
from ..strategy.generic import is_depthwise_conv2d


def is_arm_compute_runtime_enabled():
Expand Down Expand Up @@ -71,6 +74,61 @@ def partition_for_arm_compute_lib(mod, params=None):
return seq(mod)


@register_func("relay.ext.arm_compute_lib.optimize")
def preprocess_module(mod):
"""
Pre-process a module containing functions ready for ACL codegen. For now we enforce OHWI
kernel layout and fold the transforms away.
Parameters
----------
mod : Module
The module to run passes on.
Returns
-------
preprocessed_mod : The processed module.
"""

def convert_layout_conv2d(conv2d_function):
def convert_conv(attrs, inputs, tinfos, desired_layouts):
new_attrs = dict(attrs)
data_info = tinfos[0]
weight_info = tinfos[1]
desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
new_attrs["data_layout"] = desired_data_layout
new_attrs["kernel_layout"] = desired_kernel_layout

if is_depthwise_conv2d(
data_info.shape,
attrs["data_layout"],
weight_info.shape,
attrs["kernel_layout"],
attrs["groups"],
):
dkl = desired_kernel_layout
new_attrs["kernel_layout"] = dkl[3] + dkl[1:3] + dkl[0]
return conv2d_function(*inputs, **new_attrs)

return convert_conv

with TempOpAttr(
"nn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.nn.conv2d)
), TempOpAttr(
"qnn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.qnn.op.conv2d)
):
seq = tvm.transform.Sequential(
[
transform.ConvertLayout(
{"nn.conv2d": ["NHWC", "OHWI"], "qnn.conv2d": ["NHWC", "OHWI"]}
),
transform.FoldConstant(),
]
)
preprocessed_mod = seq(mod)
return preprocessed_mod


@register_pattern_table("arm_compute_lib")
def arm_compute_lib_pattern_table():
"""Get the ACL pattern table."""
Expand Down Expand Up @@ -236,8 +294,6 @@ def _func_wrapper(expr):
def conv2d(expr):
"""Check if the external ACL codegen for conv2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.groups != 1:
return False
if attrs.data_layout != "NHWC":
return False
if attrs.out_dtype != "float32" and attrs.out_dtype != "":
Expand All @@ -248,14 +304,25 @@ def conv2d(expr):
kernel_typ = args[1].checked_type
if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "float32":
return False
is_depthwise = is_depthwise_conv2d(
data_typ.shape,
attrs["data_layout"],
kernel_typ.shape,
attrs["kernel_layout"],
attrs["groups"],
)
if is_depthwise:
return depthwise_conv2d(attrs, args)
# ACL doesn't support grouped convolution
if attrs.groups != 1 and not is_depthwise:
return False
return True


def qnn_conv2d(expr):
"""Check if the external ACL codegen for qnn.conv2d should be used."""
attrs, args = expr.attrs, expr.args
if attrs.groups != 1:
return False

if attrs.data_layout != "NHWC":
return False
if attrs.out_dtype != "int32" and attrs.out_dtype != "":
Expand All @@ -266,6 +333,40 @@ def qnn_conv2d(expr):
kernel_typ = args[1].checked_type
if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "uint8":
return False
is_depthwise = is_depthwise_conv2d(
data_typ.shape,
attrs["data_layout"],
kernel_typ.shape,
attrs["kernel_layout"],
attrs["groups"],
)
if is_depthwise:
return depthwise_conv2d(attrs, args)
# ACL doesn't support grouped convolution
if attrs.groups != 1 and not is_depthwise:
return False
return True


def depthwise_conv2d(attrs, args):
"""Check if the external ACL codegen for depthwise convolution should be used.
Note
----
Relay does not have a depthwise conv2d operator whilst ACL does. We simply
separate the checks for depthwise for clarity.
"""
kernel_typ = args[1].checked_type
# Only supports 3x3, 5x5 depthwise
if (
kernel_typ.shape[0] not in [3, 5]
or kernel_typ.shape[1] not in [3, 5]
or kernel_typ.shape[0] != kernel_typ.shape[1]
):
return False
# Stride must be (1, 1) or (2, 2)
if (attrs.strides[0], attrs.strides[1]) not in [(1, 1), (2, 2)]:
return False
return True


Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@

import tvm
from tvm import te
import tvm.relay as relay
import tvm.relay.op as op
from tvm.relay import Prelude
from tvm import relay
from tvm.relay import op
from tvm.relay.prelude import Prelude
from tvm.testing import enabled_targets

from . import mlp
Expand Down
48 changes: 22 additions & 26 deletions src/relay/backend/contrib/arm_compute_lib/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ir/module.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/type.h>
#include <tvm/tir/analysis.h>

#include <memory>
#include <string>
Expand Down Expand Up @@ -126,7 +127,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
nodes.activation = current_call;
current_call = current_call->args[0].as<CallNode>();
}
if (backend::IsOp(current_call, "nn.bias_add")) {
if (backend::IsOp(current_call, "add")) {
nodes.bias = current_call;
current_call = current_call->args[0].as<CallNode>();
}
Expand Down Expand Up @@ -154,19 +155,32 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
*/
std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) {
CompositeConvNode nodes = UnpackCompositeConvolution(cn);
std::string name = "nn.conv2d";

const auto* conv_attr = nodes.conv->attrs.as<Conv2DAttrs>();
ICHECK(conv_attr);
ICHECK(conv_attr->kernel_layout == "OHWI")
<< "Kernel layout must be OHWI, has the module been pre-processed correctly?";

std::string name;
std::string name_prefix = "nn";

// Distinguish between normal and depth-wise convolution
if (conv_attr->channels.defined() &&
tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) &&
conv_attr->groups != 1) {
name = "depthwise_conv2d";
ICHECK(conv_attr->kernel_layout == "IHWO")
<< "Kernel layout must be IHWO, has the module been pre-processed correctly?";
} else {
name = "conv2d";
ICHECK(conv_attr->kernel_layout == "OHWI")
<< "Kernel layout must be OHWI, has the module been pre-processed correctly?";
}

// Inputs must be added in the same order they appear in the relay graph.
std::vector<JSONGraphNodeEntry> inputs;
inputs.push_back(VisitExpr(cn->args[0])[0]);
inputs.push_back(VisitExpr(nodes.conv->args[1])[0]);
if (nodes.requantize) {
name = "qnn.conv2d";
name_prefix = "qnn";
inputs.push_back(VisitExpr(nodes.conv->args[2])[0]); // input zero-point
inputs.push_back(VisitExpr(nodes.conv->args[3])[0]); // kernel zero-point
inputs.push_back(VisitExpr(nodes.conv->args[4])[0]); // input scale
Expand All @@ -180,7 +194,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
inputs.push_back(VisitExpr(nodes.requantize->args[4])[0]); // output zero-point
}

auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
auto json_node = std::make_shared<JSONGraphNode>(name_prefix + "." + name, "kernel", inputs, 1);
SetCallNodeAttribute(json_node, nodes.conv);

// Override attributes
Expand Down Expand Up @@ -224,10 +238,11 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
nodes.requantize = current_call;
current_call = current_call->args[0].as<CallNode>();
}
if (backend::IsOp(current_call, "nn.bias_add")) {
if (backend::IsOp(current_call, "add")) {
nodes.bias = current_call;
current_call = current_call->args[0].as<CallNode>();
}

// Enforce a dense node exists at this point during traversal
if (nodes.requantize) {
ICHECK(backend::IsOp(current_call, "qnn.dense"));
Expand Down Expand Up @@ -329,25 +344,6 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer {
}
};

/*!
* \brief Pre-process a module containing functions ready for ACL codegen.
*
* For now we enforce OHWI kernel layout and fold the transforms away.
*
* \param mod The module to be pre-processed.
* \return The processed module.
*/
IRModule PreProcessModule(const IRModule& mod) {
IRModule preprocessed_module;
tvm::Map<String, Array<String>> desired_layouts = {{"nn.conv2d", {"NHWC", "OHWI"}},
{"qnn.conv2d", {"NHWC", "OHWI"}}};
preprocessed_module = transform::ConvertLayout(desired_layouts)(mod);
preprocessed_module = transform::FoldConstant()(preprocessed_module);
return preprocessed_module;
}

TVM_REGISTER_GLOBAL("relay.ext.arm_compute_lib.optimize").set_body_typed(PreProcessModule);

/*!
* \brief Create a runtime module for ACL.
*
Expand Down
69 changes: 63 additions & 6 deletions src/runtime/contrib/arm_compute_lib/acl_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <arm_compute/core/Types.h>
#include <arm_compute/runtime/NEON/functions/NEArithmeticAddition.h>
#include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h>
#include <arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h>
#include <arm_compute/runtime/NEON/functions/NEElementwiseOperations.h>
#include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h>
#include <arm_compute/runtime/NEON/functions/NEPoolingLayer.h>
Expand Down Expand Up @@ -131,6 +132,9 @@ class ACLRuntime : public JSONRuntimeBase {
if ("nn.conv2d" == op_name || "qnn.conv2d" == op_name) {
CreateConvolution2DLayer(&layer_, node, mm);
num_pools++;
} else if ("nn.depthwise_conv2d" == op_name || "qnn.depthwise_conv2d" == op_name) {
CreateDepthwiseConvolution2DLayer(&layer_, node, mm);
num_pools++;
} else if ("nn.dense" == op_name || "qnn.dense" == op_name) {
CreateFullyConnectedLayer(&layer_, node, mm);
num_pools++;
Expand Down Expand Up @@ -227,12 +231,7 @@ class ACLRuntime : public JSONRuntimeBase {
arm_compute::ActivationLayerInfo act_info;
if (node.HasAttr("activation_type")) {
std::string activation_type = node.GetAttr<std::vector<std::string>>("activation_type")[0];
if (activation_type == "relu") {
act_info = arm_compute::ActivationLayerInfo(
arm_compute::ActivationLayerInfo::ActivationFunction::RELU);
} else {
LOG(FATAL) << "Unsupported activation function";
}
act_info = MakeACLActivationInfo(activation_type);
}

arm_compute::Size2D dilation_2d(std::stoi(dilation[0]), std::stoi(dilation[1]));
Expand Down Expand Up @@ -269,6 +268,64 @@ class ACLRuntime : public JSONRuntimeBase {
layer->function = function;
}

/*!
* \brief Create a 2D depthwise convolution layer.
*
* \param layer The ACL layer to build. Containing inputs, outputs and the ACL function.
* \param node The JSON representation of the operator.
* \param mm The ACL conv2d layer can request auxiliary memory from TVM.
*/
void CreateDepthwiseConvolution2DLayer(
CachedLayer* layer, const JSONGraphNode& node,
const std::shared_ptr<arm_compute::MemoryManagerOnDemand>& mm) {
std::vector<std::string> padding = node.GetAttr<std::vector<std::string>>("padding");
std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
std::vector<std::string> dilation = node.GetAttr<std::vector<std::string>>("dilation");
arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides);

arm_compute::ActivationLayerInfo act_info;
if (node.HasAttr("activation_type")) {
std::string activation_type = node.GetAttr<std::vector<std::string>>("activation_type")[0];
act_info = MakeACLActivationInfo(activation_type);
}

arm_compute::Size2D dilation_2d(std::stoi(dilation[0]), std::stoi(dilation[1]));

// Collect inputs and outputs, handling both nn.conv2d and qnn.conv2d cases.
std::vector<JSONGraphNodeEntry> inputs = node.GetInputs();
size_t num_inputs = inputs.size();
bool has_bias;
if (node.GetOpName() == "qnn.depthwise_conv2d") {
CHECK(num_inputs >= 8U && num_inputs <= 9U)
<< "Quantized convolution requires 9 inputs with a bias, 8 inputs without.";
has_bias = num_inputs == 9;
layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[0], &inputs[4], &inputs[2]));
layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[1], &inputs[5], &inputs[3]));
if (has_bias) {
layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[6]));
}
layer->outputs.push_back(
MakeACLTensorFromJSONNode(node, &inputs[6 + has_bias], &inputs[7 + has_bias]));
} else {
CHECK(num_inputs >= 2U && num_inputs <= 3U)
<< "Convolution requires 3 inputs with a bias, 2 inputs without.";
has_bias = num_inputs == 3;
for (const auto& i : inputs) {
layer->inputs.push_back(MakeACLTensorFromJSONEntry(i));
}
layer->outputs.push_back(MakeACLTensorFromJSONNode(node));
}

// Depth multiplier is the final dimension in acl weights tensor (IWH*M*)
int depth_multiplier = layer->inputs[1].info()->tensor_shape()[3];

auto function = std::make_shared<arm_compute::NEDepthwiseConvolutionLayer>(mm);
function->configure(&layer->inputs[0], &layer->inputs[1],
has_bias ? &layer->inputs[2] : nullptr, &layer->outputs[0], pad_stride_info,
depth_multiplier, act_info, dilation_2d);
layer->function = function;
}

/*!
* \brief Create a fully connected (dense) layer.
*
Expand Down
Loading

0 comments on commit f02e26e

Please sign in to comment.