Skip to content

Commit

Permalink
[Layer] add "div layer"
Browse files Browse the repository at this point in the history
- added "div layer" for division.

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <[email protected]>
  • Loading branch information
baek2sm committed Aug 30, 2024
1 parent 4da3610 commit 075d2c6
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 1 deletion.
9 changes: 9 additions & 0 deletions api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum LayerType {
LAYER_ADD = ML_TRAIN_LAYER_TYPE_ADD, /**< Add Layer type */
LAYER_SUB = ML_TRAIN_LAYER_TYPE_SUB, /**< Subtract Layer type */
LAYER_MUL = ML_TRAIN_LAYER_TYPE_MUL, /**< Multiply Layer type */
LAYER_DIV = ML_TRAIN_LAYER_TYPE_DIV, /**< Divide Layer type */
LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */
LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */
LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */
Expand Down Expand Up @@ -326,6 +327,14 @@ MulLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_MUL, properties);
}

/**
* @brief Helper function to create div layer
*/
inline std::unique_ptr<Layer>
DivLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_DIV, properties);
}

/**
* @brief Helper function to create fully connected layer
*/
Expand Down
1 change: 1 addition & 0 deletions api/nntrainer-api-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ typedef enum {
ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_SUB = 33, /**< Sub Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_MUL = 34, /**< Mul Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_DIV = 35, /**< Div Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP =
300, /**< Preprocess flip Layer (Since 6.5) */
ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE =
Expand Down
3 changes: 3 additions & 0 deletions nntrainer/app_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <conv2d_layer.h>
#include <cross_entropy_sigmoid_loss_layer.h>
#include <cross_entropy_softmax_loss_layer.h>
#include <div_layer.h>
#include <dropout.h>
#include <embedding.h>
#include <fc_layer.h>
Expand Down Expand Up @@ -257,6 +258,8 @@ static void add_default_object(AppContext &ac) {
LayerType::LAYER_SUB);
ac.registerFactory(nntrainer::createLayer<MulLayer>, MulLayer::type,
LayerType::LAYER_MUL);
ac.registerFactory(nntrainer::createLayer<DivLayer>, DivLayer::type,
LayerType::LAYER_DIV);
ac.registerFactory(nntrainer::createLayer<FullyConnectedLayer>,
FullyConnectedLayer::type, LayerType::LAYER_FC);
ac.registerFactory(nntrainer::createLayer<BatchNormalizationLayer>,
Expand Down
97 changes: 97 additions & 0 deletions nntrainer/layers/div_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file div_layer.cpp
* @date 30 August 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is div layer class (operation layer)
*
*/

#include <div_layer.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
#include <util_func.h>

#include <layer_context.h>

namespace nntrainer {

static constexpr size_t SINGLE_INOUT_IDX = 0;

void DivLayer::finalize(InitLayerContext &context) {
context.setOutputDimensions({context.getInputDimensions()[0]});
}

void DivLayer::forwarding(RunLayerContext &context, bool training) {
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);

const Tensor &input0 = context.getInput(0);
const Tensor &input1 = context.getInput(1);

input0.divide(input1, hidden_);
}

void DivLayer::incremental_forwarding(RunLayerContext &context,
unsigned int from, unsigned int to,
bool training) {
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
TensorDim hidden_dim = hidden_.getDim();
TensorDim hidden_step_dim = hidden_dim;

if (from) {
NNTR_THROW_IF(to - from != 1, std::invalid_argument)
<< "incremental step size is not 1";
from = 0;
to = 1;
}

hidden_step_dim.batch(1);
hidden_step_dim.height(to - from);

for (unsigned int b = 0; b < hidden_.batch(); ++b) {
Tensor hidden_step = hidden_.getSharedDataTensor(
hidden_step_dim, b * hidden_dim.getFeatureLen(), true);

const Tensor &input0 = context.getInput(0);
const Tensor &input1 = context.getInput(1);

TensorDim input_dim = input0.getDim();
TensorDim input_step_dim = input_dim;
input_step_dim.batch(1);
input_step_dim.height(to - from);

Tensor input0_step = input0.getSharedDataTensor(
input_step_dim, b * input_dim.getFeatureLen(), true);

Tensor input1_step = input1.getSharedDataTensor(
input_step_dim, b * input_dim.getFeatureLen(), true);

input0_step.divide(input1_step, hidden_step);
}
}

void DivLayer::calcDerivative(RunLayerContext &context) {
context.getOutgoingDerivative(0).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX)
.divide(context.getInput(1)));

context.getOutgoingDerivative(1).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX)
.multiply(context.getInput(0).multiply(-1))
.divide(context.getInput(1).pow(2)));
}

void DivLayer::setProperty(const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, div_props);
if (!remain_props.empty()) {
std::string msg = "[DivLayer] Unknown Layer Properties count " +
std::to_string(values.size());
throw exception::not_supported(msg);
}
}
} /* namespace nntrainer */
103 changes: 103 additions & 0 deletions nntrainer/layers/div_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file div_layer.h
* @date 30 August 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is div layer class (operation layer)
*
*/

#ifndef __DIV_LAYER_H__
#define __DIV_LAYER_H__
#ifdef __cplusplus

#include <common_properties.h>
#include <layer_devel.h>

