Skip to content

Commit

Permalink
[Layer] add "multiply layer"
Browse files Browse the repository at this point in the history
- added "multiply layer" for multiplication

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <[email protected]>
  • Loading branch information
baek2sm committed Nov 12, 2024
1 parent 6a06ec6 commit 22cd851
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 0 deletions.
9 changes: 9 additions & 0 deletions api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ enum LayerType {
LAYER_WEIGHT = ML_TRAIN_LAYER_TYPE_WEIGHT, /**< Weight Layer type */
LAYER_ADD = ML_TRAIN_LAYER_TYPE_ADD, /**< Add Layer type */
LAYER_SUBTRACT = ML_TRAIN_LAYER_TYPE_SUBTRACT, /**< Subtract Layer type */
LAYER_MULTIPLY = ML_TRAIN_LAYER_TYPE_MULTIPLY, /**< Multiply Layer type */
LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */
LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */
LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */
Expand Down Expand Up @@ -323,6 +324,14 @@ SubtractLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_SUBTRACT, properties);
}

/**
* @brief Helper function to create multiply layer
*/
inline std::unique_ptr<Layer>
MultiplyLayer(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_MULTIPLY, properties);
}

/**
* @brief Helper function to create fully connected layer
*/
Expand Down
1 change: 1 addition & 0 deletions api/nntrainer-api-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ typedef enum {
ML_TRAIN_LAYER_TYPE_WEIGHT = 31, /**< Weight Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_ADD = 32, /**< Add Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_SUBTRACT = 33, /**< Subtract Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_MULTIPLY = 34, /**< Multiply Layer type (Since 9.0)*/
ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP =
300, /**< Preprocess flip Layer (Since 6.5) */
ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE =
Expand Down
3 changes: 3 additions & 0 deletions nntrainer/app_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include <mse_loss_layer.h>
#include <multi_head_attention_layer.h>
#include <multiout_layer.h>
#include <multiply_layer.h>
#include <nntrainer_error.h>
#include <permute_layer.h>
#include <plugged_layer.h>
Expand Down Expand Up @@ -259,6 +260,8 @@ static void add_default_object(AppContext &ac) {
LayerType::LAYER_ADD);
ac.registerFactory(nntrainer::createLayer<SubtractLayer>, SubtractLayer::type,
LayerType::LAYER_SUBTRACT);
ac.registerFactory(nntrainer::createLayer<MultiplyLayer>, MultiplyLayer::type,
LayerType::LAYER_MULTIPLY);
ac.registerFactory(nntrainer::createLayer<FullyConnectedLayer>,
FullyConnectedLayer::type, LayerType::LAYER_FC);
ac.registerFactory(nntrainer::createLayer<BatchNormalizationLayer>,
Expand Down
1 change: 1 addition & 0 deletions nntrainer/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ layer_sources = [
'weight_layer.cpp',
'add_layer.cpp',
'subtract_layer.cpp',
'multiply_layer.cpp',
'addition_layer.cpp',
'attention_layer.cpp',
'mol_attention_layer.cpp',
Expand Down
51 changes: 51 additions & 0 deletions nntrainer/layers/multiply_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file multiply_layer.cpp
* @date 10 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is multiply layer class (operation layer)
*
*/

#include <multiply_layer.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
#include <util_func.h>

#include <layer_context.h>

namespace nntrainer {

void MultiplyLayer::finalize(InitLayerContext &context) {
context.setOutputDimensions({context.getInputDimensions()[0]});
}

void MultiplyLayer::forwarding_operation(const Tensor &input0,
const Tensor &input1, Tensor &hidden) {
input0.multiply(input1, hidden);
}

void MultiplyLayer::calcDerivative(RunLayerContext &context) {
context.getOutgoingDerivative(0).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX)
.multiply(context.getInput(1)));

context.getOutgoingDerivative(1).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX)
.multiply(context.getInput(0)));
}

void MultiplyLayer::setProperty(const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, multiply_props);
if (!remain_props.empty()) {
std::string msg = "[MultiplyLayer] Unknown Layer Properties count " +
std::to_string(values.size());
throw exception::not_supported(msg);
}
}
} /* namespace nntrainer */
102 changes: 102 additions & 0 deletions nntrainer/layers/multiply_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file multiply_layer.h
* @date 10 Oct 2024
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is mul layer class (operation layer)
*
*/

#ifndef __MULTIPLY_LAYER_H__
#define __MULTIPLY_LAYER_H__
#ifdef __cplusplus

#include <common_properties.h>
#include <layer_devel.h>
#include <operation_layer.h>

namespace nntrainer {

/**
* @class Multiply Layer
* @brief Multiply Layer
*/
class MultiplyLayer : public BinaryOperationLayer {
public:
/**
* @brief Constructor of Multiply Layer
*/
MultiplyLayer() : BinaryOperationLayer(), multiply_props(props::Print()) {}

/**
* @brief Destructor of Multiply Layer
*/
~MultiplyLayer(){};

/**
* @brief Move constructor of Multiply Layer.
* @param[in] MultiplyLayer &&
*/
MultiplyLayer(MultiplyLayer &&rhs) noexcept = default;

/**
* @brief Move assignment operator.
* @parma[in] rhs MultiplyLayer to be moved.
*/
MultiplyLayer &operator=(MultiplyLayer &&rhs) = default;

/**
* @copydoc Layer::finalize(InitLayerContext &context)
*/
void finalize(InitLayerContext &context) final;

/**
* @brief forwarding operation for add
*
* @param input0 input tensor 0
* @param input1 input tensor 1
* @param hidden tensor to store the result of addition
*/
void forwarding_operation(const Tensor &input0, const Tensor &input1,
Tensor &hidden) final;

/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
void calcDerivative(RunLayerContext &context) final;

/**
* @copydoc bool supportBackwarding() const
*/
bool supportBackwarding() const final { return true; };

/**
* @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
* method)
*/
void exportTo(Exporter &exporter,
const ml::train::ExportMethods &method) const final {}

/**
* @copydoc Layer::setProperty(const std::vector<std::string> &values)
*/
void setProperty(const std::vector<std::string> &values) final;

/**
* @copydoc Layer::getType()
*/
const std::string getType() const final { return MultiplyLayer::type; };

std::tuple<props::Print> multiply_props;

inline static const std::string type = "multiply";
};

} // namespace nntrainer

