From 39fefb4b9c0d4c0690fff96f46060237a102fa17 Mon Sep 17 00:00:00 2001 From: Donghak PARK Date: Thu, 17 Oct 2024 16:47:28 +0900 Subject: [PATCH] [Unnittest & Loss] Update KLD Loss & Fix Unittest Update KLD Loss Function Reflect review **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: Donghak PARK --- nntrainer/layers/loss/kld_loss_layer.cpp | 95 ++++++------------- nntrainer/layers/loss/kld_loss_layer.h | 27 +----- nntrainer/layers/loss/meson.build | 3 +- nntrainer/models/model_common_properties.cpp | 3 +- ...sentropy.cpp => integration_test_loss.cpp} | 25 +++++ .../integration_tests/integration_tests.cpp | 4 +- test/unittest/integration_tests/meson.build | 2 +- 7 files changed, 66 insertions(+), 93 deletions(-) rename test/unittest/integration_tests/{unittest_loss_crossentropy.cpp => integration_test_loss.cpp} (76%) diff --git a/nntrainer/layers/loss/kld_loss_layer.cpp b/nntrainer/layers/loss/kld_loss_layer.cpp index 7e8300cfe3..7fc6e95691 100644 --- a/nntrainer/layers/loss/kld_loss_layer.cpp +++ b/nntrainer/layers/loss/kld_loss_layer.cpp @@ -17,75 +17,42 @@ #include namespace nntrainer { -KLDLossLayer::KLDLossLayer() {} - -KLDLossLayer::~KLDLossLayer() {} - -void KLDLossLayer::finalize(nntrainer::InitLayerContext &context) { - if (context.getNumInputs() != 2) { - throw std::invalid_argument("kld loss requires two input"); - } - const auto &input_dims = context.getInputDimensions(); - - if (input_dims.front() != input_dims.back()) { - throw std::invalid_argument("dimension of mu and log_var is different"); - } - - auto &input_dim = input_dims.front(); - - temp_idx = context.requestTensor(input_dim, "temp"); - before_sum_idx = context.requestTensor( - input_dim, "before_sum", nntrainer::Initializer::NONE, false, - nntrainer::TensorLifespan::FORWARD_FUNC_LIFESPAN); - - /// output is a scaler-like tensor - context.setOutputDimensions({{input_dim.batch(), 1, 1, 1}}); -} - -void KLDLossLayer::setProperty(const std::vector &values) { - if (values.size()) { - throw std::invalid_argument( - "kld loss does not take any properties, but values given"); - } -} +static constexpr size_t SINGLE_INOUT_IDX = 0; void KLDLossLayer::forwarding(RunLayerContext &context, bool training) { - // -0.5 * sum(1 + log_std - pow(mu, 2) - exp(log_std)) - auto &mu = context.getInput(0); - auto &log_std = context.getInput(1); - auto &ret = context.getOutput(0); - auto &temp = context.getTensor(temp_idx); - auto &before_sum = context.getTensor(before_sum_idx); - - mu.pow(2.0f, temp); // 1. temp = mu ^ 2 - log_std.subtract(temp, before_sum); // 2. before_sum = log_std - temp - log_std.apply(expf, temp); // 3. temp = exp(log_std) - 1 - temp.subtract_i(1.0f); - before_sum.subtract_i(temp); // 4. before_sum = before_sum - temp - before_sum.sum({1, 2, 3}, ret, -0.5); // 5. sum * 0.5 + // Result = (P * (P / Q).log()).sum() + // KL(P ∣∣ Q) whereP denotes the distribution of the observations in datasets + // and Q denotes the model output. + + nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); + if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { + nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX); + nntrainer::Tensor temp; // temp output + /** + * 1. Output = label / predicted + * 2. Output = output * label + * 3. Output = log(output) + * 4. Output = sum(output) + */ + label.divide(predicted, temp); + temp.multiply_i(label); + temp.apply(logf, temp); + output.fill(temp.sum({0, 1, 2, 3})); + } } void KLDLossLayer::calcDerivative(RunLayerContext &context) { - auto &d_incoming = context.getIncomingDerivative(0); - auto &mu = context.getInput(0); - - auto &temp = context.getTensor(temp_idx); - - auto &d_mu = context.getOutgoingDerivative(0); - auto &d_var = context.getOutgoingDerivative(1); - - // d_mu = d_incoming * mu - mu.multiply(d_incoming, d_mu); - - // temp is exp(log_std) - 1; - // d_var = d_incoming * (-0.5) * ( 1 - exp(log_std) ) - // = d_incoming * (0.5) * ( temp ) - temp.multiply(d_incoming.multiply(0.5), d_var); + /** + * d/dQ = -P/Q + */ + nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX); // Q + nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX); // P + nntrainer::Tensor &deriv = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + + label.multiply_i(-1.0f); + label.divide_i(predicted); + deriv.fill(label); } -void KLDLossLayer::setBatch(nntrainer::RunLayerContext &context, - unsigned int batch) { - context.updateTensor(temp_idx, batch); - context.updateTensor(before_sum_idx, batch); -} } // namespace nntrainer diff --git a/nntrainer/layers/loss/kld_loss_layer.h b/nntrainer/layers/loss/kld_loss_layer.h index b8c597840f..43d38e2dda 100644 --- a/nntrainer/layers/loss/kld_loss_layer.h +++ b/nntrainer/layers/loss/kld_loss_layer.h @@ -25,27 +25,17 @@ namespace nntrainer { * @class KLD (Kullback-Leibler Divergence) Loss layer * @brief kld loss layer */ -class KLDLossLayer final : public LossLayer { +class KLDLossLayer : public LossLayer { public: /** * @brief Constructor of Constant Loss Layer */ - KLDLossLayer(); + KLDLossLayer() : LossLayer() {} /** * @brief Destructor of MSE Loss Layer */ - ~KLDLossLayer(); - - /** - * @copydoc Layer::finalize(InitLayerContext &context) - */ - void finalize(nntrainer::InitLayerContext &context) override; - - /** - * @copydoc Layer::setProperty(const std::vector &values) - */ - void setProperty(const std::vector &values) override; + ~KLDLossLayer() = default; /** * @copydoc Layer::forwarding(RunLayerContext &context, bool training) @@ -62,20 +52,9 @@ class KLDLossLayer final : public LossLayer { */ const std::string getType() const override { return KLDLossLayer::type; } - /** - * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch) - */ - void setBatch(nntrainer::RunLayerContext &context, - unsigned int batch) override; - inline static const std::string type = "kld"; - -private: - unsigned before_sum_idx; - unsigned temp_idx; }; } // namespace nntrainer #endif /* __cplusplus */ - #endif // __KLD_LOSS_LAYER_H__ diff --git a/nntrainer/layers/loss/meson.build b/nntrainer/layers/loss/meson.build index 2370b613ad..9678ff045b 100644 --- a/nntrainer/layers/loss/meson.build +++ b/nntrainer/layers/loss/meson.build @@ -3,7 +3,8 @@ loss_layer_sources = [ 'mse_loss_layer.cpp', 'cross_entropy_sigmoid_loss_layer.cpp', 'cross_entropy_softmax_loss_layer.cpp', - 'constant_derivative_loss_layer.cpp' + 'constant_derivative_loss_layer.cpp', + 'kld_loss_layer.cpp' ] loss_layer_headers = [] diff --git a/nntrainer/models/model_common_properties.cpp b/nntrainer/models/model_common_properties.cpp index 984cad662a..58aaa94f86 100644 --- a/nntrainer/models/model_common_properties.cpp +++ b/nntrainer/models/model_common_properties.cpp @@ -20,7 +20,8 @@ Epochs::Epochs(unsigned int value) { set(value); } bool LossType::isValid(const std::string &value) const { ml_logw("Model loss property is deprecated, use loss layer directly instead"); - return istrequal(value, "cross") || istrequal(value, "mse"); + return istrequal(value, "cross") || istrequal(value, "mse") || + istrequal(value, "kld"); } TrainingBatchSize::TrainingBatchSize(unsigned int value) { set(value); } diff --git a/test/unittest/integration_tests/unittest_loss_crossentropy.cpp b/test/unittest/integration_tests/integration_test_loss.cpp similarity index 76% rename from test/unittest/integration_tests/unittest_loss_crossentropy.cpp rename to test/unittest/integration_tests/integration_test_loss.cpp index 7407d6ac95..ac5571f75d 100644 --- a/test/unittest/integration_tests/unittest_loss_crossentropy.cpp +++ b/test/unittest/integration_tests/integration_test_loss.cpp @@ -91,3 +91,28 @@ TEST(crossentropy_loss, model_fail_test) { int status = model->compile(); EXPECT_FALSE(status == ML_ERROR_NONE); } + +TEST(kld_loss, compile_test) { + + std::unique_ptr model = ml::train::createModel( + ml::train::ModelType::NEURAL_NET, {withKey("loss", "kld")}); + + std::shared_ptr input_layer = ml::train::createLayer( + "input", {withKey("name", "input0"), withKey("input_shape", "3:32:32")}); + std::shared_ptr fc_layer = ml::train::createLayer( + "fully_connected", + {withKey("unit", 100), withKey("activation", "softmax")}); + + model->addLayer(input_layer); + model->addLayer(fc_layer); + + model->setProperty({withKey("batch_size", 16), withKey("epochs", 1)}); + + auto optimizer = ml::train::createOptimizer("adam", {"learning_rate=0.001"}); + model->setOptimizer(std::move(optimizer)); + int status = model->compile(); + EXPECT_FALSE(status == ML_ERROR_NONE); + + status = model->initialize(); + EXPECT_FALSE(status == ML_ERROR_NONE); +} diff --git a/test/unittest/integration_tests/integration_tests.cpp b/test/unittest/integration_tests/integration_tests.cpp index ef2ab82aee..d8491f76e8 100644 --- a/test/unittest/integration_tests/integration_tests.cpp +++ b/test/unittest/integration_tests/integration_tests.cpp @@ -2,9 +2,9 @@ /** * Copyright (C) 2024 Donghak Park * - * @file unittest_loss_crossentropy.cpp + * @file integration_tests.cpp * @date 16 Oct 2024 - * @brief CrossEntropy loss Layer Test + * @brief Layer Integration Test * @see https://github.com/nnstreamer/nntrainer * @author Donghak Park * @bug No known bugs except for NYI items diff --git a/test/unittest/integration_tests/meson.build b/test/unittest/integration_tests/meson.build index cbc6523cad..9981db30c0 100644 --- a/test/unittest/integration_tests/meson.build +++ b/test/unittest/integration_tests/meson.build @@ -2,7 +2,7 @@ test_name = 'integration_tests' test_target = [ 'integration_tests.cpp', - 'unittest_loss_crossentropy.cpp', + 'integration_test_loss.cpp', ] exe = executable(