Skip to content

Commit

Permalink
[Layer] add "add layer"
Browse files Browse the repository at this point in the history
- added "add layer"
- added a "model unit test" for add layer. Since 'model-level unit test'
havn't been run for a long time, I diabled some test cases that causing
issues when running model unit test.

Many people gave great feedback, so I've improved the structure accordingly.
- An upper class called "OperationLayer" was added to reduce redundant code.
- Based on the number of input tensors, the behavior of "OperationLayer" has been classified into two types: unary and binary operations.
- Various additional code cleanups have also taken place.

- there is an issue where committing compressed files containing golden data for unit test prevents pushing changes to the remote server. (I've confirmed that all unit tests pass locally using that golden data.)

**Self evaluation:**
1. Build test:   [X]Passed [X]Failed [ ]Skipped
2. Run test:     [X]Passed [X]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <[email protected]>
  • Loading branch information
baek2sm committed Oct 10, 2024
1 parent f222ecf commit d9416db
Show file tree
Hide file tree
Showing 14 changed files with 713 additions and 159 deletions.
9 changes: 9 additions & 0 deletions api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ namespace train {
enum LayerType {
LAYER_IN = ML_TRAIN_LAYER_TYPE_INPUT, /**< Input Layer type */
LAYER_WEIGHT = ML_TRAIN_LAYER_TYPE_WEIGHT, /**< Weight Layer type */
LAYER_ADD = ML_TRAIN_LAYER_TYPE_ADD, /**< Add Layer type */
LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */
LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */
LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */
Expand Down Expand Up @@ -305,6 +306,14 @@ WeightLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_WEIGHT, properties);
}

/**
* @brief Helper function to create add layer
*/
inline std::unique_ptr<Layer>
AddLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_ADD, properties);
}

/**
* @brief Helper function to create fully connected layer
*/
Expand Down
1 change: 1 addition & 0 deletions api/nntrainer-api-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ typedef enum {
ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */
ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */
ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP =
300, /**< Preprocess flip Layer (Since 6.5) */
ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE =
Expand Down
1 change: 1 addition & 0 deletions debian/nntrainer-dev.install
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
/usr/include/nntrainer/layer_context.h
/usr/include/nntrainer/layer_devel.h
/usr/include/nntrainer/layer_impl.h
/usr/include/nntrainer/operation_layer.h
/usr/include/nntrainer/acti_func.h
# custom layer kits
/usr/include/nntrainer/app_context.h
Expand Down
3 changes: 3 additions & 0 deletions nntrainer/app_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <sgd.h>

#include <activation_layer.h>
#include <add_layer.h>
#include <addition_layer.h>
#include <attention_layer.h>
#include <bn_layer.h>
Expand Down Expand Up @@ -248,6 +249,8 @@ static void add_default_object(AppContext &ac) {
LayerType::LAYER_IN);
ac.registerFactory(nntrainer::createLayer<WeightLayer>, WeightLayer::type,
LayerType::LAYER_WEIGHT);
ac.registerFactory(nntrainer::createLayer<AddLayer>, AddLayer::type,
LayerType::LAYER_ADD);
ac.registerFactory(nntrainer::createLayer<FullyConnectedLayer>,
FullyConnectedLayer::type, LayerType::LAYER_FC);
ac.registerFactory(nntrainer::createLayer<BatchNormalizationLayer>,
Expand Down
52 changes: 52 additions & 0 deletions nntrainer/layers/add_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file add_layer.cpp
* @date 7 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is add layer class (operation layer)
*
*/

#include <add_layer.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
#include <util_func.h>

#include <layer_context.h>

namespace nntrainer {

static constexpr size_t SINGLE_INOUT_IDX = 0;

void AddLayer::finalize(InitLayerContext &context) {
op_type = OperationType::BINARY;
context.setOutputDimensions({context.getInputDimensions()[0]});
}

void AddLayer::forwarding_operation(const Tensor &input0, const Tensor &input1,
Tensor &hidden) {
input0.add(input1, hidden);
}

void AddLayer::calcDerivative(RunLayerContext &context) {
context.getOutgoingDerivative(0).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX));

context.getOutgoingDerivative(1).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX));
}

void AddLayer::setProperty(const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, add_props);
if (!remain_props.empty()) {
std::string msg = "[AddLayer] Unknown Layer Properties count " +
std::to_string(values.size());
throw exception::not_supported(msg);
}
}
} /* namespace nntrainer */
108 changes: 108 additions & 0 deletions nntrainer/layers/add_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file add_layer.h
* @date 7 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is add layer class (operation layer)
*
*/

#ifndef __ADD_LAYER_H__
#define __ADD_LAYER_H__
#ifdef __cplusplus

#include <common_properties.h>
#include <layer_devel.h>
#include <operation_layer.h>

namespace nntrainer {

/**
* @class Add Layer
* @brief Add Layer
*/
class AddLayer : public OperationLayer {
public:
/**
* @brief Destructor of Add Layer
*/
~AddLayer() {}

/**
* @brief Constructor of Add Layer
*/
AddLayer() : OperationLayer(), add_props(props::Print()) {}

/**
* @brief Move constructor of Add Layer.
* @param[in] AddLayer &&
*/
AddLayer(AddLayer &&rhs) noexcept = default;

/**
* @brief Move assignment operator.
* @parma[in] rhs AddLayer to be moved.
*/
AddLayer &operator=(AddLayer &&rhs) = default;

/**
* @copydoc Layer::finalize(InitLayerContext &context)
*/
void finalize(InitLayerContext &context) final;

/**
* @copydoc OperationLayer::forwarding_operation(const Tensor &input, Tensor
* &hidden)
*/
void forwarding_operation(const Tensor &input, Tensor &hidden) final{};

/**
* @brief forwarding operation for add
*
* @param input0 input tensor 0
* @param input1 input tensor 1
* @param hidden tensor to store the result of addition
*/
void forwarding_operation(const Tensor &input0, const Tensor &input1,
Tensor &hidden) final;

/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
void calcDerivative(RunLayerContext &context) final;

/**
* @copydoc bool supportBackwarding() const
*/
bool supportBackwarding() const final { return true; };

/**
* @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
* method)
*/
void exportTo(Exporter &exporter,
const ml::train::ExportMethods &method) const final {}

/**
* @copydoc Layer::setProperty(const std::vector<std::string> &values)
*/
void setProperty(const std::vector<std::string> &values) final;

/**
* @copydoc Layer::getType()
*/
const std::string getType() const final { return AddLayer::type; }

std::tuple<props::Print> add_props;

inline static const std::string type = "add";
};

} // namespace nntrainer

#endif /* __cplusplus */
#endif /* __ADD_LAYER_H__ */
2 changes: 2 additions & 0 deletions nntrainer/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ nntrainer_inc_abs += meson.current_source_dir() / 'loss'
layer_sources = [
'activation_layer.cpp',
'weight_layer.cpp',
'add_layer.cpp',
'addition_layer.cpp',
'attention_layer.cpp',
'mol_attention_layer.cpp',
Expand Down Expand Up @@ -52,6 +53,7 @@ layer_headers = [
'layer_devel.h',
'layer_impl.h',
'acti_func.h',
'operation_layer.h',
'common_properties.h',
'layer_node.h',
]
Expand Down
136 changes: 136 additions & 0 deletions nntrainer/layers/operation_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file operation_layer.h
* @date 4 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is common class for operation layers
*
*/

#include <layer_context.h>
#include <layer_devel.h>

namespace nntrainer {

/**
* @brief Type of layer depending on the number of inputs.
*
*/
enum class OperationType { NONE, UNARY, BINARY };

/**
* @brief Base class for Tensor Operation Layer
*
*/
class OperationLayer : public Layer {
public:
/**
* @brief forwarding operation for unary input
*
*/
virtual void forwarding_operation(const Tensor &input, Tensor &hidden) = 0;

/**
* @brief forwarding operation for binary inputs
*
*/
virtual void forwarding_operation(const Tensor &input0, const Tensor &input1,
Tensor &hidden) = 0;

/**
* @brief copydoc Layer::forwarding(RunLayerContext &context, bool training)
*
*/
void forwarding(RunLayerContext &context, bool training) override {
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);

if (op_type == OperationType::UNARY) {
const Tensor input = context.getInput(0);
forwarding_operation(input, hidden_);
} else if (op_type == OperationType::BINARY) {
const Tensor &input0 = context.getInput(0);
const Tensor &input1 = context.getInput(1);
forwarding_operation(input0, input1, hidden_);
} else {
throw std::invalid_argument("Operation type is not defined");
}
}

/**
* @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
* int from, unsigned int to, bool training)
*
*/
void incremental_forwarding(RunLayerContext &context, unsigned int from,
unsigned int to, bool training) override {
if (from) {
NNTR_THROW_IF(to - from != 1, std::invalid_argument)
<< "incremental step size is not 1";
from = 0;
to = 1;
}

Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
TensorDim hidden_dim = hidden_.getDim();
TensorDim hidden_step_dim = hidden_dim;

hidden_step_dim.batch(1);
hidden_step_dim.height(to - from);

if (op_type == OperationType::UNARY) {
const Tensor &input = context.getInput(0);
TensorDim input_dim = input.getDim();
TensorDim input_step_dim = input_dim;
input_step_dim.batch(1);
input_step_dim.height(to - from);

for (unsigned int b = 0; b < hidden_.batch(); ++b) {
Tensor hidden_step = hidden_.getSharedDataTensor(
hidden_step_dim, b * hidden_dim.getFeatureLen(), true);

Tensor input_step = input.getSharedDataTensor(
input_step_dim, b * input_dim.getFeatureLen(), true);

forwarding_operation(input_step, hidden_step);
}
} else if (op_type == OperationType::BINARY) {
const Tensor &input0 = context.getInput(0);
const Tensor &input1 = context.getInput(1);

TensorDim input0_dim = input0.getDim();
TensorDim input1_dim = input1.getDim();
if (input0_dim != input1_dim) {
throw std::invalid_argument(
"If the two input dimensions are different, the incremental "
"forwarding implementation must be overridden.");
}

TensorDim input_step_dim = input0_dim;
input_step_dim.batch(1);
input_step_dim.height(to - from);

for (unsigned int b = 0; b < hidden_.batch(); ++b) {
Tensor hidden_step = hidden_.getSharedDataTensor(
hidden_step_dim, b * hidden_dim.getFeatureLen(), true);

Tensor input0_step = input0.getSharedDataTensor(
input_step_dim, b * input0_dim.getFeatureLen(), true);

Tensor input1_step = input1.getSharedDataTensor(
input_step_dim, b * input1_dim.getFeatureLen(), true);

forwarding_operation(input0_step, input1_step, hidden_step);
}
} else {
throw std::invalid_argument("Operation type is not defined");
}
}

OperationType op_type = OperationType::NONE; /**< type of operation */
static constexpr size_t SINGLE_INOUT_IDX = 0;
};
} // namespace nntrainer
1 change: 1 addition & 0 deletions packaging/nntrainer.spec
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/
%{_includedir}/nntrainer/layer_context.h
%{_includedir}/nntrainer/layer_devel.h
%{_includedir}/nntrainer/layer_impl.h
%{_includedir}/nntrainer/operation_layer.h
# custom layer kits
%{_includedir}/nntrainer/app_context.h
# optimizer headers
Expand Down
3 changes: 3 additions & 0 deletions test/ccapi/unittest_ccapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ TEST(ccapi_layer, construct_02_p) {
EXPECT_NO_THROW(layer = ml::train::layer::WeightLayer());
EXPECT_EQ(layer->getType(), "weight");

EXPECT_NO_THROW(layer = ml::train::layer::AddLayer());
EXPECT_EQ(layer->getType(), "add");

EXPECT_NO_THROW(layer = ml::train::layer::FullyConnected());
EXPECT_EQ(layer->getType(), "fully_connected");

Expand Down
Loading

0 comments on commit d9416db

Please sign in to comment.