Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Loss&Test] Implement KLD Loss & Add Unit Test on loss layers @open sesame 10/28 13:48 #2757

Merged
merged 6 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions nntrainer/layers/loss/kld_loss_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2021 Jihoon Lee <[email protected]>
*
* @file kld_loss_layer.cpp
* @date 25 November 2021
* @brief KLD (Kullback-Leibler Divergence) loss implementation
* @see https://github.com/nnstreamer/nntrainer
* @author Jihoon Lee <[email protected]>
* @author Donghak Park <[email protected]>
* @bug No known bugs except for NYI items
*
*/
#include <kld_loss_layer.h>
#include <layer_context.h>
#include <string>
#include <vector>

namespace nntrainer {
static constexpr size_t SINGLE_INOUT_IDX = 0;

void KLDLossLayer::forwarding(RunLayerContext &context, bool training) {
// Result = (P * (P / Q).log()).sum()
// KL(P ∣∣ Q) whereP denotes the distribution of the observations in datasets
// and Q denotes the model output.

nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX);
nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
if (context.isLabelAvailable(SINGLE_INOUT_IDX)) {
nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX);
nntrainer::Tensor temp; // temp output
/**
* 1. Output = label / predicted
* 2. Output = log(Output)
* 3. Output = Output * label
* 4. Output = sum(output)
*/
label.divide(predicted, temp);
temp.apply<float>(logf, temp);
skykongkong8 marked this conversation as resolved.
Show resolved Hide resolved
temp.multiply_i(label);
output.fill(temp.sum({0, 1, 2, 3}));
}
}

void KLDLossLayer::calcDerivative(RunLayerContext &context) {
/**
* d/dQ = -P/Q
*/
nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX); // Q
nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX); // P
nntrainer::Tensor &deriv = context.getOutgoingDerivative(SINGLE_INOUT_IDX);

label.multiply_i(-1.0f);
label.divide(predicted, deriv);
}

} // namespace nntrainer
60 changes: 60 additions & 0 deletions nntrainer/layers/loss/kld_loss_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2021 Jihoon Lee <[email protected]>
*
* @file kld_loss_layer.h
* @date 25 November 2021
* @brief KLD (Kullback-Leibler Divergence) loss implementation
* @see https://github.com/nnstreamer/nntrainer
* @author Jihoon Lee <[email protected]>
* @bug No known bugs except for NYI items
*
*/
#ifndef __KLD_LOSS_LAYER_H__
#define __KLD_LOSS_LAYER_H__

#ifdef __cplusplus

#include <loss_layer.h>
#include <string>
#include <vector>

namespace nntrainer {

/**
* @class KLD (Kullback-Leibler Divergence) Loss layer
* @brief kld loss layer
*/
class KLDLossLayer : public LossLayer {
public:
/**
* @brief Constructor of Constant Loss Layer
*/
KLDLossLayer() : LossLayer() {}

/**
* @brief Destructor of MSE Loss Layer
*/
~KLDLossLayer() = default;

/**
* @copydoc Layer::forwarding(RunLayerContext &context, bool training)
*/
void forwarding(RunLayerContext &context, bool training) override;

/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
void calcDerivative(RunLayerContext &context) override;

/**
* @copydoc Layer::getType()
*/
const std::string getType() const override { return KLDLossLayer::type; }

inline static const std::string type = "kld";
};
} // namespace nntrainer

#endif /* __cplusplus */
#endif // __KLD_LOSS_LAYER_H__
3 changes: 2 additions & 1 deletion nntrainer/layers/loss/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ loss_layer_sources = [
'mse_loss_layer.cpp',
'cross_entropy_sigmoid_loss_layer.cpp',
'cross_entropy_softmax_loss_layer.cpp',
'constant_derivative_loss_layer.cpp'
'constant_derivative_loss_layer.cpp',
'kld_loss_layer.cpp'
]

loss_layer_headers = []
Expand Down
3 changes: 2 additions & 1 deletion nntrainer/models/model_common_properties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ Epochs::Epochs(unsigned int value) { set(value); }

bool LossType::isValid(const std::string &value) const {
ml_logw("Model loss property is deprecated, use loss layer directly instead");
return istrequal(value, "cross") || istrequal(value, "mse");
return istrequal(value, "cross") || istrequal(value, "mse") ||
istrequal(value, "kld");
}

TrainingBatchSize::TrainingBatchSize(unsigned int value) { set(value); }
Expand Down
118 changes: 118 additions & 0 deletions test/unittest/integration_tests/integration_test_loss.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Donghak Park <[email protected]>
*
* @file unittest_loss_crossentropy.cpp
* @date 16 Oct 2024
* @brief CrossEntropy loss Layer Test
* @see https://github.com/nnstreamer/nntrainer
* @author Donghak Park <[email protected]>
* @bug No known bugs except for NYI items
*/

#include <tuple>

#include <gtest/gtest.h>

#include <app_context.h>
#include <layer.h>
#include <lite/core/c/common.h>
#include <model.h>
#include <optimizer.h>

template <typename T>
static std::string withKey(const std::string &key, const T &value) {
std::stringstream ss;
ss << key << "=" << value;
return ss.str();
}

template <typename T>
static std::string withKey(const std::string &key,
std::initializer_list<T> value) {
if (std::empty(value)) {
throw std::invalid_argument("empty data cannot be converted");
}

std::stringstream ss;
ss << key << "=";

auto iter = value.begin();
for (; iter != value.end() - 1; ++iter) {
ss << *iter << ',';
}
ss << *iter;

return ss.str();
}

