Skip to content

[mlir][linalg] Relax structured op region filler check #123741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 28, 2025

Conversation

adam-smnk
Copy link
Contributor

Removes assert on output type from structure op region filler to allow more graceful error handling.

Removes assert on output type from structure op region filler
to allow more graceful error handling.
@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Adam Siemieniuk (adam-smnk)

Changes

Removes assert on output type from structure op region filler to allow more graceful error handling.


Full diff: https://github.com/llvm/llvm-project/pull/123741.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+1-4)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+18)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..7c45f8805c205e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -121,14 +121,11 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
 /// `regionBuilder`. The method is used by both named structured ops created by
 /// ods-gen and by manually defined C++ ops. It is called by both builders and
 /// parsers and creates a block with arguments corresponding to the elemental
-/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
-/// ShapedType.
+/// types of `inputTypes` and `outputTypes`.
 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                                    TypeRange inputTypes, TypeRange outputTypes,
                                    ArrayRef<NamedAttribute> attrs,
                                    RegionBuilderFn regionBuilder) {
-  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
-
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
   for (auto containers : {inputTypes, outputTypes}) {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..0853856d933035 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -370,6 +370,24 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
 
 // -----
 
+func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op expected operand rank (0) to match the result rank of indexing_map #0 (2)}}
+  linalg.matmul ins(%arg0, %arg1 : f32, memref<3x4xf32>)
+                outs(%arg2 : memref<2x4xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
+  // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
+  linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
+                outs(%arg2 : f32)
+  return
+}
+
+// -----
+
 func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
   // expected-error @+1 {{expected attribute value}}
   linalg.matmul indexing_maps = [

@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Removes assert on output type from structure op region filler to allow more graceful error handling.


Full diff: https://github.com/llvm/llvm-project/pull/123741.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+1-4)
  • (modified) mlir/test/Dialect/Linalg/invalid.mlir (+18)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..7c45f8805c205e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -121,14 +121,11 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
 /// `regionBuilder`. The method is used by both named structured ops created by
 /// ods-gen and by manually defined C++ ops. It is called by both builders and
 /// parsers and creates a block with arguments corresponding to the elemental
-/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
-/// ShapedType.
+/// types of `inputTypes` and `outputTypes`.
 static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
                                    TypeRange inputTypes, TypeRange outputTypes,
                                    ArrayRef<NamedAttribute> attrs,
                                    RegionBuilderFn regionBuilder) {
-  assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
-
   SmallVector<Type, 8> argTypes;
   SmallVector<Location, 8> argLocs;
   for (auto containers : {inputTypes, outputTypes}) {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..0853856d933035 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -370,6 +370,24 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
 
 // -----
 
+func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+  // expected-error @+1 {{'linalg.matmul' op expected operand rank (0) to match the result rank of indexing_map #0 (2)}}
+  linalg.matmul ins(%arg0, %arg1 : f32, memref<3x4xf32>)
+                outs(%arg2 : memref<2x4xf32>)
+  return
+}
+
+// -----
+
+func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
+  // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
+  linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
+                outs(%arg2 : f32)
+  return
+}
+
+// -----
+
 func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
   // expected-error @+1 {{expected attribute value}}
   linalg.matmul indexing_maps = [

@banach-space
Copy link
Contributor

Makes sense, but I am surprised that this wouldn't work before. Does fillStructuredOpRegion run before Op verifier?

@adam-smnk
Copy link
Contributor Author

Makes sense, but I am surprised that this wouldn't work before. Does fillStructuredOpRegion run before Op verifier?

fillStructuredOpRegion runs as a part of parser and builder so I'd think verification comes later.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Having these extra diagnostics is a nice addition.

@adam-smnk adam-smnk merged commit 458542f into llvm:main Jan 28, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants