Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] Dialect conversion: Make materializations optional #104668

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Aug 17, 2024

This commit makes source/target/argument materializations (via the TypeConverter API) optional.

By default (ConversionConfig::buildMaterializations = true), the dialect conversion infrastructure tries to legalize all unresolved materializations right after the main transformation process has succeeded. If at least one unresolved materialization fails to resolve, the dialect conversion fails. (With an error message such as failed to legalize unresolved materialization ....) Automatic materializations through the TypeConverter API can now be deactivated. In that case, every unresolved materialization will show up as a builtin.unrealized_conversion_cast op in the output IR.

There used to be a complex and error-prone analysis in the dialect conversion that predicted the future uses of unresolved materializations. Based on that logic, some casts (that were deemed to unnecessary) were folded. This analysis was needed because folding happened at a point of time when some IR changes (e.g., op replacements) had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens after all other IR changes have been materialized and the uses can directly be queried from the IR. This simplifies the analysis significantly. And certain helper data structures such as inverseMapping are no longer needed for the analysis. The folding itself is done by reconcileUnrealizedCasts (which also exists as a standalone pass).

After casts have been folded, the remaining casts are materialized through the TypeConverter, as usual. This last step can be deactivated in the ConversionConfig.

ConversionConfig::buildMaterializations = false can be used to debug error messages such as failed to legalize unresolved materialization .... (It is also useful in case automatic materializations are not needed.) The materializations that failed to resolve can then be seen as builtin.unrealized_conversion_cast ops in the resulting IR. (This is better than running with -debug, because -debug shows IR where some IR changes have not been materialized yet.)

@matthias-springer matthias-springer force-pushed the users/matthias-springer/optional_materializations branch from 23c0a47 to d20fdc5 Compare August 17, 2024 10:46
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/move_reconcile_unre August 17, 2024 10:46
Copy link

github-actions bot commented Aug 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/optional_materializations branch 3 times, most recently from 26c00f2 to 2658d44 Compare August 17, 2024 11:24
@matthias-springer matthias-springer changed the title [mlir][Transforms][WIP] Dialect conversion: Make materializations optional [mlir][Transforms] Dialect conversion: Make materializations optional Aug 17, 2024
@matthias-springer matthias-springer force-pushed the users/matthias-springer/optional_materializations branch from 2658d44 to f72f642 Compare August 17, 2024 17:22
@matthias-springer matthias-springer marked this pull request as ready for review August 17, 2024 17:45
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:gpu mlir mlir:bufferization Bufferization infrastructure labels Aug 17, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Aug 17, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Matthias Springer (matthias-springer)

Changes

This commit makes source/target/argument materializations (via the TypeConverter API) optional.

By default (ConversionConfig::buildMaterializations = true), the dialect conversion infrastructure tries to legalize all unresolved materializations right after the main transformation process has succeeded. If at least one unresolved materialization fails to resolve, the dialect conversion fails. (With an error message such as failed to legalize unresolved materialization ....)

Automatic materializations through the TypeConverter API can now be deactivated. In that case, every unresolved materialization will show up as a builtin.unrealized_conversion_cast op in the output IR.

There used to be a complex and error-prone analysis in the dialect conversion that predicted the future uses of unresolved materializations. Based on that logic, some casts (that were deemed to unnecessary) were folded. This analysis was needed because folding happened at a point of time when some IR changes (e.g., op replacements) had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens after all other IR changes have been materialized and the uses can directly be queried from the IR. This simplifies the analysis significantly. And certain helper data structures such as inverseMapping are no longer needed for the analysis. The folding itself is done by reconcileUnrealizedCasts (which also exists as a standalone pass).

After casts have been folded, the remaining casts are materialized through the TypeConverter, as usual. This last step can be deactivated in the ConversionConfig.

ConversionConfig::buildMaterializations = false can be used to debug error messages such as failed to legalize unresolved materialization .... (It is also useful in case automatic materializations are not needed.) The materializations that failed to resolve can then be seen as builtin.unrealized_conversion_cast ops in the resulting IR. (This is better than running with -debug, because -debug shows IR where some IR changes have not been materialized yet.)


Patch is 30.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104668.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+11)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+89-294)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+2-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+1)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+5-1)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 60113bdef16a23..5f680e8eca7559 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1124,6 +1124,17 @@ struct ConversionConfig {
   // already been modified) and iterators into past IR state cannot be
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
+
+  /// If set to "true", the dialect conversion attempts to build source/target/
+  /// argument materializations through the type converter API in lieu of
+  /// builtin.unrealized_conversion_cast ops. The conversion process fails if
+  /// at least one materialization could not be built.
+  ///
+  /// If set to "false", the dialect conversion does not does not build any
+  /// custom materializations and instead inserts
+  /// builtin.unrealized_conversion_cast ops to ensure that the resulting IR
+  /// is valid.
+  bool buildMaterializations = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6238a257b2ffda..f710116186e741 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -702,14 +702,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
   }
 
-  UnrealizedConversionCastOp getOperation() const {
-    return cast<UnrealizedConversionCastOp>(op);
-  }
-
   void rollback() override;
 
-  void cleanup(RewriterBase &rewriter) override;
-
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
     return converterAndKind.getPointer();
@@ -766,7 +760,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), config(config) {}
+      : context(ctx), eraseRewriter(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -834,6 +828,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   //===--------------------------------------------------------------------===//
   // Materializations
   //===--------------------------------------------------------------------===//
+
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
@@ -912,6 +907,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
+  /// A rewriter that keeps track of ops/block that were already erased and
+  /// skips duplicate op/block erasures. This rewriter is used during the
+  /// "cleanup" phase.
+  SingleEraseRewriter eraseRewriter;
+
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1058,10 +1058,6 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
-void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
-  rewriter.eraseOp(op);
-}
-
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
   IRRewriter rewriter(context, config.listener);
@@ -1069,7 +1065,6 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit(rewriter);
 
   // Clean up all rewrites.
-  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2354,12 +2349,6 @@ struct OperationConverter {
       ConversionPatternRewriterImpl &rewriterImpl,
       DenseMap<Value, SmallVector<Value>> &inverseMapping);
 
