Skip to content

Commit

Permalink
[onert-micro] Add SVDF cmsis-nn (#13719)
Browse files Browse the repository at this point in the history
This draft adds svdf cmsis-nn kernel.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]
  • Loading branch information
BalyshevArtem authored Aug 23, 2024
1 parent 3d2ffbf commit aa86e3d
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 10 deletions.
12 changes: 12 additions & 0 deletions onert-micro/onert-micro/include/core/OMKernelData.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,18 @@ struct SliceParams
int32_t size[5];
};

struct SVDFQuantParams
{
int32_t input_zero_point;
int32_t output_zero_point;
int32_t activation_state_zero_point;
int32_t effective_scale_1_a;
int effective_scale_1_b;
int32_t effective_scale_2_a;
int effective_scale_2_b;
int rank;
};

} // namespace core
} // namespace onert_micro

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ REGISTER_KERNEL(MAX_POOL_2D, MaxPool2D)
REGISTER_KERNEL(SOFTMAX, Softmax)
#/*REGISTER_KERNEL(SUM, Sum)*/
#/*REGISTER_KERNEL(SELECT_V2, SelectV2)*/
#/*REGISTER_KERNEL(SVDF, SVDF)*/
REGISTER_KERNEL(SVDF, SVDF)
#/*REGISTER_KERNEL(WHILE, While)*/
#/*REGISTER_KERNEL(UNIDIRECTIONAL_SEQUENCE_LSTM, UnidirectionalSequenceLSTM)*/
#/*REGISTER_KERNEL(RESIZE_BILINEAR, ResizeBilinear)*/
Expand Down
125 changes: 125 additions & 0 deletions onert-micro/onert-micro/include/pal/cmsisnn/PALSVDF.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_PAL_SVDF_H
#define ONERT_MICRO_EXECUTE_PAL_SVDF_H

#include "PALSVDFCommon.h"
#include "core/OMRuntimeShape.h"
#include "core/OMKernelData.h"
#include "core/memory/OMMemoryManager.h"

#include <arm_nnfunctions.h>

