Skip to content

Commit

Permalink
[TOSA] Add Torch to Tosa Legalization for torch.tril (#3678)
Browse files Browse the repository at this point in the history
Change-Id: Ie5ba31a27394c3adcea00266a9d562862dbd8b08

Signed-off-by: Justin Ngo <[email protected]>
  • Loading branch information
justin-ngo-arm authored Sep 5, 2024
1 parent b790061 commit d4b5e05
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 97 deletions.
110 changes: 110 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
#include <optional>

Expand Down Expand Up @@ -5385,6 +5386,114 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
return success();
}

// Template to create support tril mask tensor for aten.tril
// legalization
template <typename T>
Value createTrilMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
int64_t diagonal) {
SmallVector<T> vec;

for (int64_t i = 0; i < h; i++) {
for (int64_t j = 0; j < w; j++) {
// Positive diagonal value includes as many diagonals above the main
// diagonal, while negative diagonal value excludes as many diagonals
// below the main diagonal.
if (i >= j - diagonal) {
vec.push_back(static_cast<T>(1));
} else {
vec.push_back(static_cast<T>(0));
}
}
}

return tosa::getConstTensor<T>(rewriter, op, vec, shape).value();
}

// Function to get tril mask tensor based on input type
// for aten.tril legalization
Value getTrilMask(PatternRewriter &rewriter, Operation *op,
ArrayRef<int64_t> shape, int64_t h, int64_t w,
int64_t diagonal, Type type) {
return TypeSwitch<Type, Value>(type)
.Case<mlir::FloatType>([&](auto) {
return createTrilMask<float>(rewriter, op, shape, h, w, diagonal);
})
.Case<mlir::IntegerType>([&](auto intType) {
switch (intType.getWidth()) {
case 1:
return createTrilMask<bool>(rewriter, op, shape, h, w, diagonal);
case 32:
return createTrilMask<int32_t>(rewriter, op, shape, h, w, diagonal);
case 64:
return createTrilMask<int64_t>(rewriter, op, shape, h, w, diagonal);
}
llvm_unreachable("Invalid integer width");
});
}

// Legalization for aten.tril
template <>
LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
AtenTrilOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto self = adaptor.getSelf();

// Not a ranked tensor type
auto selfType = dyn_cast<RankedTensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor types are supported");

// Rank below 2 not accepted
auto selfRank = selfType.getRank();
if (selfRank <= 1)
return rewriter.notifyMatchFailure(
op, "Rank 0 and 1 are not accepted as they cause underflow");

if (!selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Currently only static shapes are supported");

const TypeConverter *typeConverter = this->getTypeConverter();
RankedTensorType resultType = cast<RankedTensorType>(
typeConverter->convertType(op->getResult(0).getType()));
if (!resultType)
return rewriter.notifyMatchFailure(op, "Result type cannot be empty");

// Get height, width of input tensor, and diagonal arg to create
// a const mask tensor to multiply with input.
// This mask tensor has the same height and width of input tensor
// and consists of 1's for the lower triangle part and 0's for the rest.
// For example, with h=4, w=6, diagonal=1:
// tensor([[1, 1, 0, 0, 0, 0],
// [1, 1, 1, 0, 0, 0],
// [1, 1, 1, 1, 0, 0],
// [1, 1, 1, 1, 1, 0]])
auto selfShape = selfType.getShape();
int64_t h = selfShape[selfRank - 2];
int64_t w = selfShape[selfRank - 1];
int64_t diagonal;

if (!matchPattern(op.getDiagonal(), m_TorchConstantInt(&diagonal)))
return rewriter.notifyMatchFailure(op, "Diagonal value is not an integer");

// Define shape for mask tensor based on rank
SmallVector<int64_t> constShape;
for (auto i = 0; i < selfRank - 2; i++)
constShape.push_back(1);
constShape.push_back(h);
constShape.push_back(w);

Value trilMask = getTrilMask(rewriter, op, constShape, h, w, diagonal,
resultType.getElementType());

rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
/*shift=*/0);

return success();
}

} // namespace

// -----------------------------------------------------------------------------
Expand Down Expand Up @@ -5638,6 +5747,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp);
INSERT_ATENOP_PATTERN(Aten__InterpolateSizeListScaleListOp);
INSERT_ATENOP_PATTERN(AtenTrilOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
6 changes: 4 additions & 2 deletions projects/pt1/e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@
FX_IMPORTER_CRASHING_SET,
FX_IMPORTER_STABLEHLO_XFAIL_SET,
FX_IMPORTER_STABLEHLO_CRASHING_SET,
FX_IMPORTER_TOSA_CRASHING_SET,
FX_IMPORTER_TOSA_XFAIL_SET,
ONNX_TOSA_XFAIL_SET,
ONNX_TOSA_CRASHING_SET,
)

# Import tests to register them in the global registry.
Expand Down Expand Up @@ -191,7 +193,7 @@ def main():
elif args.config == "fx_importer_tosa":
config = FxImporterTestConfig(LinalgOnTensorsTosaBackend(), "tosa")
xfail_set = FX_IMPORTER_TOSA_XFAIL_SET
crashing_set = set()
crashing_set = FX_IMPORTER_TOSA_CRASHING_SET
elif args.config == "torchdynamo":
# TODO: Enanble runtime verification and extend crashing set.
config = TorchDynamoTestConfig(
Expand All @@ -206,7 +208,7 @@ def main():
elif args.config == "onnx_tosa":
config = OnnxBackendTestConfig(LinalgOnTensorsTosaBackend(), output_type="tosa")
xfail_set = ONNX_TOSA_XFAIL_SET
crashing_set = set()
crashing_set = ONNX_TOSA_CRASHING_SET

do_not_attempt = set(
args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []
Expand Down
Loading

0 comments on commit d4b5e05

Please sign in to comment.