Skip to content

Commit 458542f

Browse files
authored
[mlir][linalg] Relax structured op region filler check (#123741)
Removes assert on output type from structure op region filler to allow more graceful error handling.
1 parent f76f534 commit 458542f

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,11 @@ using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
121121
/// `regionBuilder`. The method is used by both named structured ops created by
122122
/// ods-gen and by manually defined C++ ops. It is called by both builders and
123123
/// parsers and creates a block with arguments corresponding to the elemental
124-
/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
125-
/// ShapedType.
124+
/// types of `inputTypes` and `outputTypes`.
126125
static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
127126
TypeRange inputTypes, TypeRange outputTypes,
128127
ArrayRef<NamedAttribute> attrs,
129128
RegionBuilderFn regionBuilder) {
130-
assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
131-
132129
SmallVector<Type, 8> argTypes;
133130
SmallVector<Location, 8> argLocs;
134131
for (auto containers : {inputTypes, outputTypes}) {

mlir/test/Dialect/Linalg/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,24 @@ func.func @invalid_static_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>,
370370

371371
// -----
372372

373+
func.func @invalid_scalar_input_matmul(%arg0: f32, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
374+
// expected-error @+1 {{'linalg.matmul' op expected operand rank (0) to match the result rank of indexing_map #0 (2)}}
375+
linalg.matmul ins(%arg0, %arg1 : f32, memref<3x4xf32>)
376+
outs(%arg2 : memref<2x4xf32>)
377+
return
378+
}
379+
380+
// -----
381+
382+
func.func @invalid_scalar_output_matmul(%arg0: memref<2x3xf32>, %arg1: memref<3x4xf32>, %arg2: f32) {
383+
// expected-error @+1 {{'linalg.matmul' op operand #2 must be variadic of shaped of any type values, but got 'f32'}}
384+
linalg.matmul ins(%arg0, %arg1 : memref<2x3xf32>, memref<3x4xf32>)
385+
outs(%arg2 : f32)
386+
return
387+
}
388+
389+
// -----
390+
373391
func.func @invalid_indexing_maps_matmul(%arg0: memref<2x4xf32>, %arg1: memref<3x4xf32>, %arg2: memref<2x4xf32>) {
374392
// expected-error @+1 {{expected attribute value}}
375393
linalg.matmul indexing_maps = [

0 commit comments

Comments
 (0)