Skip to content

[mlir][sparse] introduce sparse_tensor.iterate operation #88955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "llvm/ADT/bit.h"

//===----------------------------------------------------------------------===//
//
// Type aliases to help code be more self-documenting. Unfortunately
Expand Down Expand Up @@ -54,6 +58,42 @@ struct COOSegment {
}
};

/// A simple wrapper to encode a bitset of (at most 64) levels, currently used
/// by `sparse_tensor.iterate` operation for the set of levels on which the
/// coordinates should be loaded.
class LevelSet {
uint64_t bits = 0;

public:
LevelSet() = default;
explicit LevelSet(uint64_t bits) : bits(bits) {}
operator uint64_t() const { return bits; }

LevelSet &set(unsigned i) {
assert(i < 64);
bits |= static_cast<uint64_t>(0x01u) << i;
return *this;
}

LevelSet &operator|=(LevelSet lhs) {
bits |= static_cast<uint64_t>(lhs);
return *this;
}

LevelSet &lshift(unsigned offset) {
bits = bits << offset;
return *this;
}

bool operator[](unsigned i) const {
assert(i < 64);
return (bits & (1 << i)) != 0;
}

unsigned count() const { return llvm::popcount(bits); }
bool empty() const { return bits == 0; }
};

} // namespace sparse_tensor
} // namespace mlir

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
list<Trait> traits = []>
: AttrDef<SparseTensor_Dialect, name, traits>;

//===----------------------------------------------------------------------===//
// A simple bitset attribute wrapped around a single int64_t to encode a set of
// sparse tensor levels.
//===----------------------------------------------------------------------===//

def LevelSetAttr :
TypedAttrBase<
I64, "IntegerAttr",
And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
"LevelSet attribute"> {
let returnType = [{::mlir::sparse_tensor::LevelSet}];
let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
}

//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level
Expand Down
113 changes: 108 additions & 5 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"

//===----------------------------------------------------------------------===//
// Base class.
Expand Down Expand Up @@ -1304,7 +1306,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu

def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
"ForeachOp"]>]> {
"ForeachOp", "IterateOp"]>]> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
Yields a value from within a `binary`, `unary`, `reduce`,
Expand Down Expand Up @@ -1476,7 +1478,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
iteration space.

The type of returned the value is automatically inferred to
The type of returned the value is must be
`!sparse_tensor.iter_space<#INPUT_ENCODING, lvls = $loLvl to $hiLvl>`.
The returned iteration space can then be iterated over by
`sparse_tensor.iterate` operations to visit every stored element
Expand All @@ -1487,6 +1489,7 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
// Extracts a 1-D iteration space from a COO tensor at level 1.
%space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
->!sparse_tensor.iter_space<#COO, lvls = 1>
```
}];

Expand All @@ -1499,20 +1502,120 @@ def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
return getHiLvl() - getLoLvl();
}
ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
return getResultSpace().getType().getLvlTypes();
return getExtractedSpace().getType().getLvlTypes();
}
}];

let arguments = (ins AnySparseTensor:$tensor,
Optional<AnySparseIterator>:$parentIter,
LevelAttr:$loLvl, LevelAttr:$hiLvl);
let results = (outs AnySparseIterSpace:$resultSpace);
let results = (outs AnySparseIterSpace:$extractedSpace);
let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
" attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
" attr-dict `:` type($tensor) (`,` type($parentIter)^)? "
"`->` qualified(type($extractedSpace))";

let hasVerifier = 1;
}

def IterateOp : SparseTensor_Op<"iterate",
[RecursiveMemoryEffects, RecursivelySpeculatable,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getInitsMutable", "getLoopResults", "getRegionIterArgs",
"getYieldedValuesMutable"]>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands"]>,
SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {

let summary = "Iterates over a sparse iteration space";
let description = [{
The `sparse_tensor.iterate` operation represents a loop (nest) over
the provided iteration space extracted from a specific sparse tensor.
The operation defines an SSA value for a sparse iterator that points
to the current stored element in the sparse tensor and SSA values
for coordinates of the stored element. The coordinates are always
converted to `index` type despite of the underlying sparse tensor
storage. When coordinates are not used, the SSA values can be skipped
by `_` symbols, which usually leads to simpler generated code after
sparsification. For example:

```mlir
// The coordinate for level 0 is not used when iterating over a 2-D
// iteration space.
%sparse_tensor.iterate %iterator in %space at(_, %crd_1)
: !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
```

`sparse_tensor.iterate` can also operate on loop-carried variables.
It returns the final values after loop termination.
The initial values of the variables are passed as additional SSA operands
to the iterator SSA value and used coordinate SSA values mentioned
above. The operation region has an argument for the iterator, variadic
arguments for specified (used) coordiates and followed by one argument
for each loop-carried variable, representing the value of the variable
at the current iteration.
The body region must contain exactly one block that terminates with
`sparse_tensor.yield`.

The results of an `sparse_tensor.iterate` hold the final values after
the last iteration. If the `sparse_tensor.iterate` defines any values,
a yield must be explicitly present.
The number and types of the `sparse_tensor.iterate` results must match
the initial values in the iter_args binding and the yield operands.


A nested `sparse_tensor.iterate` example that prints all the coordinates
stored in the sparse input:

```mlir
func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
// Iterates over the first level of %sp
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
: tensor<4x8xf32, #COO> -> !sparse_tensor.iter_space<#COO, lvls = 0 to 1>
%r1 = sparse_tensor.iterate %it1 in %l1 at (%coord0)
: !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
// Iterates over the second level of %sp
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
-> !sparse_tensor.iter_space<#COO, lvls = 1 to 2>
%r2 = sparse_tensor.iterate %it2 in %l2 at (coord1)
: !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
vector.print %coord0 : index
vector.print %coord1 : index
}
}
}

```
}];

let arguments = (ins AnySparseIterSpace:$iterSpace,
Variadic<AnyType>:$initArgs,
LevelSetAttr:$crdUsedLvls);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);

let extraClassDeclaration = [{
unsigned getSpaceDim() {
return getIterSpace().getType().getSpaceDim();
}
BlockArgument getIterator() {
return getRegion().getArguments().front();
}
Block::BlockArgListType getCrds() {
// The first block argument is iterator, the remaining arguments are
// referenced coordinates.
return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs() {
return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
}
}];

let hasVerifier = 1;
let hasRegionVerifier = 1;
let hasCustomAssemblyFormat = 1;
}

//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading