@@ -1478,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1478
1478
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1479
1479
iteration space.
1480
1480
1481
- The type of returned the value is automatically inferred to
1481
+ The type of returned the value is must be
1482
1482
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
1483
1483
The returned iteration space can then be iterated over by
1484
1484
`sparse_tensor.iterate` operations to visit every stored element
@@ -1489,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1489
1489
// Extracts a 1-D iteration space from a COO tensor at level 1.
1490
1490
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1491
1491
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1492
+ ->!sparse_tensor.iter_space<#COO, lvls = 1>
1492
1493
```
1493
1494
}];
1494
1495
@@ -1501,16 +1502,17 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1501
1502
return getHiLvl() - getLoLvl();
1502
1503
}
1503
1504
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1504
- return getResultSpace ().getType().getLvlTypes();
1505
+ return getExtractedSpace ().getType().getLvlTypes();
1505
1506
}
1506
1507
}];
1507
1508
1508
1509
let arguments = (ins AnySparseTensor:$tensor,
1509
1510
Optional<AnySparseIterator>:$parentIter,
1510
1511
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1511
- let results = (outs AnySparseIterSpace:$resultSpace );
1512
+ let results = (outs AnySparseIterSpace:$extractedSpace );
1512
1513
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1513
- " attr-dict `:` type($tensor) (`,` type($parentIter)^)? `->` type($resultSpace)";
1514
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
1515
+ "`->` qualified(type($extractedSpace))";
1514
1516
1515
1517
let hasVerifier = 1;
1516
1518
}
@@ -1567,12 +1569,14 @@ def IterateOp : SparseTensor_Op<"iterate",
1567
1569
```mlir
1568
1570
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1569
1571
// Iterates over the first level of %sp
1570
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1572
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
1573
+ : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
1571
1574
%r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
1572
1575
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1573
1576
// Iterates over the second level of %sp
1574
1577
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1575
1578
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1579
+ -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
1576
1580
%r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
1577
1581
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1578
1582
vector.print %coord0 : index
0 commit comments