Skip to content

Commit 0f5774b

Browse files
Peiming LiuLukacma
Peiming Liu
authored andcommitted
[mlir][sparse] introduce sparse_tensor.iterate operation (llvm#88955)
A `sparse_tensor.iterate` iterates over a sparse iteration space extracted from `sparse_tensor.extract_iteration_space` operation introduced in llvm#88554.
1 parent 8550f02 commit 0f5774b

File tree

7 files changed

+538
-8
lines changed

7 files changed

+538
-8
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

+40
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
#include "mlir/IR/OpDefinition.h"
1818
#include "mlir/IR/OpImplementation.h"
1919
#include "mlir/IR/TensorEncoding.h"
20+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
22+
#include "mlir/Interfaces/LoopLikeInterface.h"
2123
#include "mlir/Interfaces/SideEffectInterfaces.h"
2224

25+
#include "llvm/ADT/bit.h"
26+
2327
//===----------------------------------------------------------------------===//
2428
//
2529
// Type aliases to help code be more self-documenting. Unfortunately
@@ -54,6 +58,42 @@ struct COOSegment {
5458
}
5559
};
5660

61+
/// A simple wrapper to encode a bitset of (at most 64) levels, currently used
62+
/// by `sparse_tensor.iterate` operation for the set of levels on which the
63+
/// coordinates should be loaded.
64+
class LevelSet {
65+
uint64_t bits = 0;
66+
67+
public:
68+
LevelSet() = default;
69+
explicit LevelSet(uint64_t bits) : bits(bits) {}
70+
operator uint64_t() const { return bits; }
71+
72+
LevelSet &set(unsigned i) {
73+
assert(i < 64);
74+
bits |= static_cast<uint64_t>(0x01u) << i;
75+
return *this;
76+
}
77+
78+
LevelSet &operator|=(LevelSet lhs) {
79+
bits |= static_cast<uint64_t>(lhs);
80+
return *this;
81+
}
82+
83+
LevelSet &lshift(unsigned offset) {
84+
bits = bits << offset;
85+
return *this;
86+
}
87+
88+
bool operator[](unsigned i) const {
89+
assert(i < 64);
90+
return (bits & (1 << i)) != 0;
91+
}
92+
93+
unsigned count() const { return llvm::popcount(bits); }
94+
bool empty() const { return bits == 0; }
95+
};
96+
5797
} // namespace sparse_tensor
5898
} // namespace mlir
5999

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

+15
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
1919
list<Trait> traits = []>
2020
: AttrDef<SparseTensor_Dialect, name, traits>;
2121

22+
//===----------------------------------------------------------------------===//
23+
// A simple bitset attribute wrapped around a single int64_t to encode a set of
24+
// sparse tensor levels.
25+
//===----------------------------------------------------------------------===//
26+
27+
def LevelSetAttr :
28+
TypedAttrBase<
29+
I64, "IntegerAttr",
30+
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
31+
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
32+
"LevelSet attribute"> {
33+
let returnType = [{::mlir::sparse_tensor::LevelSet}];
34+
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
35+
}
36+
2237
//===----------------------------------------------------------------------===//
2338
// These attributes are just like `IndexAttr` except that they clarify whether
2439
// the index refers to a dimension (an axis of the semantic tensor) or a level

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

+108-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616
include "mlir/Interfaces/InferTypeOpInterface.td"
1717
include "mlir/Interfaces/SideEffectInterfaces.td"
18+
include "mlir/Interfaces/ControlFlowInterfaces.td"
19+
include "mlir/Interfaces/LoopLikeInterface.td"
1820

1921
//===----------------------------------------------------------------------===//
2022
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
13041306