-  /// Legalize any unresolved type materializations.
-  LogicalResult legalizeUnresolvedMaterializations(
-      ConversionPatternRewriter &rewriter,
-      ConversionPatternRewriterImpl &rewriterImpl,
-      DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
   /// Legalize an operation result that was marked as "erased".
   LogicalResult
   legalizeErasedResult(Operation *op, OpResult result,
@@ -2406,6 +2395,57 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
   return success();
 }
 
+static LogicalResult
+legalizeUnresolvedMaterialization(RewriterBase &rewriter,
+                                  UnresolvedMaterializationRewrite *rewrite) {
+  UnrealizedConversionCastOp op =
+      cast<UnrealizedConversionCastOp>(rewrite->getOperation());
+  assert(!op.use_empty() &&
+         "expected that dead materializations have already been DCE'd");
+  Operation::operand_range inputOperands = op.getOperands();
+  Type outputType = op.getResultTypes()[0];
+
+  // Try to materialize the conversion.
+  if (const TypeConverter *converter = rewrite->getConverter()) {
+    rewriter.setInsertionPoint(op);
+    Value newMaterialization;
+    switch (rewrite->getMaterializationKind()) {
+    case MaterializationKind::Argument:
+      // Try to materialize an argument conversion.
+      newMaterialization = converter->materializeArgumentConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      if (newMaterialization)
+        break;
+      // If an argument materialization failed, fallback to trying a target
+      // materialization.
+      [[fallthrough]];
+    case MaterializationKind::Target:
+      newMaterialization = converter->materializeTargetConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    }
+    if (newMaterialization) {
+      assert(newMaterialization.getType() == outputType &&
+             "materialization callback produced value of incorrect type");
+      rewriter.replaceOp(op, newMaterialization);
+      return success();
+    }
+  }
+
+  InFlightDiagnostic diag = op->emitError()
+                            << "failed to legalize unresolved materialization "
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
+                            << " that remained live after conversion";
+  diag.attachNote(op->getUsers().begin()->getLoc())
+      << "see existing live user here: " << *op->getUsers().begin();
+  return failure();
+}
+
 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
@@ -2447,6 +2487,37 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   } else {
     rewriterImpl.applyRewrites();
   }
+
+  // Gather all unresolved materializations.
+  SmallVector<UnrealizedConversionCastOp> allCastOps;
+  DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
+  for (auto &rewrite : rewriterImpl.rewrites) {
+    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
+    if (!mat)
+      continue;
+    if (rewriterImpl.eraseRewriter.erased.contains(mat->getOperation()))
+      continue;
+    allCastOps.push_back(cast<UnrealizedConversionCastOp>(mat->getOperation()));
+    rewriteMap[mat->getOperation()] = mat;
+  }
+
+  // Reconcile all UnrealizedConversionCastOps that were inserted by the
+  // dialect conversion frameworks. (Not the one that were inserted by
+  // patterns.)
+  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
+  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+
+  // Try to legalize all unresolved materializations.
+  if (config.buildMaterializations) {
+    IRRewriter rewriter(rewriterImpl.context, config.listener);
+    for (UnrealizedConversionCastOp castOp : remainingCastOps) {
+      auto it = rewriteMap.find(castOp.getOperation());
+      assert(it != rewriteMap.end() && "inconsistent state");
+      if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+        return failure();
+    }
+  }
+
   return success();
 }
 
@@ -2460,9 +2531,6 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
                                             inverseMapping)))
     return failure();
-  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
-                                                inverseMapping)))
-    return failure();
   return success();
 }
 
@@ -2577,279 +2645,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
   return success();
 }
 
-/// Replace the results of a materialization operation with the given values.
-static void
-replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
-                       ResultRange matResults, ValueRange values,
-                       DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  matResults.replaceAllUsesWith(values);
-
-  // For each of the materialization results, update the inverse mappings to
-  // point to the replacement values.
-  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
-    auto inverseMapIt = inverseMapping.find(matResult);
-    if (inverseMapIt == inverseMapping.end())
-      continue;
-
-    // Update the reverse mapping, or remove the mapping if we couldn't update
-    // it. Not being able to update signals that the mapping would have become
-    // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
-    // propagated through temporary materializations. We simply drop the
-    // mapping, and let the post-conversion replacement logic handle updating
-    // uses.
-    for (Value inverseMapVal : inverseMapIt->second)
-      if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
-        rewriterImpl.mapping.erase(inverseMapVal);
-  }
-}
-
-/// Compute all of the unresolved materializations that will persist beyond the
-/// conversion process, and require inserting a proper user materialization for.
-static void computeNecessaryMaterializations(
-    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
-        &materializationOps,
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping,
-    SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
-  // Helper function to check if the given value or a not yet materialized
-  // replacement of the given value is live.
-  // Note: `inverseMapping` maps from replaced values to original values.
-  auto isLive = [&](Value value) {
-    auto findFn = [&](Operation *user) {
-      auto matIt = materializationOps.find(user);
-      if (matIt != materializationOps.end())
-        return !necessaryMaterializations.count(matIt->second);
-      return rewriterImpl.isOpIgnored(user);
-    };
-    // A worklist is needed because a value may have gone through a chain of
-    // replacements and each of the replaced values may have live users.
-    SmallVector<Value> worklist;
-    worklist.push_back(value);
-    while (!worklist.empty()) {
-      Value next = worklist.pop_back_val();
-      if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
-        return true;
-      // This value may be replacing another value that has a live user.
-      llvm::append_range(worklist, inverseMapping.lookup(next));
-    }
-    return false;
-  };
-
-  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
-      [&](Value invalidRoot, Value value, Type type) {
-        // Check to see if the input operation was remapped to a variant of the
-        // output.
-        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
-        if (remappedValue.getType() == type && remappedValue != invalidRoot)
-          return remappedValue;
-
-        // Check to see if the input is a materialization operation that
-        // provides an inverse conversion. We just check blindly for
-        // UnrealizedConversionCastOp here, but it has no effect on correctness.
-        auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-        if (inputCastOp && inputCastOp->getNumOperands() == 1)
-          return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
-                                     type);
-
-        return Value();
-      };
-
-  SetVector<UnresolvedMaterializationRewrite *> worklist;
-  for (auto &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    materializationOps.try_emplace(mat->getOperation(), mat);
-    worklist.insert(mat);
-  }
-  while (!worklist.empty()) {
-    UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
-    UnrealizedConversionCastOp op = mat->getOperation();
-
-    // We currently only handle target materializations here.
-    assert(op->getNumResults() == 1 && "unexpected materialization type");
-    OpResult opResult = op->getOpResult(0);
-    Type outputType = opResult.getType();
-    Operation::operand_range inputOperands = op.getOperands();
-
-    // Try to forward propagate operands for user conversion casts that result
-    // in the input types of the current cast.
-    for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
-      auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
-      if (!castOp)
-        continue;
-      if (castOp->getResultTypes() == inputOperands.getTypes()) {
-        replaceMaterialization(rewriterImpl, user->getResults(), inputOperands,
-                               inverseMapping);
-        necessaryMaterializations.remove(materializationOps.lookup(user));
-      }
-    }
-
-    // Try to avoid materializing a resolved materialization if possible.
-    // Handle the case of a 1-1 materialization.
-    if (inputOperands.size() == 1) {
-      // Check to see if the input operation was remapped to a variant of the
-      // output.
-      Value remappedValue =
-          lookupRemappedValue(opResult, inputOperands[0], outputType);
-      if (remappedValue && remappedValue != opResult) {
-        replaceMaterialization(rewriterImpl, opResult, remappedValue,
-                               inverseMapping);
-        necessaryMaterializations.remove(mat);
-        continue;
-      }
-    } else {
-      // TODO: Avoid materializing other types of conversions here.
-    }
-
-    // If the materialization does not have any live users, we don't need to
-    // generate a user materialization for it.
-    bool isMaterializationLive = isLive(opResult);
-    if (!isMaterializationLive)
-      continue;
-    if (!necessaryMaterializations.insert(mat))
-      continue;
-
-    // Reprocess input materializations to see if they have an updated status.
-    for (Value input : inputOperands) {
-      if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
-        if (auto *mat = materializationOps.lookup(parentOp))
-          worklist.insert(mat);
-      }
-    }
-  }
-}
-
-/// Legalize the given unresolved materialization. Returns success if the
-/// materialization was legalized, failure otherise.
-static LogicalResult legalizeUnresolvedMaterialization(
-    UnresolvedMaterializationRewrite &mat,
-    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
-        &materializationOps,
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  auto findLiveUser = [&](auto &&users) {
-    auto liveUserIt = llvm::find_if_not(
-        users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
-    return liveUserIt == users.end() ? nullptr : *liveUserIt;
-  };
-
-  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
-      [&](Value value, Type type) {
-        // Check to see if the input operation was remapped to a variant of the
-        // output.
-        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
-        if (remappedValue.getType() == type)
-          return remappedValue;
-        return Value();
-      };
-
-  UnrealizedConversionCastOp op = mat.getOperation();
-  if (!rewriterImpl.ignoredOps.insert(op))
-    return success();
-
-  // We currently only handle target materializations here.
-  OpResult opResult = op->getOpResult(0);
-  Operation::operand_range inputOperands = op.getOperands();
-  Type outputType = opResult.getType();
-
-  // If any input to this materialization is another materialization, resolve
-  // the input first.
-  for (Value value : op->getOperands()) {
-    auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
-    if (!valueCast)
-      continue;
-
-    auto matIt = materializationOps.find(valueCast);
-    if (matIt != materializationOps.end())
-      if (failed(legalizeUnresolvedMaterialization(
-              *matIt->second, materializationOps, rewriter, rewriterImpl,
-              inverseMapping)))
-        return failure();
-  }
-
-  // Perform a last ditch attempt to avoid materializing a resolved
-  // materialization if possible.
-  // Handle the case of a 1-1 materialization.
-  if (inputOperands.size() == 1) {
-    // Check to see if the input operation was remapped to a variant of the
-    // output.
-    Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
-    if (remappedValue && remappedValue != opResult) {
-      replaceMaterialization(rewriterImpl, opResult, remappedValue,
-                             inverseMapping);
-      return success();
-    }
-  } else {
-    // TODO: Avoid materializing other types of conversions here.
-  }
-
-  // Try to materialize the conversion.
-  if (const TypeConverter *converter = mat.getConverter()) {
-    rewriter.setInsertionPoint(op);
-    Value newMaterialization;
-    switch (mat.getMaterializationKind()) {
-    case MaterializationKind::Argument:
-      // Try to materialize an argument conversion.
-      newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      if (newMaterialization)
-        break;
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
-    case MaterializationKind::Target:
-      newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    case MaterializationKind::Source:
-      newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    }
-    if (newMaterialization) {
-      assert(newMaterialization.getType() == outputType &&
-             "materialization callback produced value of incorrect type");
-      replaceMaterialization(rewriterImpl, opResult, newMaterialization,
-                             inverseMapping);
-      return success();
-    }
-  }
-
-  InFlightDiagnostic diag = op->emitError()
-                            << "failed to legalize unresolved materialization "
-                               "from ("
-                            << inputOperands.getTypes() << ") to " << outputType
-                            << " that remained live after conversion";
-  if (Operation *liveUser = findLiveUser(op->getUsers())) {
-    diag.attachNote(liveUser->getLoc())
-        << "see existing live user here: " << *liveUser;
-  }
-  return failure();
-}
-
-LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  // As an initial step, compute all of the inserted materializations that we
-  // expect to persist beyond the conversion process.
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
-  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
-  com...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Aug 17, 2024

@llvm/pr-subscribers-mlir-bufferization

Author: Matthias Springer (matthias-springer)

Changes

This commit makes source/target/argument materializations (via the TypeConverter API) optional.

By default (ConversionConfig::buildMaterializations = true), the dialect conversion infrastructure tries to legalize all unresolved materializations right after the main transformation process has succeeded. If at least one unresolved materialization fails to resolve, the dialect conversion fails. (With an error message such as failed to legalize unresolved materialization ....)

Automatic materializations through the TypeConverter API can now be deactivated. In that case, every unresolved materialization will show up as a builtin.unrealized_conversion_cast op in the output IR.

There used to be a complex and error-prone analysis in the dialect conversion that predicted the future uses of unresolved materializations. Based on that logic, some casts (that were deemed to unnecessary) were folded. This analysis was needed because folding happened at a point of time when some IR changes (e.g., op replacements) had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens after all other IR changes have been materialized and the uses can directly be queried from the IR. This simplifies the analysis significantly. And certain helper data structures such as inverseMapping are no longer needed for the analysis. The folding itself is done by reconcileUnrealizedCasts (which also exists as a standalone pass).

After casts have been folded, the remaining casts are materialized through the TypeConverter, as usual. This last step can be deactivated in the ConversionConfig.

ConversionConfig::buildMaterializations = false can be used to debug error messages such as failed to legalize unresolved materialization .... (It is also useful in case automatic materializations are not needed.) The materializations that failed to resolve can then be seen as builtin.unrealized_conversion_cast ops in the resulting IR. (This is better than running with -debug, because -debug shows IR where some IR changes have not been materialized yet.)


Patch is 30.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/104668.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Transforms/DialectConversion.h (+11)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+89-294)
  • (modified) mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir (+2-3)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/finalizing-bufferize.mlir (+1)
  • (modified) mlir/test/Transforms/test-legalize-type-conversion.mlir (+5-1)
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 60113bdef16a23..5f680e8eca7559 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1124,6 +1124,17 @@ struct ConversionConfig {
   // already been modified) and iterators into past IR state cannot be
   // represented at the moment.
   RewriterBase::Listener *listener = nullptr;
+
+  /// If set to "true", the dialect conversion attempts to build source/target/
+  /// argument materializations through the type converter API in lieu of
+  /// builtin.unrealized_conversion_cast ops. The conversion process fails if
+  /// at least one materialization could not be built.
+  ///
+  /// If set to "false", the dialect conversion does not does not build any
+  /// custom materializations and instead inserts
+  /// builtin.unrealized_conversion_cast ops to ensure that the resulting IR
+  /// is valid.
+  bool buildMaterializations = true;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6238a257b2ffda..f710116186e741 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -702,14 +702,8 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::UnresolvedMaterialization;
   }
 
-  UnrealizedConversionCastOp getOperation() const {
-    return cast<UnrealizedConversionCastOp>(op);
-  }
-
   void rollback() override;
 
-  void cleanup(RewriterBase &rewriter) override;
-
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
     return converterAndKind.getPointer();
@@ -766,7 +760,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), config(config) {}
+      : context(ctx), eraseRewriter(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -834,6 +828,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   //===--------------------------------------------------------------------===//
   // Materializations
   //===--------------------------------------------------------------------===//
+
   /// Build an unresolved materialization operation given an output type and set
   /// of input operands.
   Value buildUnresolvedMaterialization(MaterializationKind kind,
@@ -912,6 +907,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
+  /// A rewriter that keeps track of ops/block that were already erased and
+  /// skips duplicate op/block erasures. This rewriter is used during the
+  /// "cleanup" phase.
+  SingleEraseRewriter eraseRewriter;
+
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1058,10 +1058,6 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
-void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
-  rewriter.eraseOp(op);
-}
-
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
   IRRewriter rewriter(context, config.listener);
@@ -1069,7 +1065,6 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrite->commit(rewriter);
 
   // Clean up all rewrites.
-  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2354,12 +2349,6 @@ struct OperationConverter {
       ConversionPatternRewriterImpl &rewriterImpl,
       DenseMap<Value, SmallVector<Value>> &inverseMapping);
 
-  /// Legalize any unresolved type materializations.
-  LogicalResult legalizeUnresolvedMaterializations(
-      ConversionPatternRewriter &rewriter,
-      ConversionPatternRewriterImpl &rewriterImpl,
-      DenseMap<Value, SmallVector<Value>> &inverseMapping);
-
   /// Legalize an operation result that was marked as "erased".
   LogicalResult
   legalizeErasedResult(Operation *op, OpResult result,
@@ -2406,6 +2395,57 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
   return success();
 }
 
+static LogicalResult
+legalizeUnresolvedMaterialization(RewriterBase &rewriter,
+                                  UnresolvedMaterializationRewrite *rewrite) {
+  UnrealizedConversionCastOp op =
+      cast<UnrealizedConversionCastOp>(rewrite->getOperation());
+  assert(!op.use_empty() &&
+         "expected that dead materializations have already been DCE'd");
+  Operation::operand_range inputOperands = op.getOperands();
+  Type outputType = op.getResultTypes()[0];
+
+  // Try to materialize the conversion.
+  if (const TypeConverter *converter = rewrite->getConverter()) {
+    rewriter.setInsertionPoint(op);
+    Value newMaterialization;
+    switch (rewrite->getMaterializationKind()) {
+    case MaterializationKind::Argument:
+      // Try to materialize an argument conversion.
+      newMaterialization = converter->materializeArgumentConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      if (newMaterialization)
+        break;
+      // If an argument materialization failed, fallback to trying a target
+      // materialization.
+      [[fallthrough]];
+    case MaterializationKind::Target:
+      newMaterialization = converter->materializeTargetConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    case MaterializationKind::Source:
+      newMaterialization = converter->materializeSourceConversion(
+          rewriter, op->getLoc(), outputType, inputOperands);
+      break;
+    }
+    if (newMaterialization) {
+      assert(newMaterialization.getType() == outputType &&
+             "materialization callback produced value of incorrect type");
+      rewriter.replaceOp(op, newMaterialization);
+      return success();
+    }
+  }
+
+  InFlightDiagnostic diag = op->emitError()
+                            << "failed to legalize unresolved materialization "
+                               "from ("
+                            << inputOperands.getTypes() << ") to " << outputType
+                            << " that remained live after conversion";
+  diag.attachNote(op->getUsers().begin()->getLoc())
+      << "see existing live user here: " << *op->getUsers().begin();
+  return failure();
+}
+
 LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   if (ops.empty())
     return success();
@@ -2447,6 +2487,37 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   } else {
     rewriterImpl.applyRewrites();
   }
+
+  // Gather all unresolved materializations.
+  SmallVector<UnrealizedConversionCastOp> allCastOps;
+  DenseMap<Operation *, UnresolvedMaterializationRewrite *> rewriteMap;
+  for (auto &rewrite : rewriterImpl.rewrites) {
+    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
+    if (!mat)
+      continue;
+    if (rewriterImpl.eraseRewriter.erased.contains(mat->getOperation()))
+      continue;
+    allCastOps.push_back(cast<UnrealizedConversionCastOp>(mat->getOperation()));
+    rewriteMap[mat->getOperation()] = mat;
+  }
+
+  // Reconcile all UnrealizedConversionCastOps that were inserted by the
+  // dialect conversion frameworks. (Not the one that were inserted by
+  // patterns.)
+  SmallVector<UnrealizedConversionCastOp> remainingCastOps;
+  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
+
+  // Try to legalize all unresolved materializations.
+  if (config.buildMaterializations) {
+    IRRewriter rewriter(rewriterImpl.context, config.listener);
+    for (UnrealizedConversionCastOp castOp : remainingCastOps) {
+      auto it = rewriteMap.find(castOp.getOperation());
+      assert(it != rewriteMap.end() && "inconsistent state");
+      if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
+        return failure();
+    }
+  }
+
   return success();
 }
 
@@ -2460,9 +2531,6 @@ OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
   if (failed(legalizeConvertedOpResultTypes(rewriter, rewriterImpl,
                                             inverseMapping)))
     return failure();
-  if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl,
-                                                inverseMapping)))
-    return failure();
   return success();
 }
 
@@ -2577,279 +2645,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
   return success();
 }
 
