Skip to content

Commit

Permalink
[MLIR][TORCH] Add TorchToTosa lowering for aten.where.self op
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored and AmosLewis committed Oct 4, 2022
1 parent 7da1af8 commit e89a71f
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@
"ElementwiseWhereScalarOtherModule_basic",
"ElementwiseWhereScalarSelfModule_basic",
"ElementwiseWhereSelfModule_basic",
"ElementwiseAtenWhereSelfModule_basic",
"EmptyLikeMemoryFormatModule_basic",
"EmptyLikeModule_defaultDtype",
"EmptyLikeModule_falsePinMemory",
Expand Down
25 changes: 25 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3005,6 +3005,30 @@ LogicalResult ConvertAtenOp<AtenBroadcastToOp>::matchAndRewrite(
"unimplemented: broadcasts other than same rank or zero ranked tensor.");
}


template <>
LogicalResult ConvertAtenOp<AtenWhereSelfOp>::matchAndRewrite(
AtenWhereSelfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
auto condType = adaptor.condition().getType().dyn_cast<TensorType>();
if (!condType)
return rewriter.notifyMatchFailure(
op, "Only tensor types condition are currently supported");

auto outType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, adaptor.condition(),
adaptor.self(), adaptor.other());

return success();
}


template <>
LogicalResult ConvertAtenOp<AtenArangeStartStepOp>::matchAndRewrite(
AtenArangeStartStepOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -3703,6 +3727,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenMaxDimOp);
INSERT_ATENOP_PATTERN(AtenSliceTensorOp);
INSERT_ATENOP_PATTERN(AtenBroadcastToOp);
INSERT_ATENOP_PATTERN(AtenWhereSelfOp);
INSERT_ATENOP_PATTERN(AtenArangeStartStepOp);
#undef INSERT_ATENOP_PATTERN

Expand Down
24 changes: 24 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,30 @@ llvm::Optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
return const_op.getResult();
}

// Template specialization for bool
template <>
llvm::Optional<Value> getConstTensor<bool>(PatternRewriter &rewriter,
Operation *op, ArrayRef<bool> vec,
ArrayRef<int64_t> shape) {
uint64_t num_total_elements = 1;
for (int64_t a : shape) {
num_total_elements *= a;
}

if (vec.size() != num_total_elements) {
op->emitOpError("getConstTensor(): number of elements mismatch.");
return llvm::None;
}

auto const_type =
RankedTensorType::get(shape, rewriter.getI1Type());
auto const_attr = DenseElementsAttr::get(const_type, vec);

auto const_op =
rewriter.create<tosa::ConstOp>(op->getLoc(), const_type, const_attr);
return const_op.getResult();
}

// Template instantiation
template llvm::Optional<Value> getConstTensor<int32_t>(PatternRewriter &,
Operation *,
Expand Down
24 changes: 24 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,30 @@ def ElementwiseTernaryModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseAtenWhereSelfModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([1, 1, 5, 5], torch.bool, True),
([1, 12, 5, 5], torch.float32, True),
([], torch.float32, True),
])
def forward(self, a, b, c):
return torch.ops.aten.where(a, b, c)


@register_test_case(module_factory=lambda: ElementwiseAtenWhereSelfModule())
def ElementwiseAtenWhereSelfModule_basic(module, tu: TestUtils):
module.forward(tu.zeros(1, 1, 5, 5, dtype=torch.bool), tu.rand(1, 12, 5, 5), tu.rand(()))


# ==============================================================================


class ElementwiseWhereSelfModule(torch.nn.Module):

def __init__(self):
Expand Down

0 comments on commit e89a71f

Please sign in to comment.