Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement linalg lowering of diag_embed torch op #2885

Merged
merged 12 commits into from
Mar 22, 2024
2 changes: 2 additions & 0 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ SmallVector<Value>
castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
SmallVectorImpl<Value> &indexValues);

SmallVector<Value> getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, int64_t offset, int64_t dim1, int64_t dim2);
schnkmwt marked this conversation as resolved.
Show resolved Hide resolved

Value getDimOp(OpBuilder &b, Location loc, Value v, int dim);

SmallVector<Value> getTensorSizesUntilDim(OpBuilder &b, Location loc,
Expand Down
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8078,6 +8078,32 @@ def Torch_AtenCosineEmbeddingLossOp : Torch_Op<"aten.cosine_embedding_loss", [
}];
}

def Torch_AtenDiagEmbedOp : Torch_Op<"aten.diag_embed", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::diag_embed : (Tensor, int, int, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$offset,
Torch_IntType:$dim1,
Torch_IntType:$dim2
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenDiagEmbedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenDiagEmbedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
}

def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
104 changes: 104 additions & 0 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,108 @@ class ConvertAtenDiagonalOp : public OpConversionPattern<AtenDiagonalOp> {
};
} // namespace

namespace {
class ConvertAtenDiagEmbedOp
: public OpConversionPattern<AtenDiagEmbedOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenDiagEmbedOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

Location loc = op->getLoc();

Value input = adaptor.getSelf();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
auto resultRank = inputRank+1;

int64_t offset;
if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset)))
return rewriter.notifyMatchFailure(op, "offset is not constant");

int64_t dim1;
if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1)))
return rewriter.notifyMatchFailure(op, "dim1 is not constant");
dim1 = toPositiveDim(dim1, resultRank);
if (!isValidDim(dim1, resultRank))
return rewriter.notifyMatchFailure(op, "dim1 can only be between [" + std::to_string(-resultRank) + "," + std::to_string(resultRank-1) + "]");

int64_t dim2;
if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2)))
return rewriter.notifyMatchFailure(op, "dim2 is not constant");
dim2 = toPositiveDim(dim2, resultRank);
if (!isValidDim(dim2, resultRank))
return rewriter.notifyMatchFailure(op, "dim2 can only be between [" + std::to_string(-resultRank) + "," + std::to_string(resultRank-1) + "]");

if(dim1 == dim2)
return rewriter.notifyMatchFailure(op, "dim1 and dim2 can not be equal");
schnkmwt marked this conversation as resolved.
Show resolved Hide resolved

// add linalg.fill
Type resultElemType = inputType.getElementType();
auto resultShape = getDiagEmbedResultShape(rewriter, loc, input, offset, dim1, dim2);
Value zeroTensor =
createZeroInitTensor(rewriter, loc, resultShape, resultElemType);

// add linalg.generic with diagonal access pattern affine indexing maps
SmallVector<AffineMap> indexingMaps = {
rewriter.getMultiDimIdentityMap(resultRank),
};
SmallVector<utils::IteratorType> iteratorTypes(
resultRank, utils::IteratorType::parallel);
Value resultTensor =
rewriter
.create<linalg::GenericOp>(
loc, zeroTensor.getType(),
ValueRange{}, zeroTensor,
/*indexingMaps=*/indexingMaps,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value dim1Index = b.create<linalg::IndexOp>(loc, dim1);
Value dim2Index = b.create<linalg::IndexOp>(loc, dim2);

// to pick right element from input, first add all dimensions except last one, then last will be either dim1 or dim2 depending upon lower or upper diagonal defined by offset sign
SmallVector<Value> inputIndices;
for(unsigned int i=0; i < resultRank; i++) {
if (i != dim1 && i != dim2) {
inputIndices.push_back(b.create<linalg::IndexOp>(loc, i));
}
}

// adjust output diagonal indices and last input Index based on offset
Value dim1IdxAdjusted;
Value dim2IdxAdjusted;
if (offset < 0) {
Value absOffset = b.create<arith::ConstantIndexOp>(loc, -offset);
dim1IdxAdjusted = dim1Index;
dim2IdxAdjusted = b.create<arith::AddIOp>(loc, dim2Index, absOffset);
inputIndices.push_back(b.create<linalg::IndexOp>(loc, dim2));
}
else {
Value constOffset = b.create<arith::ConstantIndexOp>(loc, offset);
dim1IdxAdjusted = b.create<arith::AddIOp>(loc, dim1Index, constOffset);
dim2IdxAdjusted = dim2Index;
inputIndices.push_back(b.create<linalg::IndexOp>(loc, dim1));
}

Value isDiagonal = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, dim1IdxAdjusted, dim2IdxAdjusted);

