Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebAssembly] Allow try_table to target loops in AsmTypeCheck #111432

Merged
merged 3 commits into from
Oct 7, 2024

Conversation

aheejin
Copy link
Member

@aheejin aheejin commented Oct 7, 2024

No description provided.

@aheejin aheejin requested a review from dschuff October 7, 2024 20:26
@llvmbot llvmbot added backend:WebAssembly mc Machine (object) code labels Oct 7, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 7, 2024

@llvm/pr-subscribers-mc

Author: Heejin Ahn (aheejin)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/111432.diff

4 Files Affected:

  • (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp (+17-7)
  • (modified) llvm/test/MC/WebAssembly/annotations.s (+24-3)
  • (modified) llvm/test/MC/WebAssembly/eh-assembly.s (+24-1)
  • (modified) llvm/test/MC/WebAssembly/type-checker-errors.s (+17)
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index bfe6996977e690..cc8212d2c9d28d 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -371,13 +371,23 @@ bool WebAssemblyAsmTypeCheck::checkTryTable(SMLoc ErrorLoc,
     if (Level < BlockInfoStack.size()) {
       const auto &DestBlockInfo =
           BlockInfoStack[BlockInfoStack.size() - Level - 1];
-      if (compareTypes(SentTypes, DestBlockInfo.Sig.Returns)) {
-        std::string ErrorMsg =
-            ErrorMsgBase + "type mismatch, catch tag type is " +
-            getTypesString(SentTypes) + ", but destination's type is " +
-            getTypesString(DestBlockInfo.Sig.Returns);
-        Error |= typeError(ErrorLoc, ErrorMsg);
-      }
+     if (DestBlockInfo.IsLoop) {
+        if (compareTypes(SentTypes, DestBlockInfo.Sig.Params)) {
+          std::string ErrorMsg =
+              ErrorMsgBase + "type mismatch, catch tag type is " +
+              getTypesString(SentTypes) + ", but destination's type is " +
+              getTypesString(DestBlockInfo.Sig.Params);
+          Error |= typeError(ErrorLoc, ErrorMsg);
+        }
+     } else {
+       if (compareTypes(SentTypes, DestBlockInfo.Sig.Returns)) {
+         std::string ErrorMsg =
+             ErrorMsgBase + "type mismatch, catch tag type is " +
+             getTypesString(SentTypes) + ", but destination's type is " +
+             getTypesString(DestBlockInfo.Sig.Returns);
+         Error |= typeError(ErrorLoc, ErrorMsg);
+       }
+     }
     } else {
       Error = typeError(ErrorLoc, ErrorMsgBase + "invalid depth " +
                                       std::to_string(Level));
diff --git a/llvm/test/MC/WebAssembly/annotations.s b/llvm/test/MC/WebAssembly/annotations.s
index 2a5bea27941789..f761ef3f06b1fc 100644
--- a/llvm/test/MC/WebAssembly/annotations.s
+++ b/llvm/test/MC/WebAssembly/annotations.s
@@ -7,7 +7,7 @@
   .section .text.test_annotation,"",@
   .type    test_annotation,@function
 test_annotation:
-  .functype   test_annotation () -> ()
+  .functype   test_annotation (exnref) -> ()
   .tagtype  __cpp_exception i32
   .tagtype  __c_longjmp i32
   try
@@ -54,8 +54,18 @@ test_annotation:
     return
   end_block
   drop
-  end_function
 
+  i32.const 0
+  loop (i32) -> ()
+    local.get 0
+    loop (exnref) -> ()
+      try_table (catch __cpp_exception 1) (catch_all_ref 0)
+      end_try_table
+      drop
+    end_loop
+    drop
+  end_loop
+  end_function
 
 # CHECK:      test_annotation:
 # CHECK:        try
@@ -105,5 +115,16 @@ test_annotation:
 # CHECK-NEXT:   return
 # CHECK-NEXT:   end_block                               # label7:
 # CHECK-NEXT:   drop
-# CHECK-NEXT:   end_function
 
+# CHECK:        i32.const       0
+# CHECK-NEXT:   loop            (i32) -> ()                     # label12:
+# CHECK-NEXT:   local.get       0
+# CHECK-NEXT:   loop            (exnref) -> ()                  # label13:
+# CHECK-NEXT:   try_table        (catch __cpp_exception 1) (catch_all_ref 0) # 1: up to label12
+# CHECK-NEXT:                                 # 0: up to label13
+# CHECK-NEXT:   end_try_table                           # label14:
+# CHECK-NEXT:   drop
+# CHECK-NEXT:   end_loop
+# CHECK-NEXT:   drop
+# CHECK-NEXT:   end_loop
+# CHECK-NEXT:   end_function
diff --git a/llvm/test/MC/WebAssembly/eh-assembly.s b/llvm/test/MC/WebAssembly/eh-assembly.s
index 38cda10a387a3b..31dfce5a3cde31 100644
--- a/llvm/test/MC/WebAssembly/eh-assembly.s
+++ b/llvm/test/MC/WebAssembly/eh-assembly.s
@@ -7,7 +7,7 @@
   .functype  foo () -> ()
 
 eh_test:
-  .functype  eh_test () -> ()
+  .functype  eh_test (exnref) -> ()
 
   # try_table with all four kinds of catch clauses
   block exnref
@@ -82,6 +82,18 @@ eh_test:
   end_try_table
   drop
   drop
+
+  # try_table targeting loops
+  i32.const 0
+  loop (i32) -> ()
+    local.get 0
+    loop (exnref) -> ()
+      try_table (catch __cpp_exception 1) (catch_all_ref 0)
+      end_try_table
+      drop
+    end_loop
+    drop
+  end_loop
   end_function
 
 eh_legacy_test:
@@ -203,6 +215,17 @@ eh_legacy_test:
 # CHECK-NEXT:    drop
 # CHECK-NEXT:    drop
 
+# CHECK:         i32.const       0
+# CHECK-NEXT:    loop            (i32) -> ()
+# CHECK-NEXT:    local.get       0
+# CHECK-NEXT:    loop            (exnref) -> ()
+# CHECK-NEXT:    try_table        (catch __cpp_exception 1) (catch_all_ref 0)
+# CHECK:         end_try_table
+# CHECK-NEXT:    drop
+# CHECK-NEXT:    end_loop
+# CHECK-NEXT:    drop
+# CHECK-NEXT:    end_loop
+
 # CHECK:       eh_legacy_test:
 # CHECK:         try
 # CHECK-NEXT:    i32.const       3
diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s
index c1c8209e1dce0c..9aa652348c538e 100644
--- a/llvm/test/MC/WebAssembly/type-checker-errors.s
+++ b/llvm/test/MC/WebAssembly/type-checker-errors.s
@@ -966,4 +966,21 @@ eh_test:
     end_block
   end_block
   drop
+
+  loop
+  i32.const 0
+    loop (i32) -> ()
+      loop (i32) -> ()
+        loop
+# CHECK: :[[@LINE+4]]:11: error: try_table: catch index 0: type mismatch, catch tag type is [i32], but destination's type is []
+# CHECK: :[[@LINE+3]]:11: error: try_table: catch index 1: type mismatch, catch tag type is [i32, exnref], but destination's type is [i32]
+# CHECK: :[[@LINE+2]]:11: error: try_table: catch index 2: type mismatch, catch tag type is [], but destination's type is [i32]
+# CHECK: :[[@LINE+1]]:11: error: try_table: catch index 3: type mismatch, catch tag type is [exnref], but destination's type is []
+          try_table (catch __cpp_exception 0) (catch_ref __cpp_exception 1) (catch_all 2) (catch_all_ref 3)
+          end_try_table
+        end_loop
+        drop
+      end_loop
+    end_loop
+  end_loop
   end_function

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 7, 2024

@llvm/pr-subscribers-backend-webassembly

Author: Heejin Ahn (aheejin)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/111432.diff

4 Files Affected:

  • (modified) llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp (+17-7)
  • (modified) llvm/test/MC/WebAssembly/annotations.s (+24-3)
  • (modified) llvm/test/MC/WebAssembly/eh-assembly.s (+24-1)
  • (modified) llvm/test/MC/WebAssembly/type-checker-errors.s (+17)
diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
index bfe6996977e690..cc8212d2c9d28d 100644
--- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
+++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp
@@ -371,13 +371,23 @@ bool WebAssemblyAsmTypeCheck::checkTryTable(SMLoc ErrorLoc,
     if (Level < BlockInfoStack.size()) {
       const auto &DestBlockInfo =
           BlockInfoStack[BlockInfoStack.size() - Level - 1];
-      if (compareTypes(SentTypes, DestBlockInfo.Sig.Returns)) {
-        std::string ErrorMsg =
-            ErrorMsgBase + "type mismatch, catch tag type is " +
-            getTypesString(SentTypes) + ", but destination's type is " +
-            getTypesString(DestBlockInfo.Sig.Returns);
-        Error |= typeError(ErrorLoc, ErrorMsg);
-      }
+     if (DestBlockInfo.IsLoop) {
+        if (compareTypes(SentTypes, DestBlockInfo.Sig.Params)) {
+          std::string ErrorMsg =
+              ErrorMsgBase + "type mismatch, catch tag type is " +
+              getTypesString(SentTypes) + ", but destination's type is " +
+              getTypesString(DestBlockInfo.Sig.Params);
+          Error |= typeError(ErrorLoc, ErrorMsg);
+        }
+     } else {
+       if (compareTypes(SentTypes, DestBlockInfo.Sig.Returns)) {
+         std::string ErrorMsg =
+             ErrorMsgBase + "type mismatch, catch tag type is " +
+             getTypesString(SentTypes) + ", but destination's type is " +
+             getTypesString(DestBlockInfo.Sig.Returns);
+         Error |= typeError(ErrorLoc, ErrorMsg);
+       }
+     }
     } else {
       Error = typeError(ErrorLoc, ErrorMsgBase + "invalid depth " +
                                       std::to_string(Level));
diff --git a/llvm/test/MC/WebAssembly/annotations.s b/llvm/test/MC/WebAssembly/annotations.s
index 2a5bea27941789..f761ef3f06b1fc 100644
--- a/llvm/test/MC/WebAssembly/annotations.s
+++ b/llvm/test/MC/WebAssembly/annotations.s
@@ -7,7 +7,7 @@
   .section .text.test_annotation,"",@
   .type    test_annotation,@function
 test_annotation:
-  .functype   test_annotation () -> ()
+  .functype   test_annotation (exnref) -> ()
   .tagtype  __cpp_exception i32
   .tagtype  __c_longjmp i32
   try
@@ -54,8 +54,18 @@ test_annotation:
     return
   end_block
   drop
-  end_function
 
+  i32.const 0
+  loop (i32) -> ()
+    local.get 0
+    loop (exnref) -> ()
+      try_table (catch __cpp_exception 1) (catch_all_ref 0)
+      end_try_table
+      drop
+    end_loop
+    drop
+  end_loop
+  end_function
 
 # CHECK:      test_annotation:
 # CHECK:        try
@@ -105,5 +115,16 @@ test_annotation:
 # CHECK-NEXT:   return
 # CHECK-NEXT:   end_block                               # label7:
 # CHECK-NEXT:   drop
-# CHECK-NEXT:   end_function
 
+# CHECK:        i32.const       0
+# CHECK-NEXT:   loop            (i32) -> ()                     # label12:
+# CHECK-NEXT:   local.get       0
+# CHECK-NEXT:   loop            (exnref) -> ()                  # label13:
+# CHECK-NEXT:   try_table        (catch __cpp_exception 1) (catch_all_ref 0) # 1: up to label12
+# CHECK-NEXT:                                 # 0: up to label13
+# CHECK-NEXT:   end_try_table                           # label14:
+# CHECK-NEXT:   drop
+# CHECK-NEXT:   end_loop
+# CHECK-NEXT:   drop
+# CHECK-NEXT:   end_loop
+# CHECK-NEXT:   end_function
diff --git a/llvm/test/MC/WebAssembly/eh-assembly.s b/llvm/test/MC/WebAssembly/eh-assembly.s
index 38cda10a387a3b..31dfce5a3cde31 100644
--- a/llvm/test/MC/WebAssembly/eh-assembly.s
+++ b/llvm/test/MC/WebAssembly/eh-assembly.s
@@ -7,7 +7,7 @@
   .functype  foo () -> ()
 
 eh_test:
-  .functype  eh_test () -> ()
+  .functype  eh_test (exnref) -> ()
 
   # try_table with all four kinds of catch clauses
   block exnref
@@ -82,6 +82,18 @@ eh_test:
   end_try_table
   drop
   drop
+
+  # try_table targeting loops
+  i32.const 0
+  loop (i32) -> ()
+    local.get 0
+    loop (exnref) -> ()
+      try_table (catch __cpp_exception 1) (catch_all_ref 0)
+      end_try_table
+      drop
+    end_loop
+    drop
+  end_loop
   end_function
 
 eh_legacy_test:
@@ -203,6 +215,17 @@ eh_legacy_test:
 # CHECK-NEXT:    drop
 # CHECK-NEXT:    drop
 
+# CHECK:         i32.const       0
+# CHECK-NEXT:    loop            (i32) -> ()
+# CHECK-NEXT:    local.get       0
+# CHECK-NEXT:    loop            (exnref) -> ()
+# CHECK-NEXT:    try_table        (catch __cpp_exception 1) (catch_all_ref 0)
+# CHECK:         end_try_table
+# CHECK-NEXT:    drop
+# CHECK-NEXT:    end_loop
+# CHECK-NEXT:    drop
+# CHECK-NEXT:    end_loop
+
 # CHECK:       eh_legacy_test:
 # CHECK:         try
 # CHECK-NEXT:    i32.const       3
diff --git a/llvm/test/MC/WebAssembly/type-checker-errors.s b/llvm/test/MC/WebAssembly/type-checker-errors.s
index c1c8209e1dce0c..9aa652348c538e 100644
--- a/llvm/test/MC/WebAssembly/type-checker-errors.s
+++ b/llvm/test/MC/WebAssembly/type-checker-errors.s
@@ -966,4 +966,21 @@ eh_test:
     end_block
   end_block
   drop
+
+  loop
+  i32.const 0
+    loop (i32) -> ()
+      loop (i32) -> ()
+        loop
+# CHECK: :[[@LINE+4]]:11: error: try_table: catch index 0: type mismatch, catch tag type is [i32], but destination's type is []
+# CHECK: :[[@LINE+3]]:11: error: try_table: catch index 1: type mismatch, catch tag type is [i32, exnref], but destination's type is [i32]
+# CHECK: :[[@LINE+2]]:11: error: try_table: catch index 2: type mismatch, catch tag type is [], but destination's type is [i32]
+# CHECK: :[[@LINE+1]]:11: error: try_table: catch index 3: type mismatch, catch tag type is [exnref], but destination's type is []
+          try_table (catch __cpp_exception 0) (catch_ref __cpp_exception 1) (catch_all 2) (catch_all_ref 3)
+          end_try_table
+        end_loop
+        drop
+      end_loop
+    end_loop
+  end_loop
   end_function

Copy link

github-actions bot commented Oct 7, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Comment on lines 374 to 389
if (DestBlockInfo.IsLoop) {
if (compareTypes(SentTypes, DestBlockInfo.Sig.Params)) {
std::string ErrorMsg =
ErrorMsgBase + "type mismatch, catch tag type is " +
getTypesString(SentTypes) + ", but destination's type is " +
getTypesString(DestBlockInfo.Sig.Params);
Error |= typeError(ErrorLoc, ErrorMsg);
}
} else {
if (compareTypes(SentTypes, DestBlockInfo.Sig.Returns)) {
std::string ErrorMsg =
ErrorMsgBase + "type mismatch, catch tag type is " +
getTypesString(SentTypes) + ", but destination's type is " +
getTypesString(DestBlockInfo.Sig.Returns);
Error |= typeError(ErrorLoc, ErrorMsg);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of repetitive. I wanted something like

ArrayRef DestTypes = DestBlockInfo.IsLoop ? DestBlockInfo.Sig.Params : DestBlockInfo.Sig.Returns;
if (compares(SentTypes, DestTypes)) {
  ...
}

But this was not possible because the types of WasmSignature::Params and WasmSignature::Returns are different:

SmallVector<ValType, 1> Returns;
SmallVector<ValType, 4> Params;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm... could you maybe upcast to SmallVectorImpl somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neither of

ArrayRef<wasm::ValType> DestTypes = DestBlockInfo.IsLoop ? DestBlockInfo.Sig.Params : DestBlockInfo.Sig.Returns;

and

SmallVectorImpl<wasm::ValType> DestTypes = DestBlockInfo.IsLoop ? DestBlockInfo.Sig.Params : DestBlockInfo.Sig.Returns;

works, saying things like

.../llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmTypeCheck.cpp:375:32: error: incompatible operand types ('const SmallVector<[...], 4>' and 'const SmallVector<[...], 1>')
  375 |           DestBlockInfo.IsLoop ? DestBlockInfo.Sig.Params
      |                                ^ ~~~~~~~~~~~~~~~~~~~~~~~~
  376 |                                : DestBlockInfo.Sig.Returns;
      |                                  ~~~~~~~~~~~~~~~~~~~~~~~~~

But this seems to work:

      ArrayRef<wasm::ValType> DestTypes;
      if (DestBlockInfo.IsLoop)
        DestTypes = DestBlockInfo.Sig.Params;
      else
        DestTypes = DestBlockInfo.Sig.Returns;
      if (compareTypes(SentTypes, DestTypes)) {
        ...

Replacing ArrayRef with SmallVectorImpl here doesn't work because SmallVectorImpl doesn't seem to have its own default constructor.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. I wonder if you could upcast both sides of the : in the ternary to SmallVectorImpl... but that doesn't really seem any better than what you have here.

@aheejin
Copy link
Member Author

aheejin commented Oct 7, 2024

The CI failures are irrelevant. Merging.

@aheejin aheejin merged commit 991adff into llvm:main Oct 7, 2024
6 of 8 checks passed
@aheejin aheejin deleted the eh_typecheck_loop branch October 7, 2024 22:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:WebAssembly mc Machine (object) code
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants