Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Nirvedh <[email protected]>
  • Loading branch information
nirvedhmeshram committed Nov 11, 2024
1 parent 6a2dc84 commit 9523904
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 139 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,35 @@ getExpandedShape(SmallVector<ReassociationIndices> reIndices,
auto destType = dyn_cast<ShapedType>(dest.getType());
if (!destType)
return failure();
// TODO (nirvedhmeshram): Support rank reducing parallel_insert_slice.
if (reIndices.size() != destType.getShape().size())
return failure();
// Iterator to insert outer sizes.
auto outerShapeIter = expandedShape.begin();
for (auto [reassociations, destSize] :
llvm::zip_equal(reIndices, destType.getShape())) {
// Dynamic destination dims that are not getting expanded are allowed.
if (ShapedType::isDynamic(destSize) && reassociations.size() == 1) {
expandedShape.insert(outerShapeIter++, destSize);
totalInnerSizes.push_back(1);
continue;
}
// Dynamic destination dims that are expanded are currently unsupported but
// this support can be added if needed.
if (ShapedType::isDynamic(destSize)) {
return failure();
}
int64_t totalInnerSize = 1;
for (int64_t reasociation : llvm::drop_begin(reassociations)) {
int64_t expandedInnerSize = sliceStaticSizes[reasociation];
if (ShapedType::isDynamic(expandedInnerSize)) {
// It is not safe to do this pattern if inner dimensions are dynamic.
if (ShapedType::isDynamic(expandedInnerSize))
return failure();
}
expandedShape.push_back(expandedInnerSize);
totalInnerSize *= expandedInnerSize;
}
if (destSize % totalInnerSize != 0) {
if (destSize % totalInnerSize != 0)
return failure();
}
totalInnerSizes.push_back(totalInnerSize);
// insert the outer size in front of any inner sizes.
expandedShape.insert(outerShapeIter, destSize / totalInnerSize);
Expand All @@ -58,29 +71,33 @@ getExpandedShape(SmallVector<ReassociationIndices> reIndices,
/// Check if the users of the expanded scf.forall destination can be updated to
/// account for the expand. If not we bail out. There are two supported users
/// which are extract_slice -> expand_shape with the same exact reassociation
/// map as the collapse op to be hoisted out or a parallel_insert_slice.
/// map as the collapse op to be hoisted out or the root parallel_insert_slice.
static LogicalResult
verifyAndCollectExpandableUsers(Value insertDest,
SmallVector<ReassociationIndices> reIndices,
tensor::ParallelInsertSliceOp parallelInsertOp,
SmallVector<Operation *> &expandableUsers) {
for (Operation *user : insertDest.getUsers()) {
if (auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user)) {
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
if (!expandShapeOp)
return failure();
SmallVector<ReassociationIndices> expandReIndices =
expandShapeOp.getReassociationIndices();
if (reIndices != expandReIndices) {
return failure();
}
expandableUsers.push_back(user);
} else if (auto parallelInsertOp =
dyn_cast<tensor::ParallelInsertSliceOp>(user)) {
if (user == parallelInsertOp) {
expandableUsers.push_back(user);
} else {
return failure();
continue;
}
auto extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
if (!extractSliceOp)
return failure();
if (extractSliceOp.getMixedSizes() != parallelInsertOp.getMixedSizes())
return failure();
if (extractSliceOp.getMixedOffsets() != parallelInsertOp.getMixedOffsets())
return failure();
auto expandShapeOp =
dyn_cast<tensor::ExpandShapeOp>(*extractSliceOp->getUsers().begin());
if (!expandShapeOp)
return failure();
SmallVector<ReassociationIndices> expandReIndices =
expandShapeOp.getReassociationIndices();
if (reIndices != expandReIndices)
return failure();
expandableUsers.push_back(user);
}
return success();
}
Expand All @@ -93,18 +110,6 @@ static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc,
SmallVector<int64_t> totalInnerSizes,
SmallVector<ReassociationIndices> reIndices,
scf::ForallOp forallOp) {
// The user expands and producer collapses need to be
// unflattened e.g %collapsed = tensor.collapse_shape %transposed [[0, 1], [2,
// 3]] : tensor<8x16x8x16xf32> into tensor<128x128xf32> can be unflattened to
// %collapsed = tensor.collapse_shape %transposed [[0], [1], [2], [3]] :
// tensor<8x16x8x16xf32> into tensor<8x16x8x16xf32> and then is consumed by
// the expanded parallel_insert_slice_op.
SmallVector<ReassociationIndices> unFlattenReassociations;
for (ReassociationIndices inds : reIndices) {
for (int64_t i : inds) {
unFlattenReassociations.push_back({i});
}
}
// compute the offsets,sizes,strides in the expanded dimensions.
auto computeExpandedAccess = [&](ArrayRef<OpFoldResult> mixedOffsets,
ShapedType resultType)
Expand Down Expand Up @@ -132,11 +137,8 @@ static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc,

expandedOffsetsIter = expandedOffsets.end();
}
ArrayRef<int64_t> expandedShape = resultType.getShape();
SmallVector<OpFoldResult> expandedSizes;
for (int64_t size : expandedShape) {
expandedSizes.push_back(getAsIndexOpFoldResult(ctx, size));
}
SmallVector<OpFoldResult> expandedSizes =
getAsIndexOpFoldResult(ctx, resultType.getShape());
SmallVector<OpFoldResult> expandedStrides(resultType.getRank(),
rewriter.getIndexAttr(1));
return {expandedOffsets, expandedSizes, expandedStrides};
Expand All @@ -149,28 +151,20 @@ static void expandVerifiedUsers(PatternRewriter &rewriter, Location loc,
RankedTensorType resultType = expandShapeOp.getResultType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(extractSliceOp.getMixedOffsets(), resultType);
rewriter.setInsertionPoint(extractSliceOp);
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSliceOp, resultType, extractSliceOp.getSource(),
expandedOffsets, expandedSizes, expandedStrides);
rewriter.setInsertionPoint(expandShapeOp);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
expandShapeOp, resultType, expandShapeOp.getSrc(),
unFlattenReassociations);
} else if (auto parallelInsertOp =
dyn_cast<tensor::ParallelInsertSliceOp>(user)) {
auto collapseShapeOp =
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
RankedTensorType resultType = collapseShapeOp.getSrcType();
auto [expandedOffsets, expandedSizes, expandedStrides] =
computeExpandedAccess(parallelInsertOp.getMixedOffsets(), resultType);

rewriter.setInsertionPoint(collapseShapeOp);
auto newCollapseOp = rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>(
collapseShapeOp, collapseShapeOp.getSrcType(),
collapseShapeOp.getSrc(), unFlattenReassociations);
rewriter.setInsertionPoint(parallelInsertOp);
rewriter.replaceOpWithNewOp<tensor::ParallelInsertSliceOp>(
parallelInsertOp, newCollapseOp.getResult(),
parallelInsertOp, collapseShapeOp.getSrc(),
parallelInsertOp.getDest(), expandedOffsets, expandedSizes,
expandedStrides);
}
Expand All @@ -190,9 +184,8 @@ struct ExpandDestinationForallOp final
auto collapseOp =
parallelInsertOp.getSource().getDefiningOp<tensor::CollapseShapeOp>();
// No collapse op to hoist out.
if (!collapseOp) {
if (!collapseOp)
return failure();
}

