Skip to content

Commit 1945de6

Browse files
author
Peiming Liu
committed
make iter_space type explicit
1 parent 8f7d4e2 commit 1945de6

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,7 +1510,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
15101510
LevelAttr:$loLvl, LevelAttr:$hiLvl);
15111511
let results = (outs AnySparseIterSpace:$resultSpace);
15121512
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1513-
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1513+
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
15141514

15151515
let hasVerifier = 1;
15161516
}
@@ -1543,8 +1543,8 @@ def IterateOp : SparseTensor_Op<"iterate",
15431543
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
15441544
```
15451545

1546-
`sparse_tensor.iterate` can also operate on loop-carried variables
1547-
and returns the final values after loop termination.
1546+
`sparse_tensor.iterate` can also operate on loop-carried variables.
1547+
It returns the final values after loop termination.
15481548
The initial values of the variables are passed as additional SSA operands
15491549
to the iterator SSA value and used coordinate SSA values mentioned
15501550
above. The operation region has an argument for the iterator, variadic
@@ -1554,9 +1554,9 @@ def IterateOp : SparseTensor_Op<"iterate",
15541554
The body region must contain exactly one block that terminates with
15551555
`sparse_tensor.yield`.
15561556

1557-
`sparse_tensor.iterate` results hold the final values after the last
1558-
iteration. If the `sparse_tensor.iterate` defines any values, a yield
1559-
must be explicitly present.
1557+
The results of an `sparse_tensor.iterate` hold the final values after
1558+
the last iteration. If the `sparse_tensor.iterate` defines any values,
1559+
a yield must be explicitly present.
15601560
The number and types of the `sparse_tensor.iterate` results must match
15611561
the initial values in the iter_args binding and the yield operands.
15621562

mlir/test/Dialect/SparseTensor/invalid.mlir

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,7 @@ func.func @sparse_print(%arg0: tensor<10x10xf64>) {
10251025
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 2>) {
10261026
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' expect larger level upper bound than lower bound}}
10271027
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 to 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 2>
1028+
-> !sparse_tensor.iter_space<#COO, lvls = 0 to 2>
10281029
return
10291030
}
10301031

@@ -1040,6 +1041,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
10401041
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
10411042
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
10421043
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 0 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1044+
-> !sparse_tensor.iter_space<#COO, lvls = 1>
10431045
return
10441046
}
10451047

@@ -1054,7 +1056,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
10541056

10551057
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
10561058
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be specified iff level lower bound equals 0}}
1057-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO>
1059+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 1 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 1>
10581060
return
10591061
}
10601062

@@ -1077,6 +1079,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>) {
10771079
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#CSR, lvls = 0>) {
10781080
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op mismatch in parent iterator encoding and iteration space encoding.}}
10791081
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#CSR, lvls = 0>
1082+
-> !sparse_tensor.iter_space<#COO, lvls = 1>
10801083
return
10811084
}
10821085

@@ -1092,6 +1095,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
10921095
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>) {
10931096
// expected-error@+1 {{'sparse_tensor.extract_iteration_space' op parent iterator should be used to extract an iteration space from a consecutive level.}}
10941097
%l1 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 2 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1098+
-> !sparse_tensor.iter_space<#COO, lvls = 2>
10951099
return
10961100
}
10971101

@@ -1106,7 +1110,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
11061110
}>
11071111

11081112
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
1109-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1113+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
11101114
// expected-error @+1 {{'sparse_tensor.iterate' op different number of region iter_args and yielded values: 2 != 1}}
11111115
%r1, %r2 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i, %sj = %j): !sparse_tensor.iter_space<#COO, lvls = 0> -> (index, index) {
11121116
sparse_tensor.yield %si : index
@@ -1125,7 +1129,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -
11251129

11261130
// expected-note@+1 {{prior use here}}
11271131
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
1128-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1132+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
11291133
// expected-error @+1 {{use of value '%i' expects different type than prior uses: 'f32' vs 'index'}}
11301134
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> f32 {
11311135
sparse_tensor.yield %outer : f32
@@ -1143,7 +1147,7 @@ func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index) -> f32 {
11431147
}>
11441148

11451149
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
1146-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1150+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
11471151
// expected-error @+1 {{'sparse_tensor.iterate' op 0-th region iter_arg and 0-th yielded value have different type: 'index' != 'f32'}}
11481152
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%si = %i): !sparse_tensor.iter_space<#COO, lvls = 0> -> index {
11491153
%y = arith.constant 1.0 : f32

mlir/test/Dialect/SparseTensor/roundtrip.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -758,9 +758,10 @@ func.func @sparse_has_runtime() -> i1 {
758758
func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse_tensor.iterator<#COO, lvls = 0>)
759759
-> (!sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>) {
760760
// Extracting the iteration space for the first level needs no parent iterator.
761-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
761+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
762762
// Extracting the iteration space for the second level needs a parent iterator.
763763
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1 : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
764+
-> !sparse_tensor.iter_space<#COO, lvls = 1>
764765
return %l1, %l2 : !sparse_tensor.iter_space<#COO, lvls = 0>, !sparse_tensor.iter_space<#COO, lvls = 1>
765766
}
766767

@@ -785,7 +786,7 @@ func.func @sparse_extract_iter_space(%sp : tensor<4x8xf32, #COO>, %it1 : !sparse
785786
// CHECK: return %[[VAL_4]] : index
786787
// CHECK: }
787788
func.func @sparse_iterate(%sp : tensor<4x8xf32, #COO>, %i : index, %j : index) -> index {
788-
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
789+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0>
789790
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
790791
sparse_tensor.yield %outer : index
791792
}

mlir/test/Dialect/SparseTensor/sparse_itertion_licm.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// CHECK: sparse_tensor.iterate
1616
func.func @sparse_iterate(%sp : tensor<?x?xf64, #CSR>) {
1717
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<?x?xf64, #CSR>
18+
-> !sparse_tensor.iter_space<#CSR, lvls = 0>
1819
sparse_tensor.iterate %it1 in %l1 at (%crd) : !sparse_tensor.iter_space<#CSR, lvls = 0> {
1920
%0 = sparse_tensor.values %sp : tensor<?x?xf64, #CSR> to memref<?xf64>
2021
%1 = sparse_tensor.positions %sp { level = 1 : index } : tensor<?x?xf64, #CSR> to memref<?xindex>

0 commit comments

Comments
 (0)