Skip to content

Commit

Permalink
[compiler] add params for rope mode
Browse files Browse the repository at this point in the history
add param to support two modes(neox, gpt) of rope

ONE-DCO-1.0-Signed-off-by: youngsik kim <[email protected]>
  • Loading branch information
ys44kim committed Sep 12, 2024
1 parent 19d7e7e commit 8dbd1df
Show file tree
Hide file tree
Showing 26 changed files with 217 additions and 37 deletions.
4 changes: 4 additions & 0 deletions compiler/circlechef/circle/src/Op/RoPE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ circlechef::Operation *CircleOpRoPE::build(const circle::Operator *op, CircleImp

operation->set_type("RoPE");

auto op_options = operation->mutable_rope_options();

auto op_params = op->builtin_options_as_RoPEOptions();
assert(op_params != nullptr);

op_options->set_mode(op_params->mode());

return operation;
}

Expand Down
2 changes: 2 additions & 0 deletions compiler/circlechef/core/src/Op/RoPE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
flatbuffers::Offset<void> RoPEChef::value(flatbuffers::FlatBufferBuilder &fbb) const
{
auto &operation = (*_operation);
assert(operation.has_rope_options());

circle::RoPEOptionsBuilder options_builder{fbb};
options_builder.add_mode(static_cast<circle::RoPEMode>(operation.rope_options().mode()));

return options_builder.Finish().Union();
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/circlechef/proto/circlechef.proto
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ message BCQGatherOptions {
}

message RoPEOptions {
optional int32 mode = 1 [default = 0];;
}

message Operation {
Expand All @@ -115,6 +116,7 @@ message Operation {
optional BCQGatherOptions bcq_gather_options = 103;
optional GRUOptions gru_options = 104;
optional FullyConnectedOptions fullyconnected_options = 105;
optional RoPEOptions rope_options = 106;
}

// For additional subgraphs
Expand Down
15 changes: 15 additions & 0 deletions compiler/circledump/src/OpPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,20 @@ class InstanceNormPrinter : public OpPrinter
}
};

class RoPEPrinter : public OpPrinter
{
public:
void options(const circle::Operator *op, std::ostream &os) const override
{
if (auto *params = op->builtin_options_as_RoPEOptions())
{
os << " ";
os << "mode(" << EnumNameRoPEMode(params->mode()) << ") ";
os << std::endl;
}
}
};

OpPrinterRegistry::OpPrinterRegistry()
{
_op_map[circle::BuiltinOperator_ADD] = make_unique<AddPrinter>();
Expand Down Expand Up @@ -912,6 +926,7 @@ OpPrinterRegistry::OpPrinterRegistry()
_op_map[circle::BuiltinOperator_BCQ_GATHER] = make_unique<BCQGatherPrinter>();
_op_map[circle::BuiltinOperator_GRU] = make_unique<GRUPrinter>();
_op_map[circle::BuiltinOperator_INSTANCE_NORM] = make_unique<InstanceNormPrinter>();
_op_map[circle::BuiltinOperator_ROPE] = make_unique<RoPEPrinter>();
}

} // namespace circledump
7 changes: 7 additions & 0 deletions compiler/luci-interpreter/src/core/KernelParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <luci/IR/AttrPadding.h>
#include <luci/IR/AttrFusedActFunc.h>
#include <luci/IR/AttrMirrorPadMode.h>
#include <luci/IR/AttrRoPEMode.h>
#include <luci_interpreter/core/DataType.h>

#include <cstdint>
Expand All @@ -32,6 +33,7 @@ namespace luci_interpreter
using Activation = luci::FusedActFunc;
using Padding = luci::Padding;
using MirrorPadMode = luci::MirrorPadMode;
using RoPEMode = luci::RoPEMode;

struct AddParams
{
Expand Down Expand Up @@ -115,6 +117,11 @@ struct InstanceNormParams
Activation activation;
};

struct RoPEParams
{
RoPEMode mode;
};

struct L2NormParams
{
Activation activation;
Expand Down
7 changes: 4 additions & 3 deletions compiler/luci-interpreter/src/kernels/RoPE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace kernels
{

RoPE::RoPE(const Tensor *input, const Tensor *sin_table, const Tensor *cos_table,
Tensor *output) : Kernel({input, sin_table, cos_table}, {output})
Tensor *output, const RoPEParams &params)
: KernelWithParams<RoPEParams>({input, sin_table, cos_table}, {output}, params)
{
}

Expand Down Expand Up @@ -59,7 +60,7 @@ void RoPE::evalFloat() const
const float *cos_table_data = getTensorData<float>(cos_table());
float *output_data = getTensorData<float>(output());

if (input_shape.DimensionsCount() == 4)
if (params().mode == RoPEMode::NEOX)
{
const int32_t i0_n = input_shape.Dims(0);
const int32_t multihead_n = input_shape.Dims(1);
Expand All @@ -86,7 +87,7 @@ void RoPE::evalFloat() const
}
}
else
throw std::runtime_error("luci-intp RoPE unsupported rank.");
throw std::runtime_error("luci-intp RoPE unsupported mode.");
}

} // namespace kernels
Expand Down
5 changes: 3 additions & 2 deletions compiler/luci-interpreter/src/kernels/RoPE.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
#define LUCI_INTERPRETER_KERNELS_ROPE_H