// Ignore trivially foldable collapse ops.
if (collapseOp.getSrcType().getRank() ==
Expand All @@ -205,10 +198,12 @@ struct ExpandDestinationForallOp final

// Get the enclosing scf.forall op.
OpResult tiedResult = parallelInsertOp.getTiedOpResult();
int64_t tiedResultIdx = tiedResult.getResultNumber();

auto forallOp = dyn_cast<scf::ForallOp>(tiedResult.getOwner());
if (!forallOp) {
if (!forallOp)
return failure();
}

// This allows us to assume that the extract/inserts in the loop are
// disjoint and makes the application of this pattern safe.
if (!forallOpHasMappingType<IREE::Codegen::WorkgroupMappingAttr>(
Expand All @@ -218,9 +213,6 @@ struct ExpandDestinationForallOp final
// This pattern only supports forall ops with single
// output.
SmallVector<Value> forallOutputs(forallOp.getOutputs());
if (forallOutputs.size() != 1) {
return failure();
}

SmallVector<ReassociationIndices> reIndices =
collapseOp.getReassociationIndices();
Expand All @@ -239,7 +231,7 @@ struct ExpandDestinationForallOp final
// such users.
SmallVector<Operation *> expandableUsers;
if (failed(verifyAndCollectExpandableUsers(
insertDest, collapseOp.getReassociationIndices(),
insertDest, collapseOp.getReassociationIndices(), parallelInsertOp,
expandableUsers))) {
return failure();
}
Expand All @@ -250,40 +242,43 @@ struct ExpandDestinationForallOp final
reIndices, forallOp);
rewriter.setInsertionPoint(forallOp);

Operation *outOp = forallOutputs[0].getDefiningOp();
if (!outOp) {
return failure();
}

// Create the expand -> new scf.forall -> collapse chain.
Type expandedDestType = RankedTensorType::get(
expandedDestShape,
cast<ShapedType>(outOp->getResult(0).getType()).getElementType());
auto expandedDestType =
cast<RankedTensorType>(forallOutputs[tiedResultIdx].getType())
.clone(expandedDestShape);
auto expandedDest = rewriter.create<tensor::ExpandShapeOp>(
loc, expandedDestType, outOp->getResult(0), reIndices);
loc, expandedDestType, forallOutputs[tiedResultIdx], reIndices);

forallOutputs[tiedResultIdx] = expandedDest;

scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
forallOp.getMixedStep(), ValueRange{expandedDest},
forallOp.getMappingAttr());
forallOp.getMixedStep(), forallOutputs, forallOp.getMappingAttr());

auto collapsedResultOp = rewriter.create<tensor::CollapseShapeOp>(
loc, cast<ShapedType>(forallOp->getResult(0).getType()),
newForallOp->getResult(0), reIndices);
loc, cast<ShapedType>(forallOp->getResult(tiedResultIdx).getType()),
newForallOp->getResult(tiedResultIdx), reIndices);

// Merge the old scf.forall block which has the expanded users into the new
// scf.forall which has the expanded destination.
SmallVector<Value> argReplacements(newForallOp.getInductionVars());
for (auto forallIterArg : newForallOp.getRegionIterArgs()) {
argReplacements.push_back(forallIterArg);
}
argReplacements.append(newForallOp.getRegionIterArgs().begin(),
newForallOp.getRegionIterArgs().end());
scf::InParallelOp parallelTerminator = newForallOp.getTerminator();
parallelTerminator->erase();
rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(),
argReplacements);

// Replaces the uses of the old scf.forall with the new scf.forall
forallOp->getResult(0).replaceAllUsesWith(collapsedResultOp->getResult(0));
for (int idx = 0; idx < forallOp->getNumResults(); ++idx) {
if (idx == tiedResultIdx) {
forallOp->getResult(idx).replaceAllUsesWith(
collapsedResultOp->getResult(0));
} else {
forallOp->getResult(idx).replaceAllUsesWith(
newForallOp->getResult(idx));
}
}
return success();
}
};
Expand Down
Loading

0 comments on commit 9523904

Please sign in to comment.