Skip to content

Commit

Permalink
[luci/lang] support RoPE Operation (#14014)
Browse files Browse the repository at this point in the history
This commit supports for RoPE operation in luci IR

ONE-DCO-1.0-Signed-off-by: youngsik kim <[email protected]>
  • Loading branch information
ys44kim authored Sep 24, 2024
1 parent 646c7f1 commit 520ef3c
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 0 deletions.
33 changes: 33 additions & 0 deletions compiler/luci/lang/include/luci/IR/AttrRoPEMode.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_IR_ATTR_ROPE_MODE_H__
#define __LUCI_IR_ATTR_ROPE_MODE_H__

namespace luci
{

enum class RoPEMode
{
UNDEFINED, // This is not defined by Circle. This was added to prevent programming error.

GPT_NEOX,
GPT_J,
};

} // namespace luci

#endif // __LUCI_IR_ATTR_ROPE_MODE_H__
1 change: 1 addition & 0 deletions compiler/luci/lang/include/luci/IR/CircleNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
#include "Nodes/CircleGRU.h"
#include "Nodes/CircleInstanceNorm.h"
#include "Nodes/CircleRmsNorm.h"
#include "Nodes/CircleRoPE.h"
// Virtual nodes
#include "Nodes/CircleConst.h"
#include "Nodes/CircleInput.h"
Expand Down
1 change: 1 addition & 0 deletions compiler/luci/lang/include/luci/IR/CircleNodes.lst
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ CIRCLE_NODE(BCQ_FULLY_CONNECTED, CircleBCQFullyConnected)
CIRCLE_NODE(BCQ_GATHER, CircleBCQGather)
CIRCLE_NODE(INSTANCE_NORM, CircleInstanceNorm)
CIRCLE_NODE(RMS_NORM, CircleRmsNorm)
CIRCLE_NODE(ROPE, CircleRoPE)
// Virtual node(s)
CIRCLE_VNODE(CIRCLECONST, CircleConst)
CIRCLE_VNODE(CIRCLEINPUT, CircleInput)
Expand Down
55 changes: 55 additions & 0 deletions compiler/luci/lang/include/luci/IR/Nodes/CircleRoPE.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_IR_CIRCLEROPE_H__
#define __LUCI_IR_CIRCLEROPE_H__

#include "luci/IR/CircleNodeDecl.h"
#include "luci/IR/CircleOpcode.h"

#include "luci/IR/CircleNodeMixins.h"
#include "luci/IR/AttrRoPEMode.h"

namespace luci
{

/**
* @brief ROPE in Circle
*/
class CircleRoPE final : public FixedArityNode<3, CircleNodeImpl<CircleOpcode::ROPE>>
{
public:
/// @note Currently only support FLOAT32 as input node
loco::Node *input(void) const { return at(0)->node(); }
void input(loco::Node *node) { at(0)->node(node); }

loco::Node *sin_table(void) const { return at(1)->node(); }
void sin_table(loco::Node *node) { at(1)->node(node); }

loco::Node *cos_table(void) const { return at(2)->node(); }
void cos_table(loco::Node *node) { at(2)->node(node); }

public:
RoPEMode mode() const { return _mode; }
void mode(RoPEMode mode) { _mode = mode; }

private:
RoPEMode _mode{RoPEMode::GPT_NEOX};
};

} // namespace luci

#endif // __LUCI_IR_CIRCLEROPE_H__
91 changes: 91 additions & 0 deletions compiler/luci/lang/src/Nodes/CircleRoPE.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "luci/IR/Nodes/CircleRoPE.h"

#include "luci/IR/CircleDialect.h"
#include "luci/IR/CircleNodeVisitor.h"

#include <gtest/gtest.h>

TEST(CircleRoPETest, constructor)
{
luci::CircleRoPE rope;

ASSERT_EQ(luci::CircleDialect::get(), rope.dialect());
ASSERT_EQ(luci::CircleOpcode::ROPE, rope.opcode());

ASSERT_EQ(nullptr, rope.input());
ASSERT_EQ(nullptr, rope.sin_table());
ASSERT_EQ(nullptr, rope.cos_table());

ASSERT_EQ(luci::RoPEMode::GPT_NEOX, rope.mode());
}

TEST(CircleRoPETest, input_NEG)
{
luci::CircleRoPE rope;
luci::CircleRoPE node;

rope.input(&node);
rope.sin_table(&node);
rope.cos_table(&node);
ASSERT_NE(nullptr, rope.input());
ASSERT_NE(nullptr, rope.sin_table());
ASSERT_NE(nullptr, rope.cos_table());

rope.input(nullptr);
rope.sin_table(nullptr);
rope.cos_table(nullptr);
ASSERT_EQ(nullptr, rope.input());
ASSERT_EQ(nullptr, rope.sin_table());
ASSERT_EQ(nullptr, rope.cos_table());

rope.mode(luci::RoPEMode::GPT_J);
ASSERT_NE(luci::RoPEMode::GPT_NEOX, rope.mode());
}

TEST(CircleRoPETest, arity_NEG)
{
luci::CircleRoPE rope;

ASSERT_NO_THROW(rope.arg(2));
ASSERT_THROW(rope.arg(3), std::out_of_range);
}

TEST(CircleRoPETest, visit_mutable_NEG)
{
struct TestVisitor final : public luci::CircleNodeMutableVisitor<void>
{
};

luci::CircleRoPE rope;

TestVisitor tv;
ASSERT_THROW(rope.accept(&tv), std::exception);
}

TEST(CircleRoPETest, visit_NEG)
{
struct TestVisitor final : public luci::CircleNodeVisitor<void>
{
};

luci::CircleRoPE rope;

TestVisitor tv;
ASSERT_THROW(rope.accept(&tv), std::exception);
}

0 comments on commit 520ef3c

Please sign in to comment.