@@ -15,8 +15,6 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
15
15
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
16
16
include "mlir/Interfaces/InferTypeOpInterface.td"
17
17
include "mlir/Interfaces/SideEffectInterfaces.td"
18
- include "mlir/Interfaces/ControlFlowInterfaces.td"
19
- include "mlir/Interfaces/LoopLikeInterface.td"
20
18
21
19
//===----------------------------------------------------------------------===//
22
20
// Base class.
@@ -1279,7 +1277,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
1279
1277
1280
1278
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
1281
1279
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1282
- "ForeachOp", "IterateOp" ]>]>,
1280
+ "ForeachOp"]>]>,
1283
1281
Arguments<(ins Variadic<AnyType>:$results)> {
1284
1282
let summary = "Yield from sparse_tensor set-like operations";
1285
1283
let description = [{
@@ -1432,154 +1430,6 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
1432
1430
let hasVerifier = 1;
1433
1431
}
1434
1432
1435
- //===----------------------------------------------------------------------===//
1436
- // Sparse Tensor Iteration Operations.
1437
- //===----------------------------------------------------------------------===//
1438
-
1439
- def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
1440
- [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
1441
-
1442
- let arguments = (ins AnySparseTensor:$tensor,
1443
- Optional<AnySparseIterator>:$parentIter,
1444
- LevelAttr:$loLvl, LevelAttr:$hiLvl);
1445
-
1446
- let results = (outs AnySparseIterSpace:$resultSpace);
1447
-
1448
- let summary = "Extract an iteration space from a sparse tensor between certain levels";
1449
- let description = [{
1450
- Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
1451
- certian (consecutive) levels.
1452
-
1453
- `tensor`: the input sparse tensor that defines the iteration space.
1454
- `parentIter`: the iterator for the previous level, at which the iteration space
1455
- at the current levels will be extracted.
1456
- `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
1457
- the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
1458
- iteration space.
1459
-
1460
- Example:
1461
- ```mlir
1462
- // Extracts a 1-D iteration space from a COO tensor at level 1.
1463
- %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
1464
- : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1465
- ```
1466
- }];
1467
-
1468
-
1469
- let extraClassDeclaration = [{
1470
- std::pair<Level, Level> getLvlRange() {
1471
- return std::make_pair(getLoLvl(), getHiLvl());
1472
- }
1473
- unsigned getSpaceDim() {
1474
- return getHiLvl() - getLoLvl();
1475
- }
1476
- ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1477
- return getResultSpace().getType().getLvlTypes();
1478
- }
1479
- }];
1480
-
1481
- let hasVerifier = 1;
1482
- let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1483
- " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1484
- }
1485
-
1486
- def IterateOp : SparseTensor_Op<"iterate",
1487
- [RecursiveMemoryEffects, RecursivelySpeculatable,
1488
- DeclareOpInterfaceMethods<LoopLikeOpInterface,
1489
- ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1490
- "getYieldedValuesMutable"]>,
1491
- DeclareOpInterfaceMethods<RegionBranchOpInterface,
1492
- ["getEntrySuccessorOperands"]>,
1493
- SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1494
-
1495
- let arguments = (ins AnySparseIterSpace:$iterSpace,
1496
- Variadic<AnyType>:$initArgs,
1497
- LevelSetAttr:$crdUsedLvls);
1498
- let results = (outs Variadic<AnyType>:$results);
1499
- let regions = (region SizedRegion<1>:$region);
1500
-
1501
- let summary = "Iterate over a sparse iteration space";
1502
- let description = [{
1503
- The `sparse_tensor.iterate` operations represents a loop over the
1504
- provided iteration space extracted from a specific sparse tensor.
1505
- The operation defines an SSA value for a sparse iterator that points
1506
- to the current stored element in the sparse tensor and SSA values
1507
- for coordinates of the stored element. The coordinates are always
1508
- converted to `index` type despite of the underlying sparse tensor
1509
- storage. When coordinates are not used, the SSA values can be skipped
1510
- by `_` symbols, which usually leads to simpler generated code after
1511
- sparsification. For example:
1512
-
1513
- ```mlir
1514
- // The coordinate for level 0 is not used when iterating over a 2-D
1515
- // iteration space.
1516
- %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1517
- : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1518
- ```
1519
-
1520
- `sparse_tensor.iterate` can also operate on loop-carried variables
1521
- and returns the final values after loop termination.
1522
- The initial values of the variables are passed as additional SSA operands
1523
- to the iterator SSA value and used coordinate SSA values mentioned
1524
- above. The operation region has an argument for the iterator, variadic
1525
- arguments for specified (used) coordiates and followed by one argument
1526
- for each loop-carried variable, representing the value of the variable
1527
- at the current iteration.
1528
- The body region must contain exactly one block that terminates with
1529
- `sparse_tensor.yield`.
1530
-
1531
- `sparse_tensor.iterate` results hold the final values after the last
1532
- iteration. If the `sparse_tensor.iterate` defines any values, a yield
1533
- must be explicitly present.
1534
- The number and types of the `sparse_tensor.iterate` results must match
1535
- the initial values in the iter_args binding and the yield operands.
1536
-
1537
-
1538
- A nested `sparse_tensor.iterate` example that prints all the coordinates
1539
- stored in the sparse input:
1540
-
1541
- ```mlir
1542
- func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1543
- // Iterates over the first level of %sp
1544
- %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1545
- %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1546
- : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1547
- // Iterates over the second level of %sp
1548
- %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1549
- : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1550
- %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1551
- : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1552
- vector.print %crd0 : index
1553
- vector.print %crd1 : index
1554
- }
1555
- }
1556
- }
1557
-
1558
- ```
1559
- }];
1560
-
1561
- let extraClassDeclaration = [{
1562
- unsigned getSpaceDim() {
1563
- return getIterSpace().getType().getSpaceDim();
1564
- }
1565
- BlockArgument getIterator() {
1566
- return getRegion().getArguments().front();
1567
- }
1568
- Block::BlockArgListType getCrds() {
1569
- // The first block argument is iterator, the remaining arguments are
1570
- // referenced coordinates.
1571
- return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1572
- }
1573
- unsigned getNumRegionIterArgs() {
1574
- return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1575
- }
1576
- }];
1577
-
1578
- let hasVerifier = 1;
1579
- let hasRegionVerifier = 1;
1580
- let hasCustomAssemblyFormat = 1;
1581
- }
1582
-
1583
1433
//===----------------------------------------------------------------------===//
1584
1434
// Sparse Tensor Debugging and Test-Only Operations.
1585
1435
//===----------------------------------------------------------------------===//
0 commit comments