#include "core/Kernel.h"
#include "core/KernelParams.h"

namespace luci_interpreter
{
namespace kernels
{

class RoPE : public Kernel
class RoPE : public KernelWithParams<RoPEParams>
{
public:
RoPE(const Tensor *input, const Tensor *sin_table, const Tensor *cos_table,
Tensor *output);
Tensor *output, const RoPEParams &params);

const Tensor *input() const { return _inputs[0]; }
const Tensor *sin_table() const { return _inputs[1]; }
Expand Down
10 changes: 8 additions & 2 deletions compiler/luci-interpreter/src/kernels/RoPE.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ TEST_F(RoPETest, floatTest)
Tensor sin_table = makeInputTensor<DataType::FLOAT32>(sin_shape, sin_data, _memory_manager.get());
Tensor cos_table = makeInputTensor<DataType::FLOAT32>(cos_shape, cos_data, _memory_manager.get());

RoPE kernel(&input_tensor, &sin_table, &cos_table, &output_tensor);
RoPEParams params{};
params.mode = RoPEMode::NEOX;

RoPE kernel(&input_tensor, &sin_table, &cos_table, &output_tensor, params);
kernel.configure();
_memory_manager->allocate_memory(output_tensor);
kernel.execute();
Expand Down Expand Up @@ -85,7 +88,10 @@ TEST_F(RoPETest, Unsupported_dims_NEG)
Tensor sin_table = makeInputTensor<DataType::FLOAT32>(sin_shape, sin_data, _memory_manager.get());
Tensor cos_table = makeInputTensor<DataType::FLOAT32>(cos_shape, cos_data, _memory_manager.get());

RoPE kernel(&input_tensor, &sin_table, &cos_table, &output_tensor);
RoPEParams params{};
params.mode = RoPEMode::NEOX;

RoPE kernel(&input_tensor, &sin_table, &cos_table, &output_tensor, params);
EXPECT_ANY_THROW(kernel.configure());
}

Expand Down
5 changes: 4 additions & 1 deletion compiler/luci-interpreter/src/loader/nodes/RoPE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ std::unique_ptr<Kernel> build_kernel_CircleRoPE(const luci::CircleNode *circle_n
const Tensor *cos_table = helper.getInputTensor(node->cos_table());

Tensor *output = helper.getOutputTensor(node);

RoPEParams params{};
params.mode = node->mode();

return std::make_unique<kernels::RoPE>(input, sin_table, cos_table, output);
return std::make_unique<kernels::RoPE>(input, sin_table, cos_table, output, params);
}

} // namespace luci_interpreter
4 changes: 2 additions & 2 deletions compiler/luci/export/src/CircleBuiltinTypesExtractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -548,9 +548,9 @@ class BuiltinOptionsExtractor final
to_circle_actfunc(node->fusedActivationFunction()))
.Union();
}
flatbuffers::Offset<void> visit(luci::CircleRoPE *)
flatbuffers::Offset<void> visit(luci::CircleRoPE *node)
{
return circle::CreateRoPEOptions(_builder).Union();
return circle::CreateRoPEOptions(_builder, to_circle_rope(node->mode())).Union();
}

protected:
Expand Down
13 changes: 13 additions & 0 deletions compiler/luci/export/src/CircleExporterUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,19 @@ circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode)
}
}

circle::RoPEMode to_circle_rope(luci::RoPEMode mode)
{
switch (mode)
{
case luci::RoPEMode::NEOX:
return circle::RoPEMode::RoPEMode_NEOX;
case luci::RoPEMode::GPT:
return circle::RoPEMode::RoPEMode_GPT;
default:
INTERNAL_EXN_V("trying to convert unsupported luci::RoPEMode", oops::to_uint32(mode));
}
}

circle::FullyConnectedOptionsWeightsFormat
to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format)
{
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/export/src/CircleExporterUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace luci
circle::ActivationFunctionType to_circle_actfunc(luci::FusedActFunc func);
circle::TensorType to_circle_tensortype(loco::DataType type);
circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode);
circle::RoPEMode to_circle_rope(luci::RoPEMode mode);
circle::FullyConnectedOptionsWeightsFormat
to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format);
circle::DimensionType to_circle_dimensiontype(luci::DimensionType type);
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/import/include/luci/Import/CircleReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ loco::DataType luci_datatype(circle::TensorType type);
FusedActFunc luci_actfunc(const circle::ActivationFunctionType type);
Padding luci_padding(const circle::Padding padding);
MirrorPadMode luci_mirrorpad_mode(const circle::MirrorPadMode mode);
RoPEMode luci_rope_mode(const circle::RoPEMode mode);
luci::CircleFullyConnected::WeightsFormat
luci_weights_format(const circle::FullyConnectedOptionsWeightsFormat weights_format);
std::unique_ptr<CircleQuantParam>
Expand Down
33 changes: 33 additions & 0 deletions compiler/luci/lang/include/luci/IR/AttrRoPEMode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_IR_ATTR_ROPE_MODE_H__
#define __LUCI_IR_ATTR_ROPE_MODE_H__