13051307
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
13061308
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1307-
"ForeachOp"]>]> {
1309+
"ForeachOp", "IterateOp"]>]> {
13081310
let summary = "Yield from sparse_tensor set-like operations";
13091311
let description = [{
13101312
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1476,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14761478
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
14771479
iteration space.
14781480

1479-
The type of returned the value is automatically inferred to
1481+
The type of returned the value is must be
14801482
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
14811483
The returned iteration space can then be iterated over by
14821484
`sparse_tensor.iterate` operations to visit every stored element
@@ -1487,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14871489
// Extracts a 1-D iteration space from a COO tensor at level 1.
14881490
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
14891491
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
1492+
->!sparse_tensor.iter_space<#COO, lvls = 1>
14901493
```
14911494
}];
14921495

@@ -1499,20 +1502,120 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
14991502
return getHiLvl() - getLoLvl();
15001503
}
15011504
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
1502-
return getResultSpace().getType().getLvlTypes();
1505+
return getExtractedSpace().getType().getLvlTypes();
15031506
}
15041507
}];
15051508

15061509
let arguments = (ins AnySparseTensor:$tensor,
15071510
Optional<AnySparseIterator>:$parentIter,
15081511
LevelAttr:$loLvl, LevelAttr:$hiLvl);
1509-
let results = (outs AnySparseIterSpace:$resultSpace);
1512+
let results = (outs AnySparseIterSpace:$extractedSpace);
15101513
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
1511-
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
1514+
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
1515+
"`->` qualified(type($extractedSpace))";
15121516

15131517
let hasVerifier = 1;
15141518
}
15151519

1520+
def IterateOp : SparseTensor_Op<"iterate",
1521+
[RecursiveMemoryEffects, RecursivelySpeculatable,
1522+
DeclareOpInterfaceMethods<LoopLikeOpInterface,
1523+
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1524+
"getYieldedValuesMutable"]>,
1525+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
1526+
["getEntrySuccessorOperands"]>,
1527+
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1528+
1529+
let summary = "Iterates over a sparse iteration space";
1530+
let description = [{
1531+
The `sparse_tensor.iterate` operation represents a loop (nest) over
1532+
the provided iteration space extracted from a specific sparse tensor.
1533+
The operation defines an SSA value for a sparse iterator that points
1534+
to the current stored element in the sparse tensor and SSA values
1535+
for coordinates of the stored element. The coordinates are always
1536+
converted to `index` type despite of the underlying sparse tensor
1537+
storage. When coordinates are not used, the SSA values can be skipped
1538+
by `_` symbols, which usually leads to simpler generated code after
1539+
sparsification. For example:
1540+
1541+
```mlir
1542+
// The coordinate for level 0 is not used when iterating over a 2-D
1543+
// iteration space.
1544+
%sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1545+
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1546+
```
1547+
1548+
`sparse_tensor.iterate` can also operate on loop-carried variables.
1549+
It returns the final values after loop termination.
1550+
The initial values of the variables are passed as additional SSA operands
1551+
to the iterator SSA value and used coordinate SSA values mentioned
1552+
above. The operation region has an argument for the iterator, variadic
1553+
arguments for specified (used) coordiates and followed by one argument
1554+
for each loop-carried variable, representing the value of the variable
1555+
at the current iteration.
1556+
The body region must contain exactly one block that terminates with
1557+
`sparse_tensor.yield`.
1558+
1559+
The results of an `sparse_tensor.iterate` hold the final values after
1560+
the last iteration. If the `sparse_tensor.iterate` defines any values,
1561+
a yield must be explicitly present.
1562+
The number and types of the `sparse_tensor.iterate` results must match
1563+
the initial values in the iter_args binding and the yield operands.
1564+
1565+
1566+
A nested `sparse_tensor.iterate` example that prints all the coordinates
1567+
stored in the sparse input:
1568+
1569+
```mlir
1570+
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1571+
// Iterates over the first level of %sp
1572+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
1573+
: tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
1574+
%r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
1575+
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1576+
// Iterates over the second level of %sp
1577+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1578+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1579+
-> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
1580+
%r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
1581+
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1582+
vector.print %coord0 : index
1583+
vector.print %coord1 : index
1584+
}
1585+
}
1586+
}
1587+
1588+
```
1589+
}];
1590+
1591+
let arguments = (ins AnySparseIterSpace:$iterSpace,
1592+
Variadic<AnyType>:$initArgs,
1593+
LevelSetAttr:$crdUsedLvls);
1594+
let results = (outs Variadic<AnyType>:$results);
1595+
let regions = (region SizedRegion<1>:$region);
1596+
1597+
let extraClassDeclaration = [{
1598+
unsigned getSpaceDim() {
1599+
return getIterSpace().getType().getSpaceDim();
1600+
}
1601+
BlockArgument getIterator() {
1602+
return getRegion().getArguments().front();
1603+
}
1604+
Block::BlockArgListType getCrds() {
1605+
// The first block argument is iterator, the remaining arguments are
1606+
// referenced coordinates.
1607+
return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1608+
}
1609+
unsigned getNumRegionIterArgs() {
1610+
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1611+
}
1612+
}];
1613+
1614+
let hasVerifier = 1;
1615+
let hasRegionVerifier = 1;
1616+
let hasCustomAssemblyFormat = 1;
1617+
}
1618+
15161619
//===----------------------------------------------------------------------===//
15171620
// Sparse Tensor Debugging and Test-Only Operations.
15181621
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)