namespace nntrainer {

/**
* @class Div Layer
* @brief Div Layer
*/
class DivLayer : public Layer {
public:
/**
* @brief Constructor of Div Layer
*/
DivLayer() : Layer(), div_props(props::Print()) {}

/**
* @brief Destructor of Div Layer
*/
~DivLayer(){};

/**
* @brief Move constructor of Div Layer.
* @param[in] DivLayer &&
*/
DivLayer(DivLayer &&rhs) noexcept = default;

/**
* @brief Move assignment operator.
* @parma[in] rhs DivLayer to be moved.
*/
DivLayer &operator=(DivLayer &&rhs) = default;

/**
* @copydoc Layer::finalize(InitLayerContext &context)
*/
void finalize(InitLayerContext &context) override;

/**
* @copydoc Layer::forwarding(RunLayerContext &context, bool training)
*/
void forwarding(RunLayerContext &context, bool training) override;

/**
* @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
* int from, unsigned int to, bool training)
*/
void incremental_forwarding(RunLayerContext &context, unsigned int from,
unsigned int to, bool training) override;

/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
void calcDerivative(RunLayerContext &context) override;

/**
* @copydoc bool supportBackwarding() const
*/
bool supportBackwarding() const override { return true; };

/**
* @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
* method)
*/
void exportTo(Exporter &exporter,
const ml::train::ExportMethods &method) const override {}

/**
* @copydoc Layer::setProperty(const std::vector<std::string> &values)
*/
void setProperty(const std::vector<std::string> &values) override;

/**
* @copydoc Layer::getType()
*/
const std::string getType() const override { return DivLayer::type; };

std::tuple<props::Print> div_props;

inline static const std::string type = "div";
};

} // namespace nntrainer

#endif /* __cplusplus */
#endif /* __DIV_LAYER_H__ */
1 change: 1 addition & 0 deletions nntrainer/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ layer_sources = [
'add_layer.cpp',
'sub_layer.cpp',
'mul_layer.cpp',
'div_layer.cpp',
'addition_layer.cpp',
'attention_layer.cpp',
'mol_attention_layer.cpp',
Expand Down
3 changes: 3 additions & 0 deletions test/ccapi/unittest_ccapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ TEST(ccapi_layer, construct_02_p) {
EXPECT_NO_THROW(layer = ml::train::layer::MulLayer());
EXPECT_EQ(layer->getType(), "mul");

EXPECT_NO_THROW(layer = ml::train::layer::DivLayer());
EXPECT_EQ(layer->getType(), "div");

EXPECT_NO_THROW(layer = ml::train::layer::FullyConnected());
EXPECT_EQ(layer->getType(), "fully_connected");

Expand Down
25 changes: 24 additions & 1 deletion test/input_gen/genModelTests_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,19 @@ def forward(self, inputs, labels):
return out, loss


class DivOperation(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(2, 2)
self.loss = torch.nn.MSELoss()

def forward(self, inputs, labels):
out = self.fc(inputs[0])
out = inputs[0] / out
loss = self.loss(out, labels[0])
return out, loss


if __name__ == "__main__":
record_v2(
ReduceMeanLast(),
Expand Down Expand Up @@ -772,5 +785,15 @@ def forward(self, inputs, labels):
name="mul_operation",
)

div_operation = DivOperation()
record_v2(
div_operation,
iteration=2,
input_dims=[(1, 2)],
input_dtype=[float],
label_dims=[(1, 2)],
name="div_operation",
)

# Function to check the created golden test file
inspect_file("sub_operation.nnmodelgolden")
inspect_file("div_operation.nnmodelgolden")
1 change: 1 addition & 0 deletions test/unittest/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ test_target = [
'unittest_layers_add.cpp',
'unittest_layers_sub.cpp',
'unittest_layers_mul.cpp',
'unittest_layers_div.cpp',
'unittest_layers_multiout.cpp',
'unittest_layers_rnn.cpp',
'unittest_layers_rnncell.cpp',
Expand Down
28 changes: 28 additions & 0 deletions test/unittest/layers/unittest_layers_div.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file unittest_layers_div.cpp
* @date 30 August 2024
* @brief Div Layer Test
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
*/
#include <tuple>

#include <gtest/gtest.h>

#include <div_layer.h>
#include <layers_common_tests.h>

auto semantic_div = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::DivLayer>, nntrainer::DivLayer::type, {},
LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);

auto semantic_div_multi = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::DivLayer>, nntrainer::DivLayer::type, {},
LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2);

GTEST_PARAMETER_TEST(Div, LayerSemantics,
::testing::Values(semantic_div, semantic_div_multi));
20 changes: 20 additions & 0 deletions test/unittest/models/unittest_models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,25 @@ static std::unique_ptr<NeuralNetwork> makeMulOperation() {
return nn;
}

static std::unique_ptr<NeuralNetwork> makeDivOperation() {
std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());

auto outer_graph =
makeGraph({{"input", {"name=in", "input_shape=1:1:2"}},
{"fully_connected", {"name=fc", "unit=2", "input_layers=in"}},
{"div", {"name=div_layer", "input_layers=in,fc"}},
{"mse", {"name=loss", "input_layers=div_layer"}}});

for (auto &node : outer_graph) {
nn->addLayer(node);
}

nn->setProperty({"batch_size=1"});
nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"}));

return nn;
}

GTEST_PARAMETER_TEST(
model, nntrainerModelTest,
::testing::ValuesIn({
Expand Down Expand Up @@ -995,6 +1014,7 @@ GTEST_PARAMETER_TEST(
mkModelTc_V2(makeAddOperation, "add_operation", ModelTestOption::ALL_V2),
mkModelTc_V2(makeSubOperation, "sub_operation", ModelTestOption::ALL_V2),
mkModelTc_V2(makeMulOperation, "mul_operation", ModelTestOption::ALL_V2),
mkModelTc_V2(makeDivOperation, "div_operation", ModelTestOption::ALL_V2),
}),
[](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info)
-> const auto & { return std::get<1>(info.param); });
Expand Down

0 comments on commit 075d2c6

Please sign in to comment.