-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][sparse] introduce sparse_tensor.iterate operation #88807
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extracts a sparse iteration space to iterate over.
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesA DO NOT MERGE before #88554 Patch is 37.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88807.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 5e523ec428aefb..081a9b8cad8d62 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -17,9 +17,13 @@
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TensorEncoding.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
+#include "llvm/ADT/bit.h"
+
//===----------------------------------------------------------------------===//
//
// Type aliases to help code be more self-documenting. Unfortunately
@@ -41,6 +45,40 @@ using Level = uint64_t;
/// including the value `ShapedType::kDynamic` (for shapes).
using Size = int64_t;
+/// A simple wrapper to encode a bitset of defined (at most 64) levels.
+class LevelSet {
+ uint64_t bits = 0;
+
+public:
+ LevelSet() = default;
+ explicit LevelSet(uint64_t bits) : bits(bits) {}
+ operator uint64_t() const { return bits; }
+
+ LevelSet &set(unsigned i) {
+ assert(i < 64);
+ bits |= 1 << i;
+ return *this;
+ }
+
+ LevelSet &operator|=(LevelSet lhs) {
+ bits |= static_cast<uint64_t>(lhs);
+ return *this;
+ }
+
+ LevelSet &lshift(unsigned offset) {
+ bits = bits << offset;
+ return *this;
+ }
+
+ bool operator[](unsigned i) const {
+ assert(i < 64);
+ return (bits & (1 << i)) != 0;
+ }
+
+ unsigned count() const { return llvm::popcount(bits); }
+ bool empty() const { return bits == 0; }
+};
+
} // namespace sparse_tensor
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
index 4a9b9169ae4b86..d5398a98f5b171 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td
@@ -19,6 +19,21 @@ class SparseTensor_Attr<string name,
list<Trait> traits = []>
: AttrDef<SparseTensor_Dialect, name, traits>;
+//===----------------------------------------------------------------------===//
+// A simple bitset attribute wrapped over a single int64_t to encode a set of
+// sparse tensor levels.
+//===----------------------------------------------------------------------===//
+
+def LevelSetAttr :
+ TypedAttrBase<
+ I64, "IntegerAttr",
+ And<[CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
+ CPred<"::llvm::cast<::mlir::IntegerAttr>($_self).getType().isInteger(64)">]>,
+ "LevelSet attribute"> {
+ let returnType = [{::mlir::sparse_tensor::LevelSet}];
+ let convertFromStorage = [{::mlir::sparse_tensor::LevelSet($_self.getValue().getZExtValue())}];
+}
+
//===----------------------------------------------------------------------===//
// These attributes are just like `IndexAttr` except that they clarify whether
// the index refers to a dimension (an axis of the semantic tensor) or a level
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 0cfc64f9988a0a..b43d716d5e8642 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -15,6 +15,8 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td"
include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/Interfaces/ControlFlowInterfaces.td"
+include "mlir/Interfaces/LoopLikeInterface.td"
//===----------------------------------------------------------------------===//
// Base class.
@@ -1277,7 +1279,7 @@ def SparseTensor_SelectOp : SparseTensor_Op<"select", [Pure, SameOperandsAndResu
def SparseTensor_YieldOp : SparseTensor_Op<"yield", [Pure, Terminator,
ParentOneOf<["BinaryOp", "UnaryOp", "ReduceOp", "SelectOp",
- "ForeachOp"]>]>,
+ "ForeachOp", "IterateOp"]>]>,
Arguments<(ins Variadic<AnyType>:$results)> {
let summary = "Yield from sparse_tensor set-like operations";
let description = [{
@@ -1430,6 +1432,154 @@ def SparseTensor_ForeachOp : SparseTensor_Op<"foreach",
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+def ExtractIterSpaceOp : SparseTensor_Op<"extract_iteration_space",
+ [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+
+ let arguments = (ins AnySparseTensor:$tensor,
+ Optional<AnySparseIterator>:$parentIter,
+ LevelAttr:$loLvl, LevelAttr:$hiLvl);
+
+ let results = (outs AnySparseIterSpace:$resultSpace);
+
+ let summary = "Extract an iteration space from a sparse tensor between certain levels";
+ let description = [{
+ Extracts a `!sparse_tensor.iter_space` from a sparse tensor between
+ certian (consecutive) levels.
+
+ `tensor`: the input sparse tensor that defines the iteration space.
+ `parentIter`: the iterator for the previous level, at which the iteration space
+ at the current levels will be extracted.
+ `loLvl`, `hiLvl`: the level range between [loLvl, hiLvl) in the input tensor that
+ the returned iteration space covers. `hiLvl - loLvl` defines the dimension of the
+ iteration space.
+
+ Example:
+ ```mlir
+ // Extracts a 1-D iteration space from a COO tensor at level 1.
+ %space = sparse_tensor.iteration.extract_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0>
+ ```
+ }];
+
+
+ let extraClassDeclaration = [{
+ std::pair<Level, Level> getLvlRange() {
+ return std::make_pair(getLoLvl(), getHiLvl());
+ }
+ unsigned getSpaceDim() {
+ return getHiLvl() - getLoLvl();
+ }
+ ArrayRef<::mlir::sparse_tensor::LevelType> getSpaceLvlTypes() {
+ return getResultSpace().getType().getLvlTypes();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let assemblyFormat = "$tensor (`at` $parentIter^)? `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) "
+ " attr-dict `:` type($tensor) (`,` type($parentIter)^)?";
+}
+
+def IterateOp : SparseTensor_Op<"iterate",
+ [RecursiveMemoryEffects, RecursivelySpeculatable,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface,
+ ["getInitsMutable", "getLoopResults", "getRegionIterArgs",
+ "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands"]>,
+ SingleBlockImplicitTerminator<"sparse_tensor::YieldOp">]> {
+
+ let arguments = (ins AnySparseIterSpace:$iterSpace,
+ Variadic<AnyType>:$initArgs,
+ LevelSetAttr:$crdUsedLvls);
+ let results = (outs Variadic<AnyType>:$results);
+ let regions = (region SizedRegion<1>:$region);
+
+ let summary = "Iterate over a sparse iteration space";
+ let description = [{
+ The `sparse_tensor.iterate` operations represents a loop over the
+ provided iteration space extracted from a specific sparse tensor.
+ The operation defines an SSA value for a sparse iterator that points
+ to the current stored element in the sparse tensor and SSA values
+ for coordinates of the stored element. The coordinates are always
+ converted to `index` type despite of the underlying sparse tensor
+ storage. When coordinates are not used, the SSA values can be skipped
+ by `_` symbols, which usually leads to simpler generated code after
+ sparsification. For example:
+
+ ```mlir
+ // The coordinate for level 0 is not used when iterating over a 2-D
+ // iteration space.
+ %sparse_tensor.iterate %iterator in %space at(_, %crd_1)
+ : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
+ ```
+
+ `sparse_tensor.iterate` can also operate on loop-carried variables
+ and returns the final values after loop termination.
+ The initial values of the variables are passed as additional SSA operands
+ to the iterator SSA value and used coordinate SSA values mentioned
+ above. The operation region has an argument for the iterator, variadic
+ arguments for specified (used) coordiates and followed by one argument
+ for each loop-carried variable, representing the value of the variable
+ at the current iteration.
+ The body region must contain exactly one block that terminates with
+ `sparse_tensor.yield`.
+
+ `sparse_tensor.iterate` results hold the final values after the last
+ iteration. If the `sparse_tensor.iterate` defines any values, a yield
+ must be explicitly present.
+ The number and types of the `sparse_tensor.iterate` results must match
+ the initial values in the iter_args binding and the yield operands.
+
+
+ A nested `sparse_tensor.iterate` example that prints all the coordinates
+ stored in the sparse input:
+
+ ```mlir
+ func.func @nested_iterate(%sp : tensor<4x8xf32, #COO>) {
+ // Iterates over the first level of %sp
+ %l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 : tensor<4x8xf32, #COO>
+ %r1 = sparse_tensor.iterate %it1 in %l1 at (%crd0)
+ : !sparse_tensor.iter_space<#COO, lvls = 0 to 1> {
+ // Iterates over the second level of %sp
+ %l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
+ : tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
+ %r2 = sparse_tensor.iterate %it2 in %l2 at (crd1)
+ : !sparse_tensor.iter_space<#COO, lvls = 1 to 2> {
+ vector.print %crd0 : index
+ vector.print %crd1 : index
+ }
+ }
+ }
+
+ ```
+ }];
+
+ let extraClassDeclaration = [{
+ unsigned getSpaceDim() {
+ return getIterSpace().getType().getSpaceDim();
+ }
+ BlockArgument getIterator() {
+ return getRegion().getArguments().front();
+ }
+ Block::BlockArgListType getCrds() {
+ // The first block argument is iterator, the remaining arguments are
+ // referenced coordinates.
+ return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+ }
+ unsigned getNumRegionIterArgs() {
+ return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
+ }
+ }];
+
+ let hasVerifier = 1;
+ let hasRegionVerifier = 1;
+ let hasCustomAssemblyFormat = 1;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Debugging and Test-Only Operations.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
index 185cff46ae25d5..264a0a5b3bee6c 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorTypes.td
@@ -72,4 +72,99 @@ def SparseTensorStorageSpecifier
: Type<CPred<"::llvm::isa<::mlir::sparse_tensor::StorageSpecifierType>($_self)">, "metadata",
"::mlir::sparse_tensor::StorageSpecifierType">;
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Types.
+//===----------------------------------------------------------------------===//
+
+def SparseTensor_IterSpace : SparseTensor_Type<"IterSpace"> {
+ let mnemonic = "iter_space";
+
+ let description = [{
+ A sparse iteration space that represents an abstract N-D (sparse) iteration space
+ extracted from a sparse tensor.
+
+ Examples:
+
+ ```mlir
+ // An iteration space extracted from a CSR tensor between levels [0, 2).
+ !iter_space<#CSR, lvls = 0 to 2>
+ ```
+ }];
+
+ let parameters = (ins
+ SparseTensorEncodingAttr : $encoding,
+ "Level" : $loLvl,
+ "Level" : $hiLvl
+ );
+
+ let extraClassDeclaration = [{
+ /// The the dimension of the iteration space.
+ unsigned getSpaceDim() const {
+ return getHiLvl() - getLoLvl();
+ }
+
+ /// Get the level types for the iteration space.
+ ArrayRef<LevelType> getLvlTypes() const {
+ return getEncoding().getLvlTypes().slice(getLoLvl(), getSpaceDim());
+ }
+
+ /// Whether the iteration space is unique (i.e., no duplicated coordinate).
+ bool isUnique() {
+ return !getLvlTypes().back().isa<LevelPropNonDefault::Nonunique>();
+ }
+
+ /// Get the corresponding iterator type.
+ ::mlir::sparse_tensor::IteratorType getIteratorType() const;
+ }];
+
+ let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def SparseTensor_Iterator : SparseTensor_Type<"Iterator"> {
+ let mnemonic = "iterator";
+
+ let description = [{
+ An iterator that points to the current element in the corresponding iteration space.
+
+ Examples:
+
+ ```mlir
+ // An iterator that iterates over a iteration space of type `!iter_space<#CSR, lvls = 0 to 2>`
+ !iterator<#CSR, lvls = 0 to 2>
+ ```
+ }];
+
+ let parameters = (ins
+ SparseTensorEncodingAttr : $encoding,
+ "Level" : $loLvl,
+ "Level" : $hiLvl
+ );
+
+ let extraClassDeclaration = [{
+ /// Get the corresponding iteration space type.
+ ::mlir::sparse_tensor::IterSpaceType getIterSpaceType() const;
+
+ unsigned getSpaceDim() const { return getIterSpaceType().getSpaceDim(); }
+ ArrayRef<LevelType> getLvlTypes() const { return getIterSpaceType().getLvlTypes(); }
+ bool isUnique() { return getIterSpaceType().isUnique(); }
+ }];
+
+ let assemblyFormat="`<` $encoding `,` `lvls` `=` custom<LevelRange>($loLvl, $hiLvl) `>`";
+}
+
+def IsSparseSparseIterSpaceTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IterSpaceType>($_self)">;
+
+def IsSparseSparseIteratorTypePred
+ : CPred<"::llvm::isa<::mlir::sparse_tensor::IteratorType>($_self)">;
+
+def AnySparseIterSpace
+ : Type<IsSparseSparseIterSpaceTypePred, "sparse iteration space",
+ "::mlir::sparse_tensor::IterSpaceType">;
+
+def AnySparseIterator
+ : Type<IsSparseSparseIteratorTypePred, "sparse iterator",
+ "::mlir::sparse_tensor::IteratorType">;
+
+
#endif // SPARSETENSOR_TYPES
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e9058394d33da5..36908def09f403 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -30,6 +30,14 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
+// Forward declarations, following custom print/parsing methods are referenced
+// by the generated code for SparseTensorTypes.td.
+static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
+ mlir::sparse_tensor::Level &,
+ mlir::sparse_tensor::Level &);
+static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
+ mlir::sparse_tensor::Level);
+
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
@@ -1953,6 +1961,363 @@ LogicalResult SortOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+IterSpaceType IteratorType::getIterSpaceType() const {
+ return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
+ getHiLvl());
+}
+
+IteratorType IterSpaceType::getIteratorType() const {
+ return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
+}
+
+/// Parses a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
+static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
+ Level &lvlHi) {
+ if (parser.parseInteger(lvlLo))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ if (parser.parseInteger(lvlHi))
+ return failure();
+ } else {
+ lvlHi = lvlLo + 1;
+ }
+
+ if (lvlHi <= lvlLo)
+ parser.emitError(parser.getNameLoc(),
+ "expect larger level upper bound than lower bound");
+
+ return success();
+}
+
+/// Parses a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+ IntegerAttr &lvlHiAttr) {
+ Level lvlLo, lvlHi;
+ if (parseLevelRange(parser, lvlLo, lvlHi))
+ return failure();
+
+ lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+ lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+ return success();
+}
+
+/// Prints a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
+static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
+
+ if (lo + 1 == hi)
+ p << lo;
+ else
+ p << lo << " to " << hi;
+}
+
+/// Prints a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+ IntegerAttr lvlHi) {
+ unsigned lo = lvlLo.getValue().getZExtValue();
+ unsigned hi = lvlHi.getValue().getZExtValue();
+ printLevelRange(p, lo, hi);
+}
+
+static ParseResult
+parseSparseSpaceLoop(OpAsmParser &parser, OperationState &state,
+ SmallVectorImpl<OpAsmParser::Argument> &iterators,
+ SmallVectorImpl<OpAsmParser::Argument> &iterArgs) {
+ SmallVector<OpAsmParser::UnresolvedOperand> spaces;
+ SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
+
+ // Parses "%iters, ... in %spaces, ..."
+ if (parser.parseArgumentList(iterators) || parser.parseKeyword("in") ||
+ parser.parseOperandList(spaces))
+ return failure();
+
+ if (iterators.size() != spaces.size())
+ return parser.emitError(
+ parser.getNameLoc(),
+ "mismatch in number of sparse iterators and sparse spaces");
+
+ // Parse "at(%crd0, _, ...)"
+ LevelSet crdUsedLvlSet;
+ bool hasUsedCrds = succeeded(parser.parseOptionalKeyword("at"));
+ unsigned lvlCrdCnt = 0;
+ if (hasUsedCrds) {
+ ParseResult crdList = parser.parseCommaSeparatedList(
+ OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
+ if (parser.parseOptionalKeyword("_")) {
+ if (parser.parseArgument(iterArgs.emplace_back()))
+ return failure();
+ // Always use IndexType for the coordinate.
+ crdUsedLvlSet.set(lvlCrdCnt);
+ iterArgs.back().type = parser.getBuilder().getIndexType();
+ }
+ lvlCrdCnt += 1;
+ return success();
+ });
+ if (failed(crdList)) {
+ return parser.emitError(
+ parser.getNameLoc(),
+ "expecting SSA value or \"_\" for level coordinates");
+ }
+ }
+ // Set the CrdUsedLvl bitset.
+ state.addAttribute("crdUsedLvls",
+ parser.getBuilder().getI64IntegerAttr(crdUsedLvlSet));
+
+ // Parse "iter_args(%arg = %init, ...)"
+ bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
+ if (hasIterArgs)
+ if (parser.parseAssignmentList(iterArgs, initArgs))
+ return failure();
+
+ SmallVector<Type> iterSpaceTps;
+ // parse ": sparse_tensor.iter_space -> ret"
+ if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
+ return failure();
+ if (iterSpaceTps.size() != spaces.size())
+ return parser.emitError(parser.getNameLoc(),
+ "mismatch in number of iteration space operands "
+ "and iteration space types");
+
+ for (auto [it, tp] : llvm::zip_equal(iterators, iterSpaceTps)) {
+ IterSpaceType spaceTp = llvm::dyn...
[truncated]
|
PeimingLiu
pushed a commit
that referenced
this pull request
Apr 16, 2024
)" This reverts commit 8debcf0.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
A
sparse_tensor.iterate
iterates over a sparse iteration space extracted fromsparse_tensor.extract_iteration_space
operation introduced in #88554.DO NOT MERGE before #88554