@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
15
15
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
16
16
include "mlir/Interfaces/InferTypeOpInterface.td"
17
17
include "mlir/Interfaces/SideEffectInterfaces.td"
18
+ include "mlir/Interfaces/ControlFlowInterfaces.td"
19
+ include "mlir/Interfaces/LoopLikeInterface.td"
18
20
19
21
//===----------------------------------------------------------------------===//
20
22
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
1304
1306
1305
1307
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
1306
1308
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1307
- "ForeachOp"]>]> {
1309
+ "ForeachOp", "IterateOp" ]>]> {
1308
1310
let summary = "Yield from sparse_tensor set-like operations";
1309
1311
let description = [{
1310
1312
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1476,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1476
1478
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1477
1479
iteration space.
1478
1480
1479
- The type of returned the value is automatically inferred to
1481
+ The type of returned the value is must be
1480
1482
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
1481
1483
The returned iteration space can then be iterated over by
1482
1484
`sparse_tensor.iterate` operations to visit every stored element
@@ -1487,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1487
1489
// Extracts a 1-D iteration space from a COO tensor at level 1.
1488
1490
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1489
1491
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1492
+ ->!sparse_tensor.iter_space<#COO, lvls = 1>
1490
1493
```
1491
1494
}];
1492
1495
@@ -1499,20 +1502,120 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1499
1502
return getHiLvl() - getLoLvl();
1500
1503
}
1501
1504
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1502
- return getResultSpace ().getType().getLvlTypes();
1505
+ return getExtractedSpace ().getType().getLvlTypes();
1503
1506
}
1504
1507
}];
1505
1508
1506
1509
let arguments = (ins AnySparseTensor:$tensor,
1507
1510
Optional<AnySparseIterator>:$parentIter,
1508
1511
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1509
- let results = (outs AnySparseIterSpace:$resultSpace );
1512
+ let results = (outs AnySparseIterSpace:$extractedSpace );
1510
1513
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1511
- " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1514
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
1515
+ "`->` qualified(type($extractedSpace))";
1512
1516
1513
1517
let hasVerifier = 1;
1514
1518
}
1515
1519
1520
+ def IterateOp : SparseTensor_Op<"iterate",
1521
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
1522
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
1523
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1524
+ "getYieldedValuesMutable"]>,
1525
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
1526
+ ["getEntrySuccessorOperands"]>,
1527
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1528
+
1529
+ let summary = "Iterates over a sparse iteration space";
1530
+ let description = [{
1531
+ The `sparse_tensor.iterate` operation represents a loop (nest) over
1532
+ the provided iteration space extracted from a specific sparse tensor.
1533
+ The operation defines an SSA value for a sparse iterator that points
1534
+ to the current stored element in the sparse tensor and SSA values
1535
+ for coordinates of the stored element. The coordinates are always
1536
+ converted to `index` type despite of the underlying sparse tensor
1537
+ storage. When coordinates are not used, the SSA values can be skipped
1538
+ by `_` symbols, which usually leads to simpler generated code after
1539
+ sparsification. For example:
1540
+
1541
+ ```mlir
1542
+ // The coordinate for level 0 is not used when iterating over a 2-D
1543
+ // iteration space.
1544
+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1545
+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1546
+ ```
1547
+
1548
+ `sparse_tensor.iterate` can also operate on loop-carried variables.
1549
+ It returns the final values after loop termination.
1550
+ The initial values of the variables are passed as additional SSA operands
1551
+ to the iterator SSA value and used coordinate SSA values mentioned
1552
+ above. The operation region has an argument for the iterator, variadic
1553
+ arguments for specified (used) coordiates and followed by one argument
1554
+ for each loop-carried variable, representing the value of the variable
1555
+ at the current iteration.
1556
+ The body region must contain exactly one block that terminates with
1557
+ `sparse_tensor.yield`.
1558
+
1559
+ The results of an `sparse_tensor.iterate` hold the final values after
1560
+ the last iteration. If the `sparse_tensor.iterate` defines any values,
1561
+ a yield must be explicitly present.
1562
+ The number and types of the `sparse_tensor.iterate` results must match
1563
+ the initial values in the iter_args binding and the yield operands.
1564
+
1565
+
1566
+ A nested `sparse_tensor.iterate` example that prints all the coordinates
1567
+ stored in the sparse input:
1568
+
1569
+ ```mlir
1570
+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1571
+ // Iterates over the first level of %sp
1572
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
1573
+ : tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
1574
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
1575
+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1576
+ // Iterates over the second level of %sp
1577
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1578
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1579
+ -> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
1580
+ %r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
1581
+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1582
+ vector.print %coord0 : index
1583
+ vector.print %coord1 : index
1584
+ }
1585
+ }
1586
+ }
1587
+
1588
+ ```
1589
+ }];
1590
+
1591
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
1592
+ Variadic<AnyType>:$initArgs,
1593
+ LevelSetAttr:$crdUsedLvls);
1594
+ let results = (outs Variadic<AnyType>:$results);
1595
+ let regions = (region SizedRegion<1>:$region);
1596
+
1597
+ let extraClassDeclaration = [{
1598
+ unsigned getSpaceDim() {
1599
+ return getIterSpace().getType().getSpaceDim();
1600
+ }
1601
+ BlockArgument getIterator() {
1602
+ return getRegion().getArguments().front();
1603
+ }
1604
+ Block::BlockArgListType getCrds() {
1605
+ // The first block argument is iterator, the remaining arguments are
1606
+ // referenced coordinates.
1607
+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1608
+ }
1609
+ unsigned getNumRegionIterArgs() {
1610
+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1611
+ }
1612
+ }];
1613
+
1614
+ let hasVerifier = 1;
1615
+ let hasRegionVerifier = 1;
1616
+ let hasCustomAssemblyFormat = 1;
1617
+ }
1618
+
1516
1619
//===----------------------------------------------------------------------===//
1517
1620
// Sparse Tensor Debugging and Test-Only Operations.
1518
1621
//===----------------------------------------------------------------------===//
0 commit comments