#endif /* __cplusplus */
#endif /* __MULTIPLY_LAYER_H__ */
3 changes: 3 additions & 0 deletions test/ccapi/unittest_ccapi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ TEST(ccapi_layer, construct_02_p) {
EXPECT_NO_THROW(layer = ml::train::layer::SubtractLayer());
EXPECT_EQ(layer->getType(), "subtract");

EXPECT_NO_THROW(layer = ml::train::layer::MultiplyLayer());
EXPECT_EQ(layer->getType(), "multiply");

EXPECT_NO_THROW(layer = ml::train::layer::FullyConnected());
EXPECT_EQ(layer->getType(), "fully_connected");

Expand Down
23 changes: 23 additions & 0 deletions test/input_gen/genModelTests_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,19 @@ def forward(self, inputs, labels):
return out, loss


class MultiplyOperation(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(2, 2)
self.loss = torch.nn.MSELoss()

def forward(self, inputs, labels):
out = self.fc(inputs[0])
out = inputs[0] * out
loss = self.loss(out, labels[0])
return out, loss


if __name__ == "__main__":
record_v2(
ReduceMeanLast(),
Expand Down Expand Up @@ -799,6 +812,16 @@ def forward(self, inputs, labels):
name="subtract_operation",
)

multiply_operation = MultiplyOperation()
record_v2(
multiply_operation,
iteration=2,
input_dims=[(1, 2)],
input_dtype=[float],
label_dims=[(1, 2)],
name="multiply_operation",
)

# Function to check the created golden test file
inspect_file("add_operation.nnmodelgolden")
fc_mixed_training_nan_sgd = LinearMixedPrecisionNaNSGD()
Expand Down
1 change: 1 addition & 0 deletions test/unittest/layers/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ test_target = [
'unittest_layers_addition.cpp',
'unittest_layers_add.cpp',
'unittest_layers_subtract.cpp',
'unittest_layers_multiply.cpp',
'unittest_layers_multiout.cpp',
'unittest_layers_rnn.cpp',
'unittest_layers_rnncell.cpp',
Expand Down
31 changes: 31 additions & 0 deletions test/unittest/layers/unittest_layers_multiply.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 SeungBaek Hong <[email protected]>
*
* @file unittest_layers_multiply.cpp
* @date 30 August 2024
* @brief Mul Layer Test
* @see https://github.com/nnstreamer/nntrainer
* @author SeungBaek Hong <[email protected]>
* @bug No known bugs except for NYI items
*/
#include <tuple>

#include <gtest/gtest.h>

#include <layers_common_tests.h>
#include <multiply_layer.h>

auto semantic_multiply = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::MultiplyLayer>,
nntrainer::MultiplyLayer::type, {},
LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);

auto semantic_multiply_multi = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::MultiplyLayer>,
nntrainer::MultiplyLayer::type, {},
LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2);

GTEST_PARAMETER_TEST(Multiply, LayerSemantics,
::testing::Values(semantic_multiply,
semantic_multiply_multi));
21 changes: 21 additions & 0 deletions test/unittest/models/unittest_models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,25 @@ static std::unique_ptr<NeuralNetwork> makeSubtractOperation() {
return nn;
}

static std::unique_ptr<NeuralNetwork> makeMultiplyOperation() {
std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());

auto outer_graph =
makeGraph({{"input", {"name=in", "input_shape=1:1:2"}},
{"fully_connected", {"name=fc", "unit=2", "input_layers=in"}},
{"multiply", {"name=multiply_layer", "input_layers=in,fc"}},
{"mse", {"name=loss", "input_layers=multiply_layer"}}});

for (auto &node : outer_graph) {
nn->addLayer(node);
}

nn->setProperty({"batch_size=1"});
nn->setOptimizer(ml::train::createOptimizer("sgd", {"learning_rate=0.1"}));

return nn;
}

GTEST_PARAMETER_TEST(
model, nntrainerModelTest,
::testing::ValuesIn({
Expand Down Expand Up @@ -984,6 +1003,8 @@ GTEST_PARAMETER_TEST(
mkModelTc_V2(makeAddOperation, "add_operation", ModelTestOption::ALL_V2),
mkModelTc_V2(makeSubtractOperation, "subtract_operation",
ModelTestOption::ALL_V2),
mkModelTc_V2(makeMultiplyOperation, "multiply_operation",
ModelTestOption::ALL_V2),
}),
[](const testing::TestParamInfo<nntrainerModelTest::ParamType> &info)
-> const auto & { return std::get<1>(info.param); });
Expand Down

0 comments on commit 22cd851

Please sign in to comment.