-
Notifications
You must be signed in to change notification settings - Fork 603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tile and distribute linalg.generic in DispatchLinalgOnTensor #5159
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -312,6 +312,50 @@ static LogicalResult getConfigForCooperativeMatmul( | |
return success(); | ||
} | ||
|
||
/// Launch config for element-wise linalg.generic. | ||
template <> | ||
LogicalResult getOpLaunchConfig(linalg::GenericOp op, | ||
const spirv::TargetEnv &targetEnv, | ||
const SPIRVCodegenOptions &options, | ||
TileSizesListType &tileSizes, | ||
LaunchConfigInfo &config) { | ||
int64_t subgroupSize = | ||
targetEnv.getResourceLimits().subgroup_size().getValue().getSExtValue(); | ||
config.workgroupSize[0] = subgroupSize; | ||
config.workgroupSize[1] = 1; | ||
config.workgroupSize[2] = 1; | ||
ShapedType outputShape = op.getOutputShapedType(0); | ||
|
||
SmallVector<int64_t, 4> sizes; | ||
// When Vectororization is not enabled we skil the second level of tiling and | ||
// fall back to convertToGPU which will map one element to one thread. To | ||
// avoid a mismatch in the number of workgroup dispatched, we pick a tile size | ||
// to have one element per thread. | ||
if (options.enableVectorization) { | ||
sizes.append({4 * subgroupSize, 2 * subgroupSize}); | ||
} | ||
sizes.push_back(subgroupSize); | ||
// Use the first tile size that can divide the shape. If the shape is not | ||
// aligned on any of the tile sizes pick the smallest tile of one element per | ||
// thread. | ||
int64_t lowerTs = config.workgroupSize[0]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add comments to whats happening here. Hard to parse this a bit. |
||
for (int64_t size : sizes) { | ||
if (outputShape.getShape().back() % size != 0) continue; | ||
lowerTs = size; | ||
break; | ||
} | ||
SmallVector<int64_t, 4> ts; | ||
size_t numLoops = getNumOuterParallelLoops(op); | ||
ts.resize(numLoops, 1); | ||
ts.back() = lowerTs; | ||
tileSizes.emplace_back(ts); | ||
tileSizes.emplace_back(); | ||
ts.back() = lowerTs / subgroupSize; | ||
tileSizes.emplace_back(ts); | ||
config.vectorize = options.enableVectorization; | ||
return success(); | ||
} | ||
|
||
/// Launch configuration for different known GPU configuration. | ||
static LogicalResult getTargetSpecificConfig( | ||
linalg::MatmulOp op, const spirv::TargetEnv &targetEnv, | ||
|
@@ -708,6 +752,27 @@ Optional<LaunchConfig> initGPULaunchConfig( | |
#undef DISPATCH | ||
} | ||
|
||
if (!rootOperation) { | ||
for (linalg::LinalgOp linalgOp : linalgOps) { | ||
if (auto op = dyn_cast<linalg::GenericOp>(linalgOp.getOperation())) { | ||
if (getNumOuterParallelLoops(linalgOp) == 0 || | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we drop the old path, this will have to go through the new path. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we still need this even in the new path right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, bad comment. I was saying that this is basically saying dont do anything with 0-rank operations. But those need to be handled as well (i.e. lowered to loops albeit no-op) at some point. I am guessing this isnt tested on the new path. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I can add a test for that. |
||
llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap &map) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might just be getting mixed with the old paths. This should not be needed actually. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still need this check otherwise some linalg op with weird affine map crash during tiling. I can look at why this happen after if you think this is not expected. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, this is not expected. I dont see why it would carsh (at least would be good to know the affine map. Is this because of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mahesh, yes the problem is with |
||
return !map.isProjectedPermutation(); | ||
})) { | ||
continue; | ||
} | ||
TileSizesListType tileSizesInfo; | ||
if (failed(getOpLaunchConfig(op, targetEnv, options, tileSizesInfo, | ||
config))) { | ||
continue; | ||
} | ||
launchConfig.setTileSizes(op, tileSizesInfo); | ||
launchConfig.setRootOperation(op); | ||
break; | ||
} | ||
} | ||
} | ||
|
||
launchConfig.setWorkgroupSize(config.workgroupSize); | ||
launchConfig.setNumSubgroups(config.numSubgroups); | ||
launchConfig.setVectorize(config.vectorize); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -612,13 +612,13 @@ struct TileAndDistributeOnTensorsPattern | |
SmallVector<Value, 4> count = llvm::to_vector<4>( | ||
llvm::map_range(linalgOp.createLoopRanges(rewriter, loc), | ||
[](Range r) { return r.size; })); | ||
// NOTE: Special treatment for convolution, which have more than 3 parallel | ||
// dimensions. We want to ignore the batch dimension and tile along the | ||
// next three. | ||
// TODO(#5048): figure out a better way to avoid this special case. | ||
if (isa<linalg::ConvInputNHWCFilterHWCFOp, | ||
linalg::DepthwiseConvInputNHWCFilterHWCOp>(op)) { | ||
count.erase(count.begin()); | ||
size_t numParrallelLoops = getNumOuterParallelLoops(op); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh Nice! This was exactly what I was trying to do. Should work well with Hanhan's PR #5136 |
||
// Flow currently allows only 3 level of tiling. If there are more parallel | ||
// dimension drop the higher dimensions. | ||
if (numParrallelLoops > kNumMaxParallelDims) { | ||
count.erase( | ||
count.begin(), | ||
std::next(count.begin(), numParrallelLoops - kNumMaxParallelDims)); | ||
} | ||
count.resize(getNumTilableLoops(op)); | ||
auto workload = convertToWorkload(rewriter, loc, count); | ||
|
@@ -849,6 +849,23 @@ static void decideFusableLinalgOps(FuncOp funcOp) { | |
builder.getI64ArrayAttr(fusionGroups)); | ||
} | ||
} | ||
|
||
// As a second step mark all the element-wise linalg ops not fused as roots | ||
// so that they get tiled and distributed. | ||
for (linalg::LinalgOp linalgOp : linalgOps) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the only reason we have the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct, there is still few cases that don't tile, I can look at those at some point. Yes for all those cases the |
||
Operation *op = linalgOp.getOperation(); | ||
if (!isa<linalg::GenericOp>(op) || | ||
getNumOuterParallelLoops(linalgOp) == 0 || | ||
llvm::any_of(linalgOp.getIndexingMaps(), [](AffineMap &map) { | ||
return !map.isProjectedPermutation(); | ||
})) { | ||
continue; | ||
} | ||
|
||
if (op->hasAttr(kRootOpAttr) || op->hasAttr(kFusionGroupsAttr)) continue; | ||
unsigned currGroupNum = numRootOps++; | ||
op->setAttr(kRootOpAttr, builder.getI64IntegerAttr(currGroupNum)); | ||
} | ||
} | ||
} | ||
|
||
|
@@ -906,11 +923,8 @@ void DispatchLinalgOnTensorsPass::runOnOperation() { | |
// parallel dimensions. We want to ignore the batch dimension and tile | ||
// along the next three. That means setting the first position to zero. | ||
// TODO(#5048): figure out a better way to avoid this special case. | ||
bool isConvOp = isa<linalg::ConvInputNHWCFilterHWCFOp, | ||
linalg::DepthwiseConvInputNHWCFilterHWCOp>(op); | ||
|
||
for (size_t dim = 0; dim < numTiledLoops; ++dim) { | ||
useTileSizes[(isConvOp ? numParallelDims : numTiledLoops) - dim - 1] = | ||
useTileSizes[numParallelDims - dim - 1] = | ||
buildFlowWorkgroupInfoOp<Flow::DispatchWorkgroupSizeOp>(builder, dim); | ||
} | ||
return useTileSizes; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a TODO here. FOr the linalg on tensors path this could actually be anything other than 1. It can be even on the old path with a few changes, but why bother when its going away. For now, leave a TODO and come back to it when we deprecate the old path.