Value inputElem = b.create<tensor::ExtractOp>(loc, resultElemType, input, inputIndices);

Value result = rewriter.create<arith::SelectOp>(loc, isDiagonal, inputElem, args[0]);
b.create<linalg::YieldOp>(loc, result);
})
.getResult(0);

RankedTensorType resultType = getTypeConverter()
->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, resultTensor);
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) {
Expand Down Expand Up @@ -2040,4 +2142,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality(
patterns.add<ConvertAtenViewAsRealOp>(typeConverter, context);
target.addIllegalOp<AtenDiagonalOp>();
patterns.add<ConvertAtenDiagonalOp>(typeConverter, context);
target.addIllegalOp<AtenDiagEmbedOp>();
patterns.add<ConvertAtenDiagEmbedOp>(typeConverter, context);
}
27 changes: 27 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,33 @@ castIndexVectorToInt64Vector(OpBuilder &b, Location loc,
return intValues;
}

SmallVector<Value> getDiagEmbedResultShape(OpBuilder &b, Location loc, Value tensor, int64_t offset, int64_t dim1, int64_t dim2) {
schnkmwt marked this conversation as resolved.
Show resolved Hide resolved
auto inputType = tensor.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
auto resultRank = inputRank + 1;

SmallVector<Value> resultShape;
Value constZero = b.create<arith::ConstantIndexOp>(loc, 0);
Value constNegOne = b.create<arith::ConstantIndexOp>(loc, -1);
Value constOffset = b.create<arith::ConstantIndexOp>(loc, offset);
schnkmwt marked this conversation as resolved.
Show resolved Hide resolved
Value isNegOffset = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, constOffset, constZero);
Value mulOffsetNegOne = b.create<arith::MulIOp>(loc, constOffset, constNegOne);
Value absOffset = b.create<arith::SelectOp>(loc, isNegOffset, mulOffsetNegOne, constOffset);
schnkmwt marked this conversation as resolved.
Show resolved Hide resolved

auto lastInputDim = getDimOp(b, loc, tensor, inputRank-1);
Value diagDim = b.create<arith::AddIOp>(loc, lastInputDim, absOffset);

int input_dim_idx = 0;
for (unsigned int i = 0; i < resultRank; i++) {
if (i == dim1 || i == dim2)
resultShape.push_back(diagDim);
else
resultShape.push_back(getDimOp(b, loc, tensor, input_dim_idx++));
}

return resultShape;
}

