-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][linalg] Relax structured op region filler check #123741
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Relax structured op region filler check #123741
Conversation
Removes assert on output type from structure op region filler to allow more graceful error handling.
@llvm/pr-subscribers-mlir-linalg Author: Adam Siemieniuk (adam-smnk) ChangesRemoves assert on output type from structure op region filler to allow more graceful error handling. Full diff: https://github.com/llvm/llvm-project/pull/123741.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..7c45f8805c205e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -121,14 +121,11 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
/// `regionBuilder`. The method is used by both named structured ops created by
/// ods-gen and by manually defined C++ ops. It is called by both builders and
/// parsers and creates a block with arguments corresponding to the elemental
-/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
-/// ShapedType.
+/// types of `inputTypes` and `outputTypes`.
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) {
- assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
-
SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
for (auto containers : {inputTypes, outputTypes}) {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..0853856d933035 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -370,6 +370,24 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
// -----
+func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op expected operand rank (0) to match the result rank of indexing_map #0 (2)}}
+ linalg.matmul ins(%arg0, %arg1 : f32, memref<3x4xf32>)
+ outs(%arg2 : memref<2x4xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
+ // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
+ linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
+ outs(%arg2 : f32)
+ return
+}
+
+// -----
+
func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
// expected-error @+1 {{expected attribute value}}
linalg.matmul indexing_maps = [
|
@llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesRemoves assert on output type from structure op region filler to allow more graceful error handling. Full diff: https://github.com/llvm/llvm-project/pull/123741.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..7c45f8805c205e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -121,14 +121,11 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
/// `regionBuilder`. The method is used by both named structured ops created by
/// ods-gen and by manually defined C++ ops. It is called by both builders and
/// parsers and creates a block with arguments corresponding to the elemental
-/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
-/// ShapedType.
+/// types of `inputTypes` and `outputTypes`.
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region ®ion,
TypeRange inputTypes, TypeRange outputTypes,
ArrayRef<NamedAttribute> attrs,
RegionBuilderFn regionBuilder) {
- assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
-
SmallVector<Type, 8> argTypes;
SmallVector<Location, 8> argLocs;
for (auto containers : {inputTypes, outputTypes}) {
diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir
index a59472377a732c..0853856d933035 100644
--- a/mlir/test/Dialect/Linalg/invalid.mlir
+++ b/mlir/test/Dialect/Linalg/invalid.mlir
@@ -370,6 +370,24 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
// -----
+func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
+ // expected-error @+1 {{'linalg.matmul' op expected operand rank (0) to match the result rank of indexing_map #0 (2)}}
+ linalg.matmul ins(%arg0, %arg1 : f32, memref<3x4xf32>)
+ outs(%arg2 : memref<2x4xf32>)
+ return
+}
+
+// -----
+
+func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
+ // expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
+ linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
+ outs(%arg2 : f32)
+ return
+}
+
+// -----
+
func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
// expected-error @+1 {{expected attribute value}}
linalg.matmul indexing_maps = [
|
Makes sense, but I am surprised that this wouldn't work before. Does |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Having these extra diagnostics is a nice addition.
Removes assert on output type from structure op region filler to allow more graceful error handling.