Skip to content

Commit 221c7a8

Browse files
author
Peiming Liu
committed
[mlir][sparse] introduce sparse_tensor.iterate operation
1 parent 253c28f commit 221c7a8

File tree

7 files changed

+519
-1
lines changed

7 files changed

+519
-1
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
#include "mlir/IR/OpDefinition.h"
1818
#include "mlir/IR/OpImplementation.h"
1919
#include "mlir/IR/TensorEncoding.h"
20+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2021
#include "mlir/Interfaces/InferTypeOpInterface.h"
22+
#include "mlir/Interfaces/LoopLikeInterface.h"
2123
#include "mlir/Interfaces/SideEffectInterfaces.h"
2224

25+
#include "llvm/ADT/bit.h"
26+
2327
//===----------------------------------------------------------------------===//
2428
//
2529
// Type aliases to help code be more self-documenting. Unfortunately
@@ -54,6 +58,40 @@ struct COOSegment {
5458
}
5559
};
5660

61+
/// A simple wrapper to encode a bitset of defined (at most 64) levels.
62+
class LevelSet {
63+
uint64_t bits = 0;
64+
65+
public:
66+
LevelSet() = default;
67+
explicit LevelSet(uint64_t bits) : bits(bits) {}
68+
operator uint64_t() const { return bits; }
69+
70+
LevelSet &set(unsigned i) {
71+
assert(i < 64);
72+
bits |= 1 << i;
73+
return *this;
74+
}
75+
76+
LevelSet &operator|=(LevelSet lhs) {
77+
bits |= static_cast<uint64_t>(lhs);
78+
return *this;
79+
}
80+
81+
LevelSet &lshift(unsigned offset) {
82+
bits = bits << offset;
83+
return *this;
84+
}
85+
86+
bool operator[](unsigned i) const {
87+
assert(i < 64);
88+
return (bits & (1 << i)) != 0;
89+
}
90+
91+
unsigned count() const { return llvm::popcount(bits); }
92+
bool empty() const { return bits == 0; }
93+
};
94+
5795
} // namespace sparse_tensor
5896
} // namespace mlir
5997

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
1919
list<Trait> traits = []>
2020
: AttrDef<SparseTensor_Dialect, name, traits>;
2121

22+
//===----------------------------------------------------------------------===//
23+
// A simple bitset attribute wrapped over a single int64_t to encode a set of
24+
// sparse tensor levels.
25+
//===----------------------------------------------------------------------===//
26+
27+
def LevelSetAttr :
28+
TypedAttrBase<
29+
I64, "IntegerAttr",
30+
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
31+
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
32+
"LevelSet attribute"> {
33+
let returnType = [{::mlir::sparse_tensor::LevelSet}];
34+
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
35+
}
36+
2237
//===----------------------------------------------------------------------===//
2338
// These attributes are just like `IndexAttr` except that they clarify whether
2439
// the index refers to a dimension (an axis of the semantic tensor) or a level

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
1515
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
1616
include "mlir/Interfaces/InferTypeOpInterface.td"
1717
include "mlir/Interfaces/SideEffectInterfaces.td"
18+
include "mlir/Interfaces/ControlFlowInterfaces.td"
19+
include "mlir/Interfaces/LoopLikeInterface.td"
1820

1921
//===----------------------------------------------------------------------===//
2022
// Base class.
@@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
13041306

13051307
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
13061308
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
1307-
"ForeachOp"]>]> {
1309+
"ForeachOp", "IterateOp"]>]> {
13081310
let summary = "Yield from sparse_tensor set-like operations";
13091311
let description = [{
13101312
Yields a value from within a `binary`, `unary`, `reduce`,
@@ -1513,6 +1515,103 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
15131515
let hasVerifier = 1;
15141516
}
15151517

1518+
def IterateOp : SparseTensor_Op<"iterate",
1519+
[RecursiveMemoryEffects, RecursivelySpeculatable,
1520+
DeclareOpInterfaceMethods<LoopLikeOpInterface,
1521+
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
1522+
"getYieldedValuesMutable"]>,
1523+
DeclareOpInterfaceMethods<RegionBranchOpInterface,
1524+
["getEntrySuccessorOperands"]>,
1525+
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
1526+
1527+
let summary = "Iterate over a sparse iteration space";
1528+
let description = [{
1529+
The `sparse_tensor.iterate` operations represents a loop over the
1530+
provided iteration space extracted from a specific sparse tensor.
1531+
The operation defines an SSA value for a sparse iterator that points
1532+
to the current stored element in the sparse tensor and SSA values
1533+
for coordinates of the stored element. The coordinates are always
1534+
converted to `index` type despite of the underlying sparse tensor
1535+
storage. When coordinates are not used, the SSA values can be skipped
1536+
by `_` symbols, which usually leads to simpler generated code after
1537+
sparsification. For example:
1538+
1539+
```mlir
1540+
// The coordinate for level 0 is not used when iterating over a 2-D
1541+
// iteration space.
1542+
%sparse_tensor.iterate %iterator in %space at(_, %crd_1)
1543+
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
1544+
```
1545+
1546+
`sparse_tensor.iterate` can also operate on loop-carried variables
1547+
and returns the final values after loop termination.
1548+
The initial values of the variables are passed as additional SSA operands
1549+
to the iterator SSA value and used coordinate SSA values mentioned
1550+
above. The operation region has an argument for the iterator, variadic
1551+
arguments for specified (used) coordiates and followed by one argument
1552+
for each loop-carried variable, representing the value of the variable
1553+
at the current iteration.
1554+
The body region must contain exactly one block that terminates with
1555+
`sparse_tensor.yield`.
1556+
1557+
`sparse_tensor.iterate` results hold the final values after the last
1558+
iteration. If the `sparse_tensor.iterate` defines any values, a yield
1559+
must be explicitly present.
1560+
The number and types of the `sparse_tensor.iterate` results must match
1561+
the initial values in the iter_args binding and the yield operands.
1562+
1563+
1564+
A nested `sparse_tensor.iterate` example that prints all the coordinates
1565+
stored in the sparse input:
1566+
1567+
```mlir
1568+
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
1569+
// Iterates over the first level of %sp
1570+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
1571+
%r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
1572+
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
1573+
// Iterates over the second level of %sp
1574+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
1575+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
1576+
%r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
1577+
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
1578+
vector.print %crd0 : index
1579+
vector.print %crd1 : index
1580+
}
1581+
}
1582+
}
1583+
1584+
```
1585+
}];
1586+
1587+
let arguments = (ins AnySparseIterSpace:$iterSpace,
1588+
Variadic<AnyType>:$initArgs,
1589+
LevelSetAttr:$crdUsedLvls);
1590+
let results = (outs Variadic<AnyType>:$results);
1591+
let regions = (region SizedRegion<1>:$region);
1592+
1593+
let extraClassDeclaration = [{
1594+
unsigned getSpaceDim() {
1595+
return getIterSpace().getType().getSpaceDim();
1596+
}
1597+
BlockArgument getIterator() {
1598+
return getRegion().getArguments().front();
1599+
}
1600+
Block::BlockArgListType getCrds() {
1601+
// The first block argument is iterator, the remaining arguments are
1602+
// referenced coordinates.
1603+
return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
1604+
}
1605+
unsigned getNumRegionIterArgs() {
1606+
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
1607+
}
1608+
}];
1609+
1610+
let hasVerifier = 1;
1611+
let hasRegionVerifier = 1;
1612+
let hasCustomAssemblyFormat = 1;
1613+
}
1614+
15161615
//===----------------------------------------------------------------------===//
15171616
// Sparse Tensor Debugging and Test-Only Operations.
15181617
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)