namespace onert_micro
{
namespace execute
{
namespace pal
{

OMStatus SVDF(const core::SVDFQuantParams &params, const int8_t *input_data,
const int8_t *weights_feature_data, const int8_t *weights_time_data,
const int32_t *bias_data, int8_t *state_data, int8_t *output_data,
const core::OMRuntimeShape &input_shape,
const core::OMRuntimeShape &weights_feature_shape,
const core::OMRuntimeShape &weights_time_shape,
const core::OMRuntimeShape &bias_shape, const core::OMRuntimeShape &output_shape)
{
cmsis_nn_dims input_dims;
input_dims.n = input_shape.dims(0);
input_dims.h = input_shape.dims(1);

cmsis_nn_dims weights_feature_dims;
weights_feature_dims.n = weights_feature_shape.dims(0);
weights_feature_dims.h = weights_feature_shape.dims(1);

cmsis_nn_dims weights_time_dims;
weights_time_dims.n = weights_time_shape.dims(0);
weights_time_dims.h = weights_time_shape.dims(1);

cmsis_nn_dims bias_dims;
bias_dims.n = bias_shape.dims(0);

cmsis_nn_dims state_dims;
state_dims.n = bias_shape.dims(0);
state_dims.h = bias_shape.dims(1);

cmsis_nn_dims output_dims;
output_dims.n = output_shape.dims(0);
output_dims.h = output_shape.dims(1);

cmsis_nn_svdf_params svdf_params;
svdf_params.rank = params.rank;
svdf_params.input_offset = params.input_zero_point;
svdf_params.output_offset = params.output_zero_point;

svdf_params.input_activation.min = INT16_MIN;
svdf_params.input_activation.max = INT16_MAX;

svdf_params.output_activation.min = INT8_MIN;
svdf_params.output_activation.max = INT8_MAX;

cmsis_nn_per_tensor_quant_params in_quant_params;
in_quant_params.multiplier = params.effective_scale_1_a;
in_quant_params.shift = params.effective_scale_1_b;

cmsis_nn_per_tensor_quant_params out_quant_params;
out_quant_params.multiplier = params.effective_scale_2_a;
out_quant_params.shift = params.effective_scale_2_b;

const int batch_size = input_shape.dims(0);
const int input_size = input_shape.dims(1);
const int num_filters = weights_feature_shape.dims(0);
const int num_units = num_filters / params.rank;

uint8_t *scratch_tensor_data;
OMStatus status = core::memory::OMMemoryManager::allocateMemory(
batch_size * num_filters * sizeof(int32_t), &scratch_tensor_data);
assert(status == Ok);
if (status != Ok)
return status;

uint8_t *scratch_output_tensor_data;
status = core::memory::OMMemoryManager::allocateMemory(batch_size * num_units * sizeof(int32_t),
&scratch_output_tensor_data);
assert(status == Ok);
if (status != Ok)
return status;

cmsis_nn_context scratch_ctx;
scratch_ctx.buf = reinterpret_cast<int32_t *>(scratch_tensor_data);

cmsis_nn_context scratch_output_ctx;
scratch_output_ctx.buf = reinterpret_cast<int32_t *>(scratch_output_tensor_data);

arm_svdf_s8(&scratch_ctx, &scratch_output_ctx, &svdf_params, &in_quant_params, &out_quant_params,
&input_dims, input_data, &state_dims, state_data, &weights_feature_dims,
weights_feature_data, &weights_time_dims, weights_time_data, &bias_dims, bias_data,
&output_dims, output_data);

core::memory::OMMemoryManager::deallocateMemory(scratch_tensor_data);
core::memory::OMMemoryManager::deallocateMemory(scratch_output_tensor_data);

return Ok;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_SVDF_H
20 changes: 20 additions & 0 deletions onert-micro/onert-micro/include/pal/mcu/PALSVDF.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,24 @@

#include "PALSVDFCommon.h"

namespace onert_micro
{
namespace execute
{
namespace pal
{

OMStatus SVDF(const core::SVDFQuantParams &, const int8_t *, const int8_t *, const int8_t *,
const int32_t *, int8_t *, int8_t *, const core::OMRuntimeShape &,
const core::OMRuntimeShape &, const core::OMRuntimeShape &,
const core::OMRuntimeShape &, const core::OMRuntimeShape &)
{
// TODO: support it
return UnsupportedType;
}

} // namespace pal
} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_PAL_SVDF_H
70 changes: 61 additions & 9 deletions onert-micro/onert-micro/src/execute/kernels/SVDF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "core/memory/OMMemoryManager.h"

#include "execute/OMKernelExecutionBuilder.h"
#include "execute/OMUtils.h"
#include "execute/OMRuntimeKernel.h"

#include "PALSVDF.h"
Expand All @@ -41,6 +42,38 @@ constexpr int inputActivationStateTensorIdx =
4; // This is a variable tensor, and will be modified by this op.
constexpr int outputTensorIdx = 0;

void prepareQuantParams(core::SVDFQuantParams &params, const circle::Tensor *input,
const circle::Tensor *weights_feature, const circle::Tensor *weights_time,
const circle::Tensor *activation_state, const circle::Tensor *output)
{
assert(input->quantization() != nullptr);
assert(output->quantization() != nullptr);
assert(weights_feature->quantization() != nullptr);
assert(weights_time->quantization() != nullptr);
assert(activation_state->quantization() != nullptr);

// Write zero points
params.input_zero_point =
static_cast<int32_t>(input->quantization()->zero_point()->operator[](0));
params.output_zero_point =
static_cast<int32_t>(output->quantization()->zero_point()->operator[](0));
params.activation_state_zero_point =
static_cast<int32_t>(activation_state->quantization()->zero_point()->operator[](0));

// Calculate effective scales
const float effective_scale_1 = (input->quantization()->scale()->operator[](0) *
weights_feature->quantization()->scale()->operator[](0)) /
(activation_state->quantization()->scale()->operator[](0));
const float effective_scale_2 = (activation_state->quantization()->scale()->operator[](0) *
weights_time->quantization()->scale()->operator[](0)) /
(output->quantization()->scale()->operator[](0));

execute::quantizeMultiplier(effective_scale_1, &params.effective_scale_1_a,
&params.effective_scale_1_b);
execute::quantizeMultiplier(effective_scale_2, &params.effective_scale_2_a,
&params.effective_scale_2_b);
}

} // namespace

OMStatus onert_micro::execute::execute_kernel_CircleSVDF(const OMExecuteArgs &execute_args)
Expand Down Expand Up @@ -130,30 +163,50 @@ OMStatus onert_micro::execute::execute_kernel_CircleSVDF(const OMExecuteArgs &ex
return status;

std::memset(activation_state_data, 0, activation_state_size);
// Temporary buffer
uint8_t *scratch_buffer;
status = core::memory::OMMemoryManager::allocateMemory(
batch_size * num_filters * sizeof(core::OMDataType(output->type())), &scratch_buffer);

assert(status == Ok);
if (status != Ok)
return status;

switch (input->type())
{
#ifndef DIS_FLOAT
case circle::TensorType_FLOAT32:
{
// Temporary buffer
uint8_t *scratch_buffer;
status = core::memory::OMMemoryManager::allocateMemory(
batch_size * num_filters * sizeof(core::OMDataType(output->type())), &scratch_buffer);

assert(status == Ok);
if (status != Ok)
return status;
status = pal::SVDF(
utils::castInputData<float>(input_data), utils::castInputData<float>(weights_feature_data),
utils::castInputData<float>(weights_time_data), utils::castInputData<float>(bias_data),
utils::castOutputData<float>(activation_state_data),
utils::castOutputData<float>(scratch_buffer), utils::castOutputData<float>(output_data),
rank, input_size, batch_size, num_filters, num_units, memory_size,
options->fused_activation_function());

status = core::memory::OMMemoryManager::deallocateMemory(scratch_buffer);
}
break;
#endif // DIS_FLOAT
#ifndef DIS_QUANT
case circle::TensorType_INT8:
{
core::SVDFQuantParams params{};
prepareQuantParams(params, input, weights_feature, weights_time, activation_state, output);

params.rank = rank;

status = pal::SVDF(
params, utils::castInputData<int8_t>(input_data),
utils::castInputData<int8_t>(weights_feature_data),
utils::castInputData<int8_t>(weights_time_data), utils::castInputData<int32_t>(bias_data),
utils::castOutputData<int8_t>(activation_state_data),
utils::castOutputData<int8_t>(output_data), input_shape, weights_feature_shape,
weights_time_shape, core::OMRuntimeShape(bias), output_shape);
}
break;
#endif // DIS_QUANT
default:
{
status = UnsupportedActivation;
Expand All @@ -163,7 +216,6 @@ OMStatus onert_micro::execute::execute_kernel_CircleSVDF(const OMExecuteArgs &ex
}

status = core::memory::OMMemoryManager::deallocateMemory(activation_state_data);
status = core::memory::OMMemoryManager::deallocateMemory(scratch_buffer);

return status;
}

0 comments on commit aa86e3d

Please sign in to comment.