Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] Add support for aten.selu #2640

Merged
merged 1 commit into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,51 @@ def Torch_AtenLog_Op : Torch_Op<"aten.log_", [
}];
}

def Torch_AtenSeluOp : Torch_Op<"aten.selu", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::selu : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSeluOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSeluOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSelu_Op : Torch_Op<"aten.selu_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::selu_ : (Tensor) -> (Tensor)`";
let arguments = (ins
Torch_NonValueTensorType:$self
);
let results = (outs
Torch_NonValueTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSelu_Op::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenSelu_Op::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}

def Torch_AtenSigmoidOp : Torch_Op<"aten.sigmoid", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
26 changes: 26 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6746,6 +6746,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.gather\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>, %arg3: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg2) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -10434,6 +10438,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.selu\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n"
" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
50 changes: 50 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1937,6 +1937,55 @@ class DecomposeAtenEluOp : public OpRewritePattern<AtenEluOp> {
};
} // namespace

// Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1)))
namespace {
class DecomposeAtenSeluOp : public OpRewritePattern<AtenSeluOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSeluOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.getSelf();
auto resType = op.getType().cast<BaseTensorType>();
if (!resType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}

// Define λ and α
double scale = 1.0507009873554804934193349852946;
double alpha = 1.6732632423543772848170429916717;

// Create constants for λ and α
Value scaleVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(scale));
Value alphaVal = rewriter.create<Torch::ConstantFloatOp>(loc, rewriter.getF64FloatAttr(alpha));

// Create zero tensor for comparison
Value constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero);

// Calculate positive and negative parts
Value constantOne =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
Value positiveOutput = rewriter.create<AtenMaximumOp>(loc, resType, zeroTensor, input);
Value minZeroX =
rewriter.create<AtenMinimumOp>(loc, resType, zeroTensor, input);
Value expInput = rewriter.create<AtenExpOp>(loc, resType, minZeroX);
Value expInputMinusOne = rewriter.create<AtenSubScalarOp>(loc, resType, expInput, constantOne, constantOne);
Value negativeOutput = rewriter.create<AtenMulScalarOp>(loc, resType, expInputMinusOne, alphaVal);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these lines look a bit long. Make sure to format changes using git clang-format. See:

4. Use `git clang-format HEAD~1` to automatically format your commit.


// Multiply the result by λ
Value seluOutput = rewriter.create<AtenAddTensorOp>(
loc, resType, positiveOutput, negativeOutput, constantOne);
seluOutput = rewriter.create<AtenMulScalarOp>(loc, resType, seluOutput, scaleVal);

// Replace the original operation
rewriter.replaceOp(op, seluOutput);
return success();
}
};
} // namespace

namespace {
class DecomposeAtenTOp : public OpRewritePattern<AtenTOp> {
public:
Expand Down Expand Up @@ -6460,6 +6509,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenRandnLikeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSeluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLeakyReluBackwardOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRelu6Op>();
target.addIllegalOp<AtenEluOp>();
target.addIllegalOp<AtenGluOp>();
target.addIllegalOp<AtenSeluOp>();
target.addIllegalOp<AtenHardswishOp>();
target.addIllegalOp<AtenSoftplusOp>();
target.addIllegalOp<AtenSiluOp>();
Expand Down
2 changes: 2 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@
"ElementwiseLeakyReluModule_basic",
"ElementwiseEluModule_basic",
"ElementwiseEluNonDefaultModule_basic",
"ElementwiseSeluModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic",
Expand Down Expand Up @@ -1115,6 +1116,7 @@
"ElementwiseRemainderScalarModule_Int_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"ElementwiseRsqrtModule_basic",
"ElementwiseSeluModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSignModule_basic",
"ElementwiseSqrtIntModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu
def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇selu〡shape(self: List[int]) -> List[int]:
return upstream_shape_functions.unary(self)

def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_grad: bool = False) -> List[int]:
return upstream_shape_functions.unary(index)

Expand Down Expand Up @@ -3066,6 +3069,14 @@ def aten〇elu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float
assert not is_integer_dtype(self_dtype)
return self_dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64}))
def aten〇selu〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
self_rank, self_dtype = self_rank_dtype
assert self_dtype != torch.bool
assert not is_integer_dtype(self_dtype)
return self_dtype

@check_dtype_function(
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def emit_with_mutating_variants(key, **kwargs):
"aten::relu6 : (Tensor) -> (Tensor)",
"aten::leaky_relu : (Tensor, Scalar) -> (Tensor)",
"aten::log : (Tensor) -> (Tensor)",
"aten::selu : (Tensor) -> (Tensor)",
"aten::sigmoid : (Tensor) -> (Tensor)",
"aten::sign : (Tensor) -> (Tensor)",
"aten::sgn : (Tensor) -> (Tensor)",
Expand Down
21 changes: 21 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,27 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseSeluModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
])
def forward(self, x):
return torch.ops.aten.selu(x)

@register_test_case(module_factory=lambda: ElementwiseSeluModule())
def ElementwiseSeluModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, low=-1, high=1))


# ==============================================================================


class ElementwiseSigmoidModule(torch.nn.Module):

def __init__(self):
Expand Down