Skip to content

Commit a43d79a

Browse files
author
Peiming Liu
authored
[mlir][sparse] add canonicalization patterns for IterateOp. (llvm#95569)
1 parent 29d857f commit a43d79a

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,6 +1601,13 @@ def IterateOp : SparseTensor_Op<"iterate",
16011601
BlockArgument getIterator() {
16021602
return getRegion().getArguments().front();
16031603
}
1604+
std::optional<BlockArgument> getLvlCrd(Level lvl) {
1605+
if (getCrdUsedLvls()[lvl]) {
1606+
uint64_t mask = (static_cast<uint64_t>(0x01u) << lvl) - 1;
1607+
return getCrds()[llvm::popcount(mask & getCrdUsedLvls())];
1608+
}
1609+
return std::nullopt;
1610+
}
16041611
Block::BlockArgListType getCrds() {
16051612
// The first block argument is iterator, the remaining arguments are
16061613
// referenced coordinates.
@@ -1613,6 +1620,7 @@ def IterateOp : SparseTensor_Op<"iterate",
16131620

16141621
let hasVerifier = 1;
16151622
let hasRegionVerifier = 1;
1623+
let hasCanonicalizer = 1;
16161624
let hasCustomAssemblyFormat = 1;
16171625
}
16181626

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/Matchers.h"
2525
#include "mlir/IR/OpImplementation.h"
2626
#include "mlir/IR/PatternMatch.h"
27+
#include "llvm/ADT/Bitset.h"
2728
#include "llvm/ADT/TypeSwitch.h"
2829
#include "llvm/Support/FormatVariadic.h"
2930

@@ -2266,6 +2267,39 @@ LogicalResult ExtractIterSpaceOp::verify() {
22662267
return success();
22672268
}
22682269

2270+
struct RemoveUnusedLvlCrds : public OpRewritePattern<IterateOp> {
2271+
using OpRewritePattern::OpRewritePattern;
2272+
2273+
LogicalResult matchAndRewrite(IterateOp iterateOp,
2274+
PatternRewriter &rewriter) const override {
2275+
LevelSet newUsedLvls(0);
2276+
llvm::BitVector toRemove(iterateOp.getBody()->getNumArguments());
2277+
for (unsigned i = 0, e = iterateOp.getSpaceDim(); i < e; i++) {
2278+
if (auto crd = iterateOp.getLvlCrd(i)) {
2279+
if (crd->getUsers().empty())
2280+
toRemove.set(crd->getArgNumber());
2281+
else
2282+
newUsedLvls.set(i);
2283+
}
2284+
}
2285+
2286+
// All coordinates are used.
2287+
if (toRemove.none())
2288+
return failure();
2289+
2290+
rewriter.startOpModification(iterateOp);
2291+
iterateOp.setCrdUsedLvls(newUsedLvls);
2292+
iterateOp.getBody()->eraseArguments(toRemove);
2293+
rewriter.finalizeOpModification(iterateOp);
2294+
return success();
2295+
}
2296+
};
2297+
2298+
void IterateOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results,
2299+
mlir::MLIRContext *context) {
2300+
results.add<RemoveUnusedLvlCrds>(context);
2301+
}
2302+
22692303
ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
22702304
OpAsmParser::Argument iterator;
22712305
OpAsmParser::UnresolvedOperand iterSpace;

mlir/test/Dialect/SparseTensor/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,21 @@ func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : i
2121
%0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
2222
return %0 : tensor<?x?x?xf32, #BCOO>
2323
}
24+
25+
// -----
26+
27+
#CSR = #sparse_tensor.encoding<{
28+
map = (i, j) -> (i : dense, j : compressed)
29+
}>
30+
31+
// Make sure that the first unused coordinate is optimized.
32+
// CHECK-LABEL: @sparse_iterate_canonicalize
33+
// CHECK: sparse_tensor.iterate {{.*}} at(_, %{{.*}})
34+
func.func @sparse_iterate_canonicalize(%sp : tensor<?x?xf64, #CSR>) {
35+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0 to 2
36+
: tensor<?x?xf64, #CSR> -> !sparse_tensor.iter_space<#CSR, lvls = 0 to 2>
37+
sparse_tensor.iterate %it1 in %l1 at (%coord0, %coord1) : !sparse_tensor.iter_space<#CSR, lvls = 0 to 2> {
38+
"test.op"(%coord1) : (index) -> ()
39+
}
40+
return
41+
}

0 commit comments

Comments
 (0)