Skip to content

[mlir][Transforms] Dialect conversion: Add missing erasure notifications #145030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2025

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Jun 20, 2025

Add missing listener notifications when erasing nested blocks/operations.

This commit also moves some of the functionality from ConversionPatternRewriter to ConversionPatternRewriterImpl. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in ConversionPatternRewriter should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, ConversionPatternRewriterImpl can be bypassed to some degree, and PatternRewriter::eraseBlock etc. can be used.)

Depends on #145018.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jun 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Add missing listener notifications when erasing nested blocks/operations.

This commit also moves some of the functionality from ConversionPatternRewriter to ConversionPatternRewriterImpl. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in ConversionPatternRewriter should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, ConversionPatternRewriterImpl can be bypassed to some degree, and PatternRewriter::eraseBlock etc. can be used.)


Full diff: https://github.com/llvm/llvm-project/pull/145030.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-19)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+16-2)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ff48647f43305..7419d79cd8856 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -274,6 +274,26 @@ struct RewriterState {
 // IR rewrites
 //===----------------------------------------------------------------------===//
 
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
+
+/// Notify the listener that the given block and its contents are being erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
+  for (Operation &op : b)
+    notifyIRErased(listener, op);
+  listener->notifyBlockErased(&b);
+}
+
+/// Notify the listener that the given operation and its contents are being
+/// erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
+  for (Region &r : op.getRegions()) {
+    for (Block &b : r) {
+      notifyIRErased(listener, b);
+    }
+  }
+  listener->notifyOperationErased(&op);
+}
+
 /// An IR rewrite that can be committed (upon success) or rolled back (upon
 /// failure).
 ///
