Skip to content

Commit

Permalink
[LoRA] Apply Inception-LoRA
Browse files Browse the repository at this point in the history
- updates the LoRA computation (applying Inception-LoRA)
- apply LoRA without matrix build with LoRA vectors
- revise `forwarding()`
- revise `calcGradient()`
- revise `calcDerivative()`

Signed-off-by: Eunju Yang <[email protected]>
  • Loading branch information
EunjuYang committed Mar 19, 2024
1 parent f72b0f2 commit 10aae55
Showing 1 changed file with 28 additions and 18 deletions.
46 changes: 28 additions & 18 deletions nntrainer/layers/fc_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace nntrainer {
static constexpr size_t SINGLE_INOUT_IDX = 0;

enum FCParams { weight, bias };
enum LORAParams { loraA, loraB, loraW, loraOut };
enum LORAParams { loraA, loraB, loraTmp, loraOut };

FullyConnectedLayer::FullyConnectedLayer() :
LayerImpl(), fc_props(props::Unit(), props::LoraRank()) {
Expand Down Expand Up @@ -123,6 +123,12 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0011 : 0b0101);

/** loraTmp: (1, lora_rank) */
TensorDim loraTmp_dim(
1, is_nchw ? 1 : lora_rank, 1, is_nchw ? lora_rank : 1,
TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
is_nchw ? 0b0001 : 0b0100);

lora_idx[LORAParams::loraA] = context.requestWeight(
loraA_dim, weight_initializer, weight_regularizer,
weight_regularizer_constant, weight_decay, "loraA", true);
Expand All @@ -131,8 +137,8 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
loraB_dim, weight_initializer, weight_regularizer,
weight_regularizer_constant, weight_decay, "loraB", true);

lora_idx[LORAParams::loraW] = context.requestTensor(
weight_dim, "weight_lora", Tensor::Initializer::NONE, true,
lora_idx[LORAParams::loraTmp] = context.requestTensor(
loraTmp_dim, "hidden_tmp_lora", Tensor::Initializer::NONE, true,
TensorLifespan::FORWARD_DERIV_LIFESPAN);

lora_idx[LORAParams::loraOut] =
Expand Down Expand Up @@ -178,11 +184,12 @@ void FullyConnectedLayer::forwarding(RunLayerContext &context, bool training) {
if (!std::get<props::LoraRank>(fc_props).empty()) {
Tensor &loraA = context.getWeight(lora_idx[LORAParams::loraA]);
Tensor &loraB = context.getWeight(lora_idx[LORAParams::loraB]);
Tensor &weight_lora = context.getTensor(lora_idx[LORAParams::loraW]);
Tensor &hidden_lora = context.getTensor(lora_idx[LORAParams::loraOut]);
loraA.dot(loraB, weight_lora);
input_.dot(weight_lora, hidden_lora, false, false);
hidden_.add_i(hidden_lora);
Tensor &hidden_tmp_lora = context.getTensor(lora_idx[LORAParams::loraTmp]);
Tensor &hidden_out_lora = context.getTensor(lora_idx[LORAParams::loraOut]);

input_.dot(loraA, hidden_tmp_lora, false, false);
hidden_tmp_lora.dot(loraB, hidden_out_lora, false, false);
hidden_.add_i(hidden_out_lora);
}

if (auto &disable_bias = std::get<props::DisableBias>(*layer_impl_props);
Expand Down Expand Up @@ -240,8 +247,10 @@ void FullyConnectedLayer::calcDerivative(RunLayerContext &context) {
Tensor &ret_ = context.getOutgoingDerivative(SINGLE_INOUT_IDX);

if (!std::get<props::LoraRank>(fc_props).empty()) {
Tensor &weight_lora = context.getTensor(lora_idx[LORAParams::loraW]);
ret_.dot_deriv_wrt_1(weight.add(weight_lora), derivative_, false, false);
Tensor &lora_A = context.getWeight(lora_idx[LORAParams::loraA]);
Tensor &lora_B = context.getWeight(lora_idx[LORAParams::loraB]);
ret_.dot_deriv_wrt_1(weight.add(lora_A.dot(lora_B)), derivative_, false,
false);
} else {
ret_.dot_deriv_wrt_1(weight, derivative_, false, false);
}
Expand Down Expand Up @@ -276,21 +285,22 @@ void FullyConnectedLayer::calcGradient(RunLayerContext &context) {
/** (lora) calcGradient - compute gradients of LoRA params only */
Tensor &djdla = context.getWeightGrad(lora_idx[LORAParams::loraA]);
Tensor &djdlb = context.getWeightGrad(lora_idx[LORAParams::loraB]);
Tensor &djdlora_w = context.getTensorGrad(lora_idx[LORAParams::loraW]);
Tensor &djdtmp = context.getTensorGrad(lora_idx[LORAParams::loraTmp]);

const Tensor &derivative_ = context.getIncomingDerivative(SINGLE_INOUT_IDX);
Tensor &input_ = context.getInput(SINGLE_INOUT_IDX);
Tensor &loraA = context.getWeight(lora_idx[LORAParams::loraA]);
Tensor &loraB = context.getWeight(lora_idx[LORAParams::loraB]);
Tensor &loraTmp = context.getTensor(lora_idx[LORAParams::loraTmp]);

input_.dot_deriv_wrt_2(
djdlora_w, derivative_, false, false,
!context.isGradientFirstAccess(lora_idx[LORAParams::loraW]));
loraA.dot_deriv_wrt_2(
djdlb, djdlora_w, false, false,
loraTmp.dot_deriv_wrt_2(
djdlb, derivative_, false, false,
!context.isGradientFirstAccess(lora_idx[LORAParams::loraB]));
djdtmp.dot_deriv_wrt_1(
loraB, derivative_, false, false,
!context.isGradientFirstAccess(lora_idx[LORAParams::loraB]));
djdla.dot_deriv_wrt_1(
loraB, djdlora_w, false, false,
input_.dot_deriv_wrt_2(
djdla, djdtmp, false, false,
!context.isGradientFirstAccess(lora_idx[LORAParams::loraA]));
}
}
Expand Down

0 comments on commit 10aae55

Please sign in to comment.