Value getDimOp(OpBuilder &b, Location loc, Value v, int dim) {
return b.createOrFold<tensor::DimOp>(loc, v, dim);
}
Expand Down
89 changes: 89 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8133,6 +8133,91 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.new_empty_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.diag_embed\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__._diag_embed_shape_helper(%arg0, %arg1, %arg2, %arg3) : (!torch.list<int>, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @__torch__._diag_embed_shape_helper(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %int-1 = torch.constant.int -1\n"
" %true = torch.constant.bool true\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int1 = torch.constant.int 1\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.add.int %0, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %2 = torch.aten.ne.int %arg2, %arg3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %2 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %3 = torch.aten.lt.int %arg2, %1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n"
" %5 = torch.aten.ge.int %arg2, %4 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.aten.lt.int %arg3, %1 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %6 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %7 = torch.aten.neg.int %1 : !torch.int -> !torch.int\n"
" %8 = torch.aten.ge.int %arg3, %7 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.aten.lt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %10 = torch.prim.If %9 -> (!torch.int) {\n"
" %15 = torch.aten.add.int %1, %arg2 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %15 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg2 : !torch.int\n"
" }\n"
" %11 = torch.aten.lt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %12 = torch.prim.If %11 -> (!torch.int) {\n"
" %15 = torch.aten.add.int %1, %arg3 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %15 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg3 : !torch.int\n"
" }\n"
" %13 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" %14 = torch.prim.Loop %1, %true, init(%int0) {\n"
" ^bb0(%arg4: !torch.int, %arg5: !torch.int):\n"
" %15 = torch.prim.ListConstruct %10, %12 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %16 = torch.aten.__contains__.int_list %15, %arg4 : !torch.list<int>, !torch.int -> !torch.bool\n"
" %17 = torch.prim.If %16 -> (!torch.int) {\n"
" %18 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %19 = torch.operator \"prim.abs.int\"(%arg1) : (!torch.int) -> !torch.int\n"
" %20 = torch.aten.add.int %18, %19 : !torch.int, !torch.int -> !torch.int\n"
" %21 = torch.aten.append.t %13, %20 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %arg5 : !torch.int\n"
" } else {\n"
" %18 = torch.aten.__getitem__.t %arg0, %arg5 : !torch.list<int>, !torch.int -> !torch.int\n"
" %19 = torch.aten.append.t %13, %18 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %20 = torch.aten.add.int %arg5, %int1 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %20 : !torch.int\n"
" }\n"
" torch.prim.Loop.condition %true, iter(%17 : !torch.int)\n"
" } : (!torch.int, !torch.bool, !torch.int) -> !torch.int\n"
" return %13 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten._to_copy\"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.bool, %arg6: !torch.optional<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12204,6 +12289,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" return %2 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.diag_embed\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.rand_like\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,33 @@ def _embedding_bag_helper(weight: List[int], indices: List[int],

return output_bag_shape, offset2bag_shape, bag_size_shape, max_indices_shape

# TODO: upstream this
schnkmwt marked this conversation as resolved.
Show resolved Hide resolved
def _diag_embed_shape_helper(self: List[int], offset: int, dim1: int, dim2: int):
self_rank = len(self)
result_rank = self_rank + 1

assert dim1 != dim2
assert dim1 < result_rank
assert dim1 >= -(result_rank)
assert dim2 < result_rank
assert dim2 >= -(result_rank)

if dim1 < 0:
dim1 = result_rank + dim1
if dim2 < 0:
dim2 = result_rank + dim2

result_shape: List[int] = []
input_dim_idx = 0
for i in range(result_rank):
if i in (dim1, dim2):
result_shape.append(self[-1] + abs(offset))
else:
result_shape.append(self[input_dim_idx])
input_dim_idx += 1

return result_shape

def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -1006,6 +1033,9 @@ def aten〇new_empty〡shape(self: List[int], size: List[int], dtype: Optional[i
def aten〇new_empty_strided〡shape(self: List[int], size: List[int], stride: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
return size

def aten〇diag_embed〡shape(self: List[int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> List[int]:
return _diag_embed_shape_helper(self, offset, dim1, dim2)

def aten〇_to_copy〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, non_blocking: bool = False, memory_format: Optional[int] = None) -> List[int]:
return upstream_shape_functions.unary(self)

Expand Down Expand Up @@ -4032,6 +4062,11 @@ def aten〇new_empty_strided〡dtype(self_rank_dtype: Tuple[int, int], size: Lis
self_rank, self_dtype = self_rank_dtype
return self_dtype if dtype is None else dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1))
def aten〇diag_embed〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = -2, dim2: int = -1) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.float16) +
_check_tensors_with_the_same_dtype(num_of_tensors=1, dtype=torch.int32) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::log_sigmoid_backward : (Tensor, Tensor, Tensor) -> (Tensor)")
emit("aten::sigmoid_backward : (Tensor, Tensor) -> (Tensor)")
emit("aten::cosine_embedding_loss : (Tensor, Tensor, Tensor, float, int) -> (Tensor)")
emit("aten::diag_embed : (Tensor, int, int, int) -> (Tensor)")

# Misc tensor ops.
emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)")
Expand Down
Loading
Loading