namespace luci
{

enum class RoPEMode
{
UNDEFINED, // This is not defined by Circle. This was added to prevent programming error.

NEOX,
GPT,
};

} // namespace luci

#endif // __LUCI_IR_ATTR_ROPE_MODE_H__
8 changes: 8 additions & 0 deletions compiler/luci/lang/include/luci/IR/Nodes/CircleRoPE.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "luci/IR/CircleOpcode.h"

#include "luci/IR/CircleNodeMixins.h"
#include "luci/IR/AttrRoPEMode.h"

namespace luci
{
Expand All @@ -40,6 +41,13 @@ class CircleRoPE final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::R

loco::Node *cos_table(void) const { return at(2)->node(); }
void cos_table(loco::Node *node) { at(2)->node(node); }

public:
RoPEMode mode() const { return _mode; }
void mode(RoPEMode mode) { _mode = mode; }

private:
RoPEMode _mode{RoPEMode::NEOX};
};

} // namespace luci
Expand Down
7 changes: 6 additions & 1 deletion compiler/luci/lang/src/Nodes/CircleRoPE.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ TEST(CircleRoPETest, constructor)

ASSERT_EQ(nullptr, rope.input());
ASSERT_EQ(nullptr, rope.sin_table());
ASSERT_EQ(nullptr, rope.cos_table());
ASSERT_EQ(nullptr, rope.cos_table());

ASSERT_EQ(luci::RoPEMode::NEOX, rope.mode());
}

TEST(CircleRoPETest, input_NEG)
Expand All @@ -51,6 +53,9 @@ TEST(CircleRoPETest, input_NEG)
ASSERT_EQ(nullptr, rope.input());
ASSERT_EQ(nullptr, rope.sin_table());
ASSERT_EQ(nullptr, rope.cos_table());

rope.mode(luci::RoPEMode::GPT);
ASSERT_NE(luci::RoPEMode::NEOX, rope.mode());
}

TEST(CircleRoPETest, arity_NEG)
Expand Down
14 changes: 14 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilder.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,20 @@ TEST_F(CircleNodeSummaryBuilderTest, MirrorPad_validate_mirror_padding_NEG)
EXPECT_FALSE(mock_build(&node));
}

TEST_F(CircleNodeSummaryBuilderTest, RoPE_validate)
{
luci::CircleRoPE node;
node.mode(luci::RoPEMode::NEOX);
EXPECT_TRUE(mock_build(&node));
}

TEST_F(CircleNodeSummaryBuilderTest, RoPE_validate_NEG)
{
luci::CircleRoPE node;
node.mode(luci::RoPEMode::UNDEFINED);
EXPECT_FALSE(mock_build(&node));
}

TEST_F(CircleNodeSummaryBuilderTest, Mul_validate)
{
luci::CircleMul node;
Expand Down
22 changes: 22 additions & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,19 @@ std::string to_str(luci::MirrorPadMode mode)
}
}

std::string to_str(luci::RoPEMode mode)
{
switch (mode)
{
case luci::RoPEMode::NEOX:
return "NEOX";
case luci::RoPEMode::GPT:
return "GPT";
default:
return "Error";
}
}

} // namespace

namespace luci
Expand Down Expand Up @@ -1195,6 +1208,15 @@ std::vector<std::string> CircleWhileOutSummaryBuilder::get_input_names(const luc
return {"while"};
}

bool CircleRoPESummaryBuilder::validate(const luci::CircleNode *node)
{
auto rope = loco::must_cast<const luci::CircleRoPE *>(node);
if (rope->mode() == luci::RoPEMode::UNDEFINED)
return false;

return true;
}

std::vector<std::string> CircleRoPESummaryBuilder::get_input_names(const luci::CircleNode *)
{
return {"input", "sin_table", "cos_table"};
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/logex/src/CircleNodeSummaryBuilders.h
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,7 @@ class CircleWhileOutSummaryBuilder final : public CircleNodeSummaryBuilder
class CircleRoPESummaryBuilder final : public CircleNodeSummaryBuilder
{
private:
bool validate(const luci::CircleNode *node);
std::vector<std::string> get_input_names(const luci::CircleNode *);
};

Expand Down
8 changes: 7 additions & 1 deletion compiler/luci/partition/src/Nodes/CircleRoPE.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ class NodeGraphlet : public NodeGraphletT<luci::CircleRoPE>
NodeGraphlet() = default;

public:
void init(loco::Graph *g) override { NodeGraphletT<luci::CircleRoPE>::init(g); }
void init(loco::Graph *g) override
{
NodeGraphletT<luci::CircleRoPE>::init(g);

_node->mode(luci::RoPEMode::NEOX);
}
};

class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
Expand All @@ -51,6 +56,7 @@ class TestNodeGraph : public TestIsOGraph<3>, public NodeGraphlet
node()->sin_table(input(1));
node()->cos_table(input(2));


output()->from(node());
}
};
Expand Down
Loading

0 comments on commit 8dbd1df

Please sign in to comment.