@@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
   }
 
   void commit(RewriterBase &rewriter) override {
-    // Erase the block.
     assert(block && "expected block");
-    assert(block->empty() && "expected empty block");
 
-    // Notify the listener that the block is about to be erased.
+    // Notify the listener that the block and its contents are being erased.
     if (auto *listener =
             dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
-      listener->notifyBlockErased(block);
+      notifyIRErased(listener, *block);
   }
 
   void cleanup(RewriterBase &rewriter) override {
+    // Erase the contents of the block.
+    for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
+      rewriter.eraseOp(&op);
+    assert(block->empty() && "expected empty block");
+
     // Erase the block.
     block->dropAllDefinedValueUses();
     delete block;
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   if (getConfig().unlegalizedOps)
     getConfig().unlegalizedOps->erase(op);
 
-  // Notify the listener that the operation (and its nested operations) was
-  // erased.
-  if (listener) {
-    op->walk<WalkOrder::PostOrder>(
-        [&](Operation *op) { listener->notifyOperationErased(op); });
-  }
+  // Notify the listener that the operation and its contents are being erased.
+  if (listener)
+    notifyIRErased(listener, *op);
 
   // Do not erase the operation yet. It may still be referenced in `mapping`.
   // Just unlink it for now and erase it during cleanup.
@@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp(
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+  assert(!wasOpReplaced(block->getParentOp()) &&
+         "attempting to erase a block within a replaced/erased op");
   appendRewrite<EraseBlockRewrite>(block);
 
   // Unlink the block from its parent region. The block is kept in the rewrite
@@ -1612,12 +1634,16 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   // allows us to keep the operations in the block live and undo the removal by
   // re-inserting the block.
   block->getParent()->getBlocks().remove(block);
+
+  // Mark all nested ops as erased.
+  block->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
-  assert(!wasOpReplaced(block->getParentOp()) &&
-         "attempting to insert into a region within a replaced/erased op");
+  assert(
+      (!config.allowPatternRollback || !wasOpReplaced(block->getParentOp())) &&
+      "attempting to insert into a region within a replaced/erased op");
   LLVM_DEBUG(
       {
         Operation *parent = block->getParentOp();
@@ -1630,6 +1656,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
         }
       });
 
+  if (!config.allowPatternRollback) {
+    // Pattern rollback is not allowed. No extra bookkeeping is needed.
+    return;
+  }
+
   if (!previous) {
     // This is a newly created block.
     appendRewrite<CreateBlockRewrite>(block);
@@ -1709,13 +1740,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
-  assert(!impl->wasOpReplaced(block->getParentOp()) &&
-         "attempting to erase a block within a replaced/erased op");
-
-  // Mark all ops for erasure.
-  for (Operation &op : *block)
-    eraseOp(&op);
-
   impl->eraseBlock(block);
 }
 
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 34948ae685f0a..204c8c1456826 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
 
 // -----
 
+// CHECK: notifyOperationReplaced: test.erase_op
+// CHECK: notifyOperationErased: test.dummy_op_lvl_2
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.dummy_op_lvl_1
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.erase_op
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
+// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
+
 // CHECK-LABEL: func @circular_mapping()
 //  CHECK-NEXT:   "test.valid"() : () -> ()
 func.func @circular_mapping() {
   // Regression test that used to crash due to circular
-  // unrealized_conversion_cast ops.
-  %0 = "test.erase_op"() : () -> (i64)
+  // unrealized_conversion_cast ops. 
+  %0 = "test.erase_op"() ({
+    "test.dummy_op_lvl_1"() ({
+      "test.dummy_op_lvl_2"() : () -> ()
+    }) : () -> ()
+  }): () -> (i64)
   "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
 }
 

@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Add missing listener notifications when erasing nested blocks/operations.

This commit also moves some of the functionality from ConversionPatternRewriter to ConversionPatternRewriterImpl. This is in preparation of the One-Shot Dialect Conversion refactoring: The implementations in ConversionPatternRewriter should be as simple as possible, so that a switch between "rollback allowed" and "rollback not allowed" can be inserted at that level. (In the latter case, ConversionPatternRewriterImpl can be bypassed to some degree, and PatternRewriter::eraseBlock etc. can be used.)


Full diff: https://github.com/llvm/llvm-project/pull/145030.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+43-19)
  • (modified) mlir/test/Transforms/test-legalizer.mlir (+16-2)
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ff48647f43305..7419d79cd8856 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -274,6 +274,26 @@ struct RewriterState {
 // IR rewrites
 //===----------------------------------------------------------------------===//
 
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op);
+
+/// Notify the listener that the given block and its contents are being erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Block &b) {
+  for (Operation &op : b)
+    notifyIRErased(listener, op);
+  listener->notifyBlockErased(&b);
+}
+
+/// Notify the listener that the given operation and its contents are being
+/// erased.
+static void notifyIRErased(RewriterBase::Listener *listener, Operation &op) {
+  for (Region &r : op.getRegions()) {
+    for (Block &b : r) {
+      notifyIRErased(listener, b);
+    }
+  }
+  listener->notifyOperationErased(&op);
+}
+
 /// An IR rewrite that can be committed (upon success) or rolled back (upon
 /// failure).
 ///
@@ -422,17 +442,20 @@ class EraseBlockRewrite : public BlockRewrite {
   }
 
   void commit(RewriterBase &rewriter) override {
-    // Erase the block.
     assert(block && "expected block");
-    assert(block->empty() && "expected empty block");
 
-    // Notify the listener that the block is about to be erased.
+    // Notify the listener that the block and its contents are being erased.
     if (auto *listener =
             dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
-      listener->notifyBlockErased(block);
+      notifyIRErased(listener, *block);
   }
 
   void cleanup(RewriterBase &rewriter) override {
+    // Erase the contents of the block.
+    for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block)))
+      rewriter.eraseOp(&op);
+    assert(block->empty() && "expected empty block");
+
     // Erase the block.
     block->dropAllDefinedValueUses();
     delete block;
@@ -1147,12 +1170,9 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
   if (getConfig().unlegalizedOps)
     getConfig().unlegalizedOps->erase(op);
 