-/// Replace the results of a materialization operation with the given values.
-static void
-replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl,
-                       ResultRange matResults, ValueRange values,
-                       DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  matResults.replaceAllUsesWith(values);
-
-  // For each of the materialization results, update the inverse mappings to
-  // point to the replacement values.
-  for (auto [matResult, newValue] : llvm::zip(matResults, values)) {
-    auto inverseMapIt = inverseMapping.find(matResult);
-    if (inverseMapIt == inverseMapping.end())
-      continue;
-
-    // Update the reverse mapping, or remove the mapping if we couldn't update
-    // it. Not being able to update signals that the mapping would have become
-    // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
-    // propagated through temporary materializations. We simply drop the
-    // mapping, and let the post-conversion replacement logic handle updating
-    // uses.
-    for (Value inverseMapVal : inverseMapIt->second)
-      if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
-        rewriterImpl.mapping.erase(inverseMapVal);
-  }
-}
-
-/// Compute all of the unresolved materializations that will persist beyond the
-/// conversion process, and require inserting a proper user materialization for.
-static void computeNecessaryMaterializations(
-    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
-        &materializationOps,
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping,
-    SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
-  // Helper function to check if the given value or a not yet materialized
-  // replacement of the given value is live.
-  // Note: `inverseMapping` maps from replaced values to original values.
-  auto isLive = [&](Value value) {
-    auto findFn = [&](Operation *user) {
-      auto matIt = materializationOps.find(user);
-      if (matIt != materializationOps.end())
-        return !necessaryMaterializations.count(matIt->second);
-      return rewriterImpl.isOpIgnored(user);
-    };
-    // A worklist is needed because a value may have gone through a chain of
-    // replacements and each of the replaced values may have live users.
-    SmallVector<Value> worklist;
-    worklist.push_back(value);
-    while (!worklist.empty()) {
-      Value next = worklist.pop_back_val();
-      if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
-        return true;
-      // This value may be replacing another value that has a live user.
-      llvm::append_range(worklist, inverseMapping.lookup(next));
-    }
-    return false;
-  };
-
-  llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
-      [&](Value invalidRoot, Value value, Type type) {
-        // Check to see if the input operation was remapped to a variant of the
-        // output.
-        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
-        if (remappedValue.getType() == type && remappedValue != invalidRoot)
-          return remappedValue;
-
-        // Check to see if the input is a materialization operation that
-        // provides an inverse conversion. We just check blindly for
-        // UnrealizedConversionCastOp here, but it has no effect on correctness.
-        auto inputCastOp = value.getDefiningOp<UnrealizedConversionCastOp>();
-        if (inputCastOp && inputCastOp->getNumOperands() == 1)
-          return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0),
-                                     type);
-
-        return Value();
-      };
-
-  SetVector<UnresolvedMaterializationRewrite *> worklist;
-  for (auto &rewrite : rewriterImpl.rewrites) {
-    auto *mat = dyn_cast<UnresolvedMaterializationRewrite>(rewrite.get());
-    if (!mat)
-      continue;
-    materializationOps.try_emplace(mat->getOperation(), mat);
-    worklist.insert(mat);
-  }
-  while (!worklist.empty()) {
-    UnresolvedMaterializationRewrite *mat = worklist.pop_back_val();
-    UnrealizedConversionCastOp op = mat->getOperation();
-
-    // We currently only handle target materializations here.
-    assert(op->getNumResults() == 1 && "unexpected materialization type");
-    OpResult opResult = op->getOpResult(0);
-    Type outputType = opResult.getType();
-    Operation::operand_range inputOperands = op.getOperands();
-
-    // Try to forward propagate operands for user conversion casts that result
-    // in the input types of the current cast.
-    for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) {
-      auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
-      if (!castOp)
-        continue;
-      if (castOp->getResultTypes() == inputOperands.getTypes()) {
-        replaceMaterialization(rewriterImpl, user->getResults(), inputOperands,
-                               inverseMapping);
-        necessaryMaterializations.remove(materializationOps.lookup(user));
-      }
-    }
-
-    // Try to avoid materializing a resolved materialization if possible.
-    // Handle the case of a 1-1 materialization.
-    if (inputOperands.size() == 1) {
-      // Check to see if the input operation was remapped to a variant of the
-      // output.
-      Value remappedValue =
-          lookupRemappedValue(opResult, inputOperands[0], outputType);
-      if (remappedValue && remappedValue != opResult) {
-        replaceMaterialization(rewriterImpl, opResult, remappedValue,
-                               inverseMapping);
-        necessaryMaterializations.remove(mat);
-        continue;
-      }
-    } else {
-      // TODO: Avoid materializing other types of conversions here.
-    }
-
-    // If the materialization does not have any live users, we don't need to
-    // generate a user materialization for it.
-    bool isMaterializationLive = isLive(opResult);
-    if (!isMaterializationLive)
-      continue;
-    if (!necessaryMaterializations.insert(mat))
-      continue;
-
-    // Reprocess input materializations to see if they have an updated status.
-    for (Value input : inputOperands) {
-      if (auto parentOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
-        if (auto *mat = materializationOps.lookup(parentOp))
-          worklist.insert(mat);
-      }
-    }
-  }
-}
-
-/// Legalize the given unresolved materialization. Returns success if the
-/// materialization was legalized, failure otherise.
-static LogicalResult legalizeUnresolvedMaterialization(
-    UnresolvedMaterializationRewrite &mat,
-    DenseMap<Operation *, UnresolvedMaterializationRewrite *>
-        &materializationOps,
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  auto findLiveUser = [&](auto &&users) {
-    auto liveUserIt = llvm::find_if_not(
-        users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); });
-    return liveUserIt == users.end() ? nullptr : *liveUserIt;
-  };
-
-  llvm::unique_function<Value(Value, Type)> lookupRemappedValue =
-      [&](Value value, Type type) {
-        // Check to see if the input operation was remapped to a variant of the
-        // output.
-        Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type);
-        if (remappedValue.getType() == type)
-          return remappedValue;
-        return Value();
-      };
-
-  UnrealizedConversionCastOp op = mat.getOperation();
-  if (!rewriterImpl.ignoredOps.insert(op))
-    return success();
-
-  // We currently only handle target materializations here.
-  OpResult opResult = op->getOpResult(0);
-  Operation::operand_range inputOperands = op.getOperands();
-  Type outputType = opResult.getType();
-
-  // If any input to this materialization is another materialization, resolve
-  // the input first.
-  for (Value value : op->getOperands()) {
-    auto valueCast = value.getDefiningOp<UnrealizedConversionCastOp>();
-    if (!valueCast)
-      continue;
-
-    auto matIt = materializationOps.find(valueCast);
-    if (matIt != materializationOps.end())
-      if (failed(legalizeUnresolvedMaterialization(
-              *matIt->second, materializationOps, rewriter, rewriterImpl,
-              inverseMapping)))
-        return failure();
-  }
-
-  // Perform a last ditch attempt to avoid materializing a resolved
-  // materialization if possible.
-  // Handle the case of a 1-1 materialization.
-  if (inputOperands.size() == 1) {
-    // Check to see if the input operation was remapped to a variant of the
-    // output.
-    Value remappedValue = lookupRemappedValue(inputOperands[0], outputType);
-    if (remappedValue && remappedValue != opResult) {
-      replaceMaterialization(rewriterImpl, opResult, remappedValue,
-                             inverseMapping);
-      return success();
-    }
-  } else {
-    // TODO: Avoid materializing other types of conversions here.
-  }
-
-  // Try to materialize the conversion.
-  if (const TypeConverter *converter = mat.getConverter()) {
-    rewriter.setInsertionPoint(op);
-    Value newMaterialization;
-    switch (mat.getMaterializationKind()) {
-    case MaterializationKind::Argument:
-      // Try to materialize an argument conversion.
-      newMaterialization = converter->materializeArgumentConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      if (newMaterialization)
-        break;
-      // If an argument materialization failed, fallback to trying a target
-      // materialization.
-      [[fallthrough]];
-    case MaterializationKind::Target:
-      newMaterialization = converter->materializeTargetConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    case MaterializationKind::Source:
-      newMaterialization = converter->materializeSourceConversion(
-          rewriter, op->getLoc(), outputType, inputOperands);
-      break;
-    }
-    if (newMaterialization) {
-      assert(newMaterialization.getType() == outputType &&
-             "materialization callback produced value of incorrect type");
-      replaceMaterialization(rewriterImpl, opResult, newMaterialization,
-                             inverseMapping);
-      return success();
-    }
-  }
-
-  InFlightDiagnostic diag = op->emitError()
-                            << "failed to legalize unresolved materialization "
-                               "from ("
-                            << inputOperands.getTypes() << ") to " << outputType
-                            << " that remained live after conversion";
-  if (Operation *liveUser = findLiveUser(op->getUsers())) {
-    diag.attachNote(liveUser->getLoc())
-        << "see existing live user here: " << *liveUser;
-  }
-  return failure();
-}
-
-LogicalResult OperationConverter::legalizeUnresolvedMaterializations(
-    ConversionPatternRewriter &rewriter,
-    ConversionPatternRewriterImpl &rewriterImpl,
-    DenseMap<Value, SmallVector<Value>> &inverseMapping) {
-  // As an initial step, compute all of the inserted materializations that we
-  // expect to persist beyond the conversion process.
-  DenseMap<Operation *, UnresolvedMaterializationRewrite *> materializationOps;
-  SetVector<UnresolvedMaterializationRewrite *> necessaryMaterializations;
-  com...
[truncated]

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Slight nit in the PR description: Initially reading paragraph two and three had me confused whether the ConversionConfig::buildMaterializations knob and the "automatic materilization" as you called it were referring to the same thing or not. Maybe merging these two paragraphs would make this clearer.

mlir/lib/Transforms/Utils/DialectConversion.cpp Outdated Show resolved Hide resolved
Base automatically changed from users/matthias-springer/move_reconcile_unre to main August 23, 2024 15:46
@matthias-springer matthias-springer force-pushed the users/matthias-springer/optional_materializations branch from f72f642 to 41949c9 Compare August 23, 2024 15:48
…ional

Build all source/target/argument materializations after the conversion has succeeded. Provide a new configuration option for users to opt out of all automatic materializations. In that case, the resulting IR will have `builtin.unrealized_conversion_cast` ops.
@matthias-springer matthias-springer force-pushed the users/matthias-springer/optional_materializations branch from 41949c9 to b72f956 Compare August 23, 2024 19:18
@matthias-springer matthias-springer merged commit d7073c5 into main Aug 23, 2024
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/optional_materializations branch August 23, 2024 21:03
5chmidti pushed a commit that referenced this pull request Aug 24, 2024
…#104668)

This commit makes source/target/argument materializations (via the
`TypeConverter` API) optional.

By default (`ConversionConfig::buildMaterializations = true`), the
dialect conversion infrastructure tries to legalize all unresolved
materializations right after the main transformation process has
succeeded. If at least one unresolved materialization fails to resolve,
the dialect conversion fails. (With an error message such as `failed to
legalize unresolved materialization ...`.) Automatic materializations
through the `TypeConverter` API can now be deactivated. In that case,
every unresolved materialization will show up as a
`builtin.unrealized_conversion_cast` op in the output IR.

There used to be a complex and error-prone analysis in the dialect
conversion that predicted the future uses of unresolved
materializations. Based on that logic, some casts (that were deemed to
unnecessary) were folded. This analysis was needed because folding
happened at a point of time when some IR changes (e.g., op replacements)
had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens
after all other IR changes have been materialized and the uses can
directly be queried from the IR. This simplifies the analysis
significantly. And certain helper data structures such as
`inverseMapping` are no longer needed for the analysis. The folding
itself is done by `reconcileUnrealizedCasts` (which also exists as a
standalone pass).

After casts have been folded, the remaining casts are materialized
through the `TypeConverter`, as usual. This last step can be deactivated
in the `ConversionConfig`.

`ConversionConfig::buildMaterializations = false` can be used to debug
error messages such as `failed to legalize unresolved materialization
...`. (It is also useful in case automatic materializations are not
needed.) The materializations that failed to resolve can then be seen as
`builtin.unrealized_conversion_cast` ops in the resulting IR. (This is
better than running with `-debug`, because `-debug` shows IR where some
IR changes have not been materialized yet.)
MaheshRavishankar added a commit to iree-org/llvm-project that referenced this pull request Aug 29, 2024
@MaheshRavishankar
Copy link
Contributor

MaheshRavishankar commented Aug 29, 2024

I think this change broke some of the passes downstream. I definitely know that the TypePropagationPass is broken by this change. This could be because of an error in the pass itself in how it was using the dialect conversion (specifically the region signature conversion) in the wrong way. I dont have an MLIR only repro, but I have an IREE repro

func.func @_select_dispatch_0_elementwise_4_i1xi1xi32xi32xi32() {
  %c0 = arith.constant 0 : index
  %0 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4xi8>>
  %1 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4xi8>>
  %2 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4xi32>>
  %3 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(3) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor<readonly:tensor<4xi32>>
  %4 = hal.interface.binding.subspan layout(<bindings = [#hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, "ReadOnly|Indirect">, #hal.pipeline.binding<storage_buffer, Indirect>], flags = Indirect>) binding(4) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor<writeonly:tensor<4xi32>>
  %5 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4xi8>> -> tensor<4xi8>
  %6 = arith.trunci %5 : tensor<4xi8> to tensor<4xi1>
  %7 = flow.dispatch.tensor.load %1, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4xi8>> -> tensor<4xi8>
  %8 = arith.trunci %7 : tensor<4xi8> to tensor<4xi1>
  %9 = flow.dispatch.tensor.load %2, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4xi32>> -> tensor<4xi32>
  %10 = flow.dispatch.tensor.load %3, offsets = [0], sizes = [4], strides = [1] : !flow.dispatch.tensor<readonly:tensor<4xi32>> -> tensor<4xi32>
  %11 = tensor.empty() : tensor<4xi32>
  %12 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %8, %9, %10 : tensor<4xi1>, tensor<4xi1>, tensor<4xi32>, tensor<4xi32>) outs(%11 : tensor<4xi32>) {
  ^bb0(%in: i1, %in_0: i1, %in_1: i32, %in_2: i32, %out: i32):
    %13 = arith.cmpi ugt, %in, %in_0 : i1
    %14 = arith.select %13, %in_1, %in_2 : i32
    linalg.yield %14 : i32
  } -> tensor<4xi32>
  flow.dispatch.tensor.store %12, %4, offsets = [0], sizes = [4], strides = [1] : tensor<4xi32> -> !flow.dispatch.tensor<writeonly:tensor<4xi32>>
  return
}

fails when this change is used in IREE and compiled with

iree-opt --pass-pipeline="builtin.module(func.func(iree-flow-type-propagation))" repro.mlir

I think it also breaks some EmitC passes (I am checking that in iree-org/iree#18384).

Leaving a note here. For now I am trying to revert this change locally to IREE and will come back to triage this further.

@matthias-springer
Copy link
Member Author

@MaheshRavishankar I'm looking into this. Can you give this a try? matthias-springer/iree@28f0aea (This could be useful even without this PR.) I still have to look into the emitc failure.

@MaheshRavishankar
Copy link
Contributor

Awesome! Thanks! I am off for best part of a week. .

Let me tap someone in.

matthias-springer added a commit that referenced this pull request Aug 30, 2024
…optional" (#106778)

Reverts #104668

This commit triggers an edge case that can cause circular
`unrealized_conversion_cast` ops.
#106760 may fix it, but it is
has other issues. Reverting this PR for now, until I find a solution for
that problem.
@matthias-springer
Copy link
Member Author

matthias-springer commented Sep 1, 2024

There is a more direct way to fix the issue (#106760 is still desirable because it simplifies the code base). In the code that was deleted by this PR, there is the following:

    // Update the reverse mapping, or remove the mapping if we couldn't update
    // it. Not being able to update signals that the mapping would have become
    // circular (i.e. %foo -> newValue -> %foo), which may occur as values are
    // propagated through temporary materializations. We simply drop the
    // mapping, and let the post-conversion replacement logic handle updating
    // uses.
    for (Value inverseMapVal : inverseMapIt->second)
      if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue))
        rewriterImpl.mapping.erase(inverseMapVal);

I saw such circular mappings in the resulting IR in IREE and thought it is a problem with this PR. But the above code snippet shows that these are actually expected to appear. #106760 changes the implementation so that we never generate them. That's one way to handle the problem. Alternatively, we can do something similar as the old code and handle them post-conversion during reconcileUnrealizedCasts. I'm going to give that a try. (Basically doing the same as the old code, but operating on concrete IR instead of an IRMapping.)

dmpolukhin pushed a commit to dmpolukhin/llvm-project that referenced this pull request Sep 2, 2024
…llvm#104668)

This commit makes source/target/argument materializations (via the
`TypeConverter` API) optional.

By default (`ConversionConfig::buildMaterializations = true`), the
dialect conversion infrastructure tries to legalize all unresolved
materializations right after the main transformation process has
succeeded. If at least one unresolved materialization fails to resolve,
the dialect conversion fails. (With an error message such as `failed to
legalize unresolved materialization ...`.) Automatic materializations
through the `TypeConverter` API can now be deactivated. In that case,
every unresolved materialization will show up as a
`builtin.unrealized_conversion_cast` op in the output IR.

There used to be a complex and error-prone analysis in the dialect
conversion that predicted the future uses of unresolved
materializations. Based on that logic, some casts (that were deemed to
unnecessary) were folded. This analysis was needed because folding
happened at a point of time when some IR changes (e.g., op replacements)
had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens
after all other IR changes have been materialized and the uses can
directly be queried from the IR. This simplifies the analysis
significantly. And certain helper data structures such as
`inverseMapping` are no longer needed for the analysis. The folding
itself is done by `reconcileUnrealizedCasts` (which also exists as a
standalone pass).

After casts have been folded, the remaining casts are materialized
through the `TypeConverter`, as usual. This last step can be deactivated
in the `ConversionConfig`.

`ConversionConfig::buildMaterializations = false` can be used to debug
error messages such as `failed to legalize unresolved materialization
...`. (It is also useful in case automatic materializations are not
needed.) The materializations that failed to resolve can then be seen as
`builtin.unrealized_conversion_cast` ops in the resulting IR. (This is
better than running with `-debug`, because `-debug` shows IR where some
IR changes have not been materialized yet.)
matthias-springer added a commit that referenced this pull request Sep 3, 2024
…#104668)

This commit makes source/target/argument materializations (via the
`TypeConverter` API) optional.

By default (`ConversionConfig::buildMaterializations = true`), the
dialect conversion infrastructure tries to legalize all unresolved
materializations right after the main transformation process has
succeeded. If at least one unresolved materialization fails to resolve,
the dialect conversion fails. (With an error message such as `failed to
legalize unresolved materialization ...`.) Automatic materializations
through the `TypeConverter` API can now be deactivated. In that case,
every unresolved materialization will show up as a
`builtin.unrealized_conversion_cast` op in the output IR.

There used to be a complex and error-prone analysis in the dialect
conversion that predicted the future uses of unresolved
materializations. Based on that logic, some casts (that were deemed to
unnecessary) were folded. This analysis was needed because folding
happened at a point of time when some IR changes (e.g., op replacements)
had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens
after all other IR changes have been materialized and the uses can
directly be queried from the IR. This simplifies the analysis
significantly. And certain helper data structures such as
`inverseMapping` are no longer needed for the analysis. The folding
itself is done by `reconcileUnrealizedCasts` (which also exists as a
standalone pass).

After casts have been folded, the remaining casts are materialized
through the `TypeConverter`, as usual. This last step can be deactivated
in the `ConversionConfig`.

`ConversionConfig::buildMaterializations = false` can be used to debug
error messages such as `failed to legalize unresolved materialization
...`. (It is also useful in case automatic materializations are not
needed.) The materializations that failed to resolve can then be seen as
`builtin.unrealized_conversion_cast` ops in the resulting IR. (This is
better than running with `-debug`, because `-debug` shows IR where some
IR changes have not been materialized yet.)

Note: This is a reupload of #104668, but with correct handling of cyclic unrealized_conversion_casts that may be generated by the dialect conversion.
matthias-springer added a commit that referenced this pull request Sep 3, 2024
…#104668)

This commit makes source/target/argument materializations (via the
`TypeConverter` API) optional.

By default (`ConversionConfig::buildMaterializations = true`), the
dialect conversion infrastructure tries to legalize all unresolved
materializations right after the main transformation process has
succeeded. If at least one unresolved materialization fails to resolve,
the dialect conversion fails. (With an error message such as `failed to
legalize unresolved materialization ...`.) Automatic materializations
through the `TypeConverter` API can now be deactivated. In that case,
every unresolved materialization will show up as a
`builtin.unrealized_conversion_cast` op in the output IR.

There used to be a complex and error-prone analysis in the dialect
conversion that predicted the future uses of unresolved
materializations. Based on that logic, some casts (that were deemed to
unnecessary) were folded. This analysis was needed because folding
happened at a point of time when some IR changes (e.g., op replacements)
had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens
after all other IR changes have been materialized and the uses can
directly be queried from the IR. This simplifies the analysis
significantly. And certain helper data structures such as
`inverseMapping` are no longer needed for the analysis. The folding
itself is done by `reconcileUnrealizedCasts` (which also exists as a
standalone pass).

After casts have been folded, the remaining casts are materialized
through the `TypeConverter`, as usual. This last step can be deactivated
in the `ConversionConfig`.

`ConversionConfig::buildMaterializations = false` can be used to debug
error messages such as `failed to legalize unresolved materialization
...`. (It is also useful in case automatic materializations are not
needed.) The materializations that failed to resolve can then be seen as
`builtin.unrealized_conversion_cast` ops in the resulting IR. (This is
better than running with `-debug`, because `-debug` shows IR where some
IR changes have not been materialized yet.)

Note: This is a reupload of #104668, but with correct handling of cyclic unrealized_conversion_casts that may be generated by the dialect conversion.
matthias-springer added a commit that referenced this pull request Sep 5, 2024
…#104668)

This commit makes source/target/argument materializations (via the
`TypeConverter` API) optional.

By default (`ConversionConfig::buildMaterializations = true`), the
dialect conversion infrastructure tries to legalize all unresolved
materializations right after the main transformation process has
succeeded. If at least one unresolved materialization fails to resolve,
the dialect conversion fails. (With an error message such as `failed to
legalize unresolved materialization ...`.) Automatic materializations
through the `TypeConverter` API can now be deactivated. In that case,
every unresolved materialization will show up as a
`builtin.unrealized_conversion_cast` op in the output IR.

There used to be a complex and error-prone analysis in the dialect
conversion that predicted the future uses of unresolved
materializations. Based on that logic, some casts (that were deemed to
unnecessary) were folded. This analysis was needed because folding
happened at a point of time when some IR changes (e.g., op replacements)
had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens
after all other IR changes have been materialized and the uses can
directly be queried from the IR. This simplifies the analysis
significantly. And certain helper data structures such as
`inverseMapping` are no longer needed for the analysis. The folding
itself is done by `reconcileUnrealizedCasts` (which also exists as a
standalone pass).

After casts have been folded, the remaining casts are materialized
through the `TypeConverter`, as usual. This last step can be deactivated
in the `ConversionConfig`.

`ConversionConfig::buildMaterializations = false` can be used to debug
error messages such as `failed to legalize unresolved materialization
...`. (It is also useful in case automatic materializations are not
needed.) The materializations that failed to resolve can then be seen as
`builtin.unrealized_conversion_cast` ops in the resulting IR. (This is
better than running with `-debug`, because `-debug` shows IR where some
IR changes have not been materialized yet.)

Note: This is a reupload of #104668, but with correct handling of cyclic unrealized_conversion_casts that may be generated by the dialect conversion.
matthias-springer added a commit that referenced this pull request Sep 5, 2024
…#107109)

This commit makes source/target/argument materializations (via the
`TypeConverter` API) optional.

By default (`ConversionConfig::buildMaterializations = true`), the
dialect conversion infrastructure tries to legalize all unresolved
materializations right after the main transformation process has
succeeded. If at least one unresolved materialization fails to resolve,
the dialect conversion fails. (With an error message such as `failed to
legalize unresolved materialization ...`.) Automatic materializations
through the `TypeConverter` API can now be deactivated. In that case,
every unresolved materialization will show up as a
`builtin.unrealized_conversion_cast` op in the output IR.

There used to be a complex and error-prone analysis in the dialect
conversion that predicted the future uses of unresolved
materializations. Based on that logic, some casts (that were deemed to
unnecessary) were folded. This analysis was needed because folding
happened at a point of time when some IR changes (e.g., op replacements)
had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens
after all other IR changes have been materialized and the uses can
directly be queried from the IR. This simplifies the analysis
significantly. And certain helper data structures such as
`inverseMapping` are no longer needed for the analysis. The folding
itself is done by `reconcileUnrealizedCasts` (which also exists as a
standalone pass).

After casts have been folded, the remaining casts are materialized
through the `TypeConverter`, as usual. This last step can be deactivated
in the `ConversionConfig`.

`ConversionConfig::buildMaterializations = false` can be used to debug
error messages such as `failed to legalize unresolved materialization
...`. (It is also useful in case automatic materializations are not
needed.) The materializations that failed to resolve can then be seen as
`builtin.unrealized_conversion_cast` ops in the resulting IR. (This is
better than running with `-debug`, because `-debug` shows IR where some
IR changes have not been materialized yet.)

Note: This is a reupload of #104668, but with correct handling of cyclic
unrealized_conversion_casts that may be generated by the dialect
conversion.
MaheshRavishankar pushed a commit to MaheshRavishankar/iree that referenced this pull request Sep 6, 2024
MaheshRavishankar added a commit to iree-org/iree that referenced this pull request Sep 6, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
Co-authored-by: Matthias Springer <[email protected]>
IanWood1 pushed a commit to IanWood1/iree that referenced this pull request Sep 8, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
Co-authored-by: Matthias Springer <[email protected]>
VitaNuo pushed a commit to VitaNuo/llvm-project that referenced this pull request Sep 12, 2024
…llvm#107109)

This commit makes source/target/argument materializations (via the
`TypeConverter` API) optional.

By default (`ConversionConfig::buildMaterializations = true`), the
dialect conversion infrastructure tries to legalize all unresolved
materializations right after the main transformation process has
succeeded. If at least one unresolved materialization fails to resolve,
the dialect conversion fails. (With an error message such as `failed to
legalize unresolved materialization ...`.) Automatic materializations
through the `TypeConverter` API can now be deactivated. In that case,
every unresolved materialization will show up as a
`builtin.unrealized_conversion_cast` op in the output IR.

There used to be a complex and error-prone analysis in the dialect
conversion that predicted the future uses of unresolved
materializations. Based on that logic, some casts (that were deemed to
unnecessary) were folded. This analysis was needed because folding
happened at a point of time when some IR changes (e.g., op replacements)
had not materialized yet.

This commit removes that analysis. Any folding of cast ops now happens
after all other IR changes have been materialized and the uses can
directly be queried from the IR. This simplifies the analysis
significantly. And certain helper data structures such as
`inverseMapping` are no longer needed for the analysis. The folding
itself is done by `reconcileUnrealizedCasts` (which also exists as a
standalone pass).

After casts have been folded, the remaining casts are materialized
through the `TypeConverter`, as usual. This last step can be deactivated
in the `ConversionConfig`.

`ConversionConfig::buildMaterializations = false` can be used to debug
error messages such as `failed to legalize unresolved materialization
...`. (It is also useful in case automatic materializations are not
needed.) The materializations that failed to resolve can then be seen as
`builtin.unrealized_conversion_cast` ops in the resulting IR. (This is
better than running with `-debug`, because `-debug` shows IR where some
IR changes have not been materialized yet.)

Note: This is a reupload of llvm#104668, but with correct handling of cyclic
unrealized_conversion_casts that may be generated by the dialect
conversion.
josemonsalve2 pushed a commit to josemonsalve2/iree that referenced this pull request Sep 14, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
Co-authored-by: Matthias Springer <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir:gpu mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants