diff --git a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h index 1e1adfca5f5..c40a1315b40 100644 --- a/compiler/luci/export/src/CircleBuiltinTypesExtractor.h +++ b/compiler/luci/export/src/CircleBuiltinTypesExtractor.h @@ -552,6 +552,10 @@ class BuiltinOptionsExtractor final { return circle::CreateRmsNormOptions(_builder, node->epsilon()).Union(); } + flatbuffers::Offset visit(luci::CircleRoPE *node) + { + return circle::CreateRoPEOptions(_builder, to_circle_rope(node->mode())).Union(); + } protected: flatbuffers::FlatBufferBuilder &_builder; diff --git a/compiler/luci/export/src/CircleExporterUtils.cpp b/compiler/luci/export/src/CircleExporterUtils.cpp index f6e380d7872..2656f2c2256 100644 --- a/compiler/luci/export/src/CircleExporterUtils.cpp +++ b/compiler/luci/export/src/CircleExporterUtils.cpp @@ -95,6 +95,19 @@ circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode) } } +circle::RoPEMode to_circle_rope(luci::RoPEMode mode) +{ + switch (mode) + { + case luci::RoPEMode::GPT_NEOX: + return circle::RoPEMode::RoPEMode_GPT_NEOX; + case luci::RoPEMode::GPT_J: + return circle::RoPEMode::RoPEMode_GPT_J; + default: + INTERNAL_EXN_V("trying to convert unsupported luci::RoPEMode", oops::to_uint32(mode)); + } +} + circle::FullyConnectedOptionsWeightsFormat to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format) { diff --git a/compiler/luci/export/src/CircleExporterUtils.h b/compiler/luci/export/src/CircleExporterUtils.h index 49822d5d775..c0e8a0ddf1d 100644 --- a/compiler/luci/export/src/CircleExporterUtils.h +++ b/compiler/luci/export/src/CircleExporterUtils.h @@ -35,6 +35,7 @@ namespace luci circle::ActivationFunctionType to_circle_actfunc(luci::FusedActFunc func); circle::TensorType to_circle_tensortype(loco::DataType type); circle::MirrorPadMode to_circle_mirrorpadmode(luci::MirrorPadMode mode); +circle::RoPEMode to_circle_rope(luci::RoPEMode mode); circle::FullyConnectedOptionsWeightsFormat to_circle_weightsformat(luci::CircleFullyConnected::WeightsFormat format); circle::DimensionType to_circle_dimensiontype(luci::DimensionType type); diff --git a/compiler/luci/export/src/CircleOps.lst b/compiler/luci/export/src/CircleOps.lst index 91b079ac91a..6f5787c92f4 100644 --- a/compiler/luci/export/src/CircleOps.lst +++ b/compiler/luci/export/src/CircleOps.lst @@ -142,6 +142,7 @@ CIRCLE_NODE(CircleBCQGather, BuiltinOperator_BCQ_GATHER, BuiltinOptions_BCQGathe CIRCLE_NODE(CircleGRU, BuiltinOperator_GRU, BuiltinOptions_GRUOptions) CIRCLE_NODE(CircleInstanceNorm, BuiltinOperator_INSTANCE_NORM, BuiltinOptions_InstanceNormOptions) CIRCLE_NODE(CircleRmsNorm, BuiltinOperator_RMS_NORM, BuiltinOptions_RmsNormOptions) +CIRCLE_NODE(CircleRoPE, BuiltinOperator_ROPE, BuiltinOptions_RoPEOptions) // Virtual node(s) CIRCLE_VNODE(CircleBidirectionalSequenceLSTMOut) CIRCLE_VNODE(CircleConst)