diff --git a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp index fa3384c57ea5..f61d44400823 100644 --- a/iree/compiler/Dialect/Flow/IR/FlowOps.cpp +++ b/iree/compiler/Dialect/Flow/IR/FlowOps.cpp @@ -827,12 +827,11 @@ static bool canDispatchRegionContainOp(Operation *op) { constantValueAttr.dyn_cast()) { // TODO(GH-4897): Non-splat constants seems to have an issue on the LLLVM // side. Uncomment after that is fixed. - // auto shapedType = constantOp.getType().cast(); - // uint64_t estimatedByteLength = - // (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) - // / 8; - return denseAttr - .isSplat(); // || estimatedByteLength <= 256; // or whatever + auto shapedType = constantOp.getType().cast(); + uint64_t estimatedByteLength = + (shapedType.getNumElements() * shapedType.getElementTypeBitWidth()) / + 8; + return denseAttr.isSplat() || estimatedByteLength <= 256; // or whatever } else if (constantType.isIntOrIndexOrFloat()) { return true; } diff --git a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp index 483c0f3b002a..273d48abadd6 100644 --- a/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp +++ b/iree/compiler/Dialect/HAL/Target/LLVM/LLVMAOTTarget.cpp @@ -103,6 +103,35 @@ class LLVMAOTTargetBackend final : public TargetBackend { llvm::to_vector<8>(moduleOp.getOps()); if (sourceExecutableOps.size() <= 1) return success(); + // Private symbols (i.e. llvm dialect private symbols) get deduped + // incorrectly by the link executables pass even though they should be + // treated as different symbols. For now just change the names of the + // private symbols to avoid conflicts. + unsigned moduleNumber = 0; + for (auto sourceExecutableOp : enumerate(sourceExecutableOps)) { + auto targetOps = llvm::to_vector<4>( + sourceExecutableOp.value().getOps()); + for (auto targetOp : targetOps) { + if (!matchPattern(targetOp.target_backend_filter(), filter_pattern())) { + continue; + } + + auto sourceModuleOp = targetOp.getInnerModule(); + for (auto globalOp : sourceModuleOp.getOps()) { + if (globalOp.linkage() != LLVM::Linkage::Private) { + continue; + } + auto disambiguateName = + llvm::formatv("{0}_{1}", globalOp.sym_name(), moduleNumber).str(); + SymbolTableCollection symbolTable; + SymbolUserMap symbolUsers(symbolTable, sourceModuleOp); + symbolUsers.replaceAllUsesWith(globalOp, disambiguateName); + SymbolTable::setSymbolName(globalOp, disambiguateName); + } + moduleNumber++; + } + } + // Guess a module name, if needed, to make the output files readable. auto moduleName = guessModuleName(moduleOp); diff --git a/iree/test/e2e/linalg_tensor_ops/add.mlir b/iree/test/e2e/linalg_tensor_ops/add.mlir index a9e31dd68ffe..6d5265676d8c 100644 --- a/iree/test/e2e/linalg_tensor_ops/add.mlir +++ b/iree/test/e2e/linalg_tensor_ops/add.mlir @@ -87,4 +87,4 @@ func @cst_plus_tensor() attributes { iree.module.export } { check.expect_eq_const(%1, dense< [[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]]> : tensor<2x2x3xi32>) : tensor<2x2x3xi32> return -} \ No newline at end of file +}