TEST(crossentropy_loss, model_pass_test) {

std::unique_ptr<ml::train::Model> model = ml::train::createModel(
ml::train::ModelType::NEURAL_NET, {withKey("loss", "cross")});

std::shared_ptr<ml::train::Layer> input_layer = ml::train::createLayer(
"input", {withKey("name", "input0"), withKey("input_shape", "3:32:32")});

std::shared_ptr<ml::train::Layer> fc_layer = ml::train::createLayer(
"fully_connected",
{withKey("unit", 100), withKey("activation", "softmax")});

model->addLayer(input_layer);
model->addLayer(fc_layer);

model->setProperty({withKey("batch_size", 16), withKey("epochs", 10)});

auto optimizer = ml::train::createOptimizer("adam", {"learning_rate=0.001"});
model->setOptimizer(std::move(optimizer));
int status = model->compile();
EXPECT_EQ(status, ML_ERROR_NONE);
status = model->initialize();
EXPECT_EQ(status, ML_ERROR_NONE);
}

TEST(crossentropy_loss, model_fail_test) {

std::unique_ptr<ml::train::Model> model = ml::train::createModel(
ml::train::ModelType::NEURAL_NET, {withKey("loss", "cross")});

std::shared_ptr<ml::train::Layer> input_layer = ml::train::createLayer(
"input", {withKey("name", "input0"), withKey("input_shape", "3:32:32")});
std::shared_ptr<ml::train::Layer> fc_layer = ml::train::createLayer(
"fully_connected", {withKey("unit", 100), withKey("activation", "relu")});

model->addLayer(input_layer);
model->addLayer(fc_layer);

model->setProperty({withKey("batch_size", 16), withKey("epochs", 10)});

auto optimizer = ml::train::createOptimizer("adam", {"learning_rate=0.001"});
model->setOptimizer(std::move(optimizer));
int status = model->compile();
EXPECT_FALSE(status == ML_ERROR_NONE);
}

TEST(kld_loss, compile_test) {

std::unique_ptr<ml::train::Model> model = ml::train::createModel(
ml::train::ModelType::NEURAL_NET, {withKey("loss", "kld")});

std::shared_ptr<ml::train::Layer> input_layer = ml::train::createLayer(
"input", {withKey("name", "input0"), withKey("input_shape", "3:32:32")});
std::shared_ptr<ml::train::Layer> fc_layer = ml::train::createLayer(
"fully_connected",
{withKey("unit", 100), withKey("activation", "softmax")});

model->addLayer(input_layer);
model->addLayer(fc_layer);

model->setProperty({withKey("batch_size", 16), withKey("epochs", 1)});

auto optimizer = ml::train::createOptimizer("adam", {"learning_rate=0.001"});
model->setOptimizer(std::move(optimizer));
int status = model->compile();
EXPECT_FALSE(status == ML_ERROR_NONE);

status = model->initialize();
EXPECT_FALSE(status == ML_ERROR_NONE);
}
32 changes: 32 additions & 0 deletions test/unittest/integration_tests/integration_tests.cpp
EunjuYang marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Donghak Park <[email protected]>
*
* @file integration_tests.cpp
* @date 16 Oct 2024
* @brief Layer Integration Test
* @see https://github.com/nnstreamer/nntrainer
* @author Donghak Park <[email protected]>
* @bug No known bugs except for NYI items
*/

#include <gtest/gtest.h>

int main(int argc, char **argv) {
int result = -1;

try {
testing::InitGoogleTest(&argc, argv);
} catch (...) {
std::cerr << "Error during InitGoogleTest" << std::endl;
return 0;
}

try {
result = RUN_ALL_TESTS();
} catch (...) {
std::cerr << "Error during RUN_ALL_TESTS()" << std::endl;
}

return result;
}
23 changes: 23 additions & 0 deletions test/unittest/integration_tests/meson.build
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
test_name = 'integration_tests'

test_target = [
'integration_tests.cpp',
'integration_test_loss.cpp',
]

exe = executable(
test_name,
test_target,
dependencies: [
nntrainer_test_main_deps,
nntrainer_dep,
nntrainer_ccapi_dep,
],
install: get_option('enable-test'),
install_dir: application_install_dir
)

test(test_name, exe,
args: '--gtest_output=xml:@0@/@[email protected]'.format(meson.build_root(), test_name),
timeout: test_timeout
)
8 changes: 7 additions & 1 deletion test/unittest/layers/unittest_layers_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cross_entropy_loss_layer.h>
#include <cross_entropy_sigmoid_loss_layer.h>
#include <cross_entropy_softmax_loss_layer.h>
#include <kld_loss_layer.h>
#include <layers_common_tests.h>
#include <mse_loss_layer.h>

Expand All @@ -35,6 +36,10 @@ auto semantic_loss_mse = LayerSemanticsParamType(
nntrainer::MSELossLayer::type, {},
LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);

auto semantic_loss_kld =
LayerSemanticsParamType(nntrainer::createLayer<nntrainer::KLDLossLayer>,
nntrainer::KLDLossLayer::type, {}, 0, false, 1);

auto semantic_loss_constant_derivative = LayerSemanticsParamType(
nntrainer::createLayer<nntrainer::ConstantDerivativeLossLayer>,
nntrainer::ConstantDerivativeLossLayer::type, {},
Expand All @@ -48,4 +53,5 @@ GTEST_PARAMETER_TEST(LossCross, LayerSemantics,
::testing::Values(semantic_loss_cross, semantic_loss_mse,
semantic_loss_cross_softmax,
semantic_loss_cross_sigmoid,
semantic_loss_constant_derivative));
semantic_loss_constant_derivative,
semantic_loss_kld));
1 change: 1 addition & 0 deletions test/unittest/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ subdir('compiler')
subdir('layers')
subdir('datasets')
subdir('models')
subdir('integration_tests')
Loading