-  // Notify the listener that the operation (and its nested operations) was
-  // erased.
-  if (listener) {
-    op->walk<WalkOrder::PostOrder>(
-        [&](Operation *op) { listener->notifyOperationErased(op); });
-  }
+  // Notify the listener that the operation and its contents are being erased.
+  if (listener)
+    notifyIRErased(listener, *op);
 
   // Do not erase the operation yet. It may still be referenced in `mapping`.
   // Just unlink it for now and erase it during cleanup.
@@ -1605,6 +1625,8 @@ void ConversionPatternRewriterImpl::replaceOp(
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+  assert(!wasOpReplaced(block->getParentOp()) &&
+         "attempting to erase a block within a replaced/erased op");
   appendRewrite<EraseBlockRewrite>(block);
 
   // Unlink the block from its parent region. The block is kept in the rewrite
@@ -1612,12 +1634,16 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
   // allows us to keep the operations in the block live and undo the removal by
   // re-inserting the block.
   block->getParent()->getBlocks().remove(block);
+
+  // Mark all nested ops as erased.
+  block->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
-  assert(!wasOpReplaced(block->getParentOp()) &&
-         "attempting to insert into a region within a replaced/erased op");
+  assert(
+      (!config.allowPatternRollback || !wasOpReplaced(block->getParentOp())) &&
+      "attempting to insert into a region within a replaced/erased op");
   LLVM_DEBUG(
       {
         Operation *parent = block->getParentOp();
@@ -1630,6 +1656,11 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
         }
       });
 
+  if (!config.allowPatternRollback) {
+    // Pattern rollback is not allowed. No extra bookkeeping is needed.
+    return;
+  }
+
   if (!previous) {
     // This is a newly created block.
     appendRewrite<CreateBlockRewrite>(block);
@@ -1709,13 +1740,6 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
-  assert(!impl->wasOpReplaced(block->getParentOp()) &&
-         "attempting to erase a block within a replaced/erased op");
-
-  // Mark all ops for erasure.
-  for (Operation &op : *block)
-    eraseOp(&op);
-
   impl->eraseBlock(block);
 }
 
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 34948ae685f0a..204c8c1456826 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -461,12 +461,26 @@ func.func @convert_detached_signature() {
 
 // -----
 
+// CHECK: notifyOperationReplaced: test.erase_op
+// CHECK: notifyOperationErased: test.dummy_op_lvl_2
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.dummy_op_lvl_1
+// CHECK: notifyBlockErased
+// CHECK: notifyOperationErased: test.erase_op
+// CHECK: notifyOperationInserted: test.valid, was unlinked
+// CHECK: notifyOperationReplaced: test.drop_operands_and_replace_with_valid
+// CHECK: notifyOperationErased: test.drop_operands_and_replace_with_valid
+
 // CHECK-LABEL: func @circular_mapping()
 //  CHECK-NEXT:   "test.valid"() : () -> ()
 func.func @circular_mapping() {
   // Regression test that used to crash due to circular
-  // unrealized_conversion_cast ops.
-  %0 = "test.erase_op"() : () -> (i64)
+  // unrealized_conversion_cast ops. 
+  %0 = "test.erase_op"() ({
+    "test.dummy_op_lvl_1"() ({
+      "test.dummy_op_lvl_2"() : () -> ()
+    }) : () -> ()
+  }): () -> (i64)
   "test.drop_operands_and_replace_with_valid"(%0) : (i64) -> ()
 }
 

@matthias-springer matthias-springer force-pushed the users/matthias-springer/erase_notifications branch 2 times, most recently from d65161f to edb49ec Compare June 20, 2025 12:30
Copy link
Contributor

@j2kun j2kun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Base automatically changed from users/matthias-springer/rename_impl to main June 21, 2025 07:43
@matthias-springer matthias-springer force-pushed the users/matthias-springer/erase_notifications branch from edb49ec to 2fe4d14 Compare June 21, 2025 07:44
@matthias-springer matthias-springer merged commit 0921bfd into main Jun 21, 2025
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/erase_notifications branch June 21, 2025 08:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants