Skip to content

Commit c6d85ba

Browse files
author
Peiming Liu
authored
[mlir][sparse] implement sparse space collapse pass. (#89003)
1 parent 3af3525 commit c6d85ba

File tree

5 files changed

+259
-0
lines changed

5 files changed

+259
-0
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

+6
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,12 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
248248
bool enableBufferInitialization, unsigned vectorLength,
249249
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen);
250250

251+
//===----------------------------------------------------------------------===//
252+
// Sparse Iteration Transform Passes
253+
//===----------------------------------------------------------------------===//
254+
255+
std::unique_ptr<Pass> createSparseSpaceCollapsePass();
256+
251257
//===----------------------------------------------------------------------===//
252258
// Registration.
253259
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

+16
Original file line numberDiff line numberDiff line change
@@ -464,4 +464,20 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
464464
];
465465
}
466466

467+
//===----------------------------------------------------------------------===//
468+
// Sparse Iteration Transform Passes
469+
//===----------------------------------------------------------------------===//
470+
471+
def SparseSpaceCollapse : Pass<"sparse-space-collapse", "func::FuncOp"> {
472+
let summary = "sparse space collapsing pass";
473+
let description = [{
474+
This pass collapses consecutive sparse spaces (extracted from the same tensor)
475+
into one multi-dimensional space. The pass is not yet stablized.
476+
}];
477+
let constructor = "mlir::createSparseSpaceCollapsePass()";
478+
let dependentDialects = [
479+
"sparse_tensor::SparseTensorDialect",
480+
];
481+
}
482+
467483
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
66
SparseGPUCodegen.cpp
77
SparseReinterpretMap.cpp
88
SparseStorageSpecifierToLLVM.cpp
9+
SparseSpaceCollapse.cpp
910
SparseTensorCodegen.cpp
1011
SparseTensorConversion.cpp
1112
SparseTensorPasses.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
//===--------- SparseSpaceCollapse.cpp - Collapse Sparse Space Pass -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Func/IR/FuncOps.h"
10+
#include "mlir/IR/IRMapping.h"
11+
#include "mlir/Transforms/Passes.h"
12+
13+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
15+
16+
namespace mlir {
17+
#define GEN_PASS_DEF_SPARSESPACECOLLAPSE
18+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
19+
} // namespace mlir
20+
21+
#define DEBUG_TYPE "sparse-space-collapse"
22+
23+
using namespace mlir;
24+
using namespace sparse_tensor;
25+
26+
namespace {
27+
28+
struct CollapseSpaceInfo {
29+
ExtractIterSpaceOp space;
30+
IterateOp loop;
31+
};
32+
33+
bool isCollapsableLoops(LoopLikeOpInterface parent, LoopLikeOpInterface node) {
34+
auto pIterArgs = parent.getRegionIterArgs();
35+
auto nInitArgs = node.getInits();
36+
if (pIterArgs.size() != nInitArgs.size())
37+
return false;
38+
39+
// Two loops are collapsable if they are perfectly nested.
40+
auto pYields = parent.getYieldedValues();
41+
auto nResult = node.getLoopResults().value();
42+
43+
bool yieldEq =
44+
llvm::all_of(llvm::zip_equal(pYields, nResult), [](auto zipped) {
45+
return std::get<0>(zipped) == std::get<1>(zipped);
46+
});
47+
48+
// Parent iter_args should be passed directly to the node's init_args.
49+
bool iterArgEq =
50+
llvm::all_of(llvm::zip_equal(pIterArgs, nInitArgs), [](auto zipped) {
51+
return std::get<0>(zipped) == std::get<1>(zipped);
52+
});
53+
54+
return yieldEq && iterArgEq;
55+
}
56+
57+
bool legalToCollapse(SmallVectorImpl<CollapseSpaceInfo> &toCollapse,
58+
ExtractIterSpaceOp curSpace) {
59+
60+
auto getIterateOpOverSpace = [](ExtractIterSpaceOp space) -> IterateOp {
61+
Value spaceVal = space.getExtractedSpace();
62+
if (spaceVal.hasOneUse())
63+
return llvm::dyn_cast<IterateOp>(*spaceVal.getUsers().begin());
64+
return nullptr;
65+
};
66+
67+
if (toCollapse.empty()) {
68+
// Collapse root.
69+
if (auto itOp = getIterateOpOverSpace(curSpace)) {
70+
CollapseSpaceInfo &info = toCollapse.emplace_back();
71+
info.space = curSpace;
72+
info.loop = itOp;
73+
return true;
74+
}
75+
return false;
76+
}
77+
78+
auto parent = toCollapse.back().space;
79+
auto pItOp = toCollapse.back().loop;
80+
auto nItOp = getIterateOpOverSpace(curSpace);
81+
82+
// Can only collapse spaces extracted from the same tensor.
83+
if (parent.getTensor() != curSpace.getTensor()) {
84+
LLVM_DEBUG({
85+
llvm::dbgs()
86+
<< "failed to collpase spaces extracted from different tensors.";
87+
});
88+
return false;
89+
}
90+
91+
// Can only collapse consecutive simple iteration on one tensor (i.e., no
92+
// coiteration).
93+
if (!nItOp || nItOp->getBlock() != curSpace->getBlock() ||
94+
pItOp.getIterator() != curSpace.getParentIter() ||
95+
curSpace->getParentOp() != pItOp.getOperation()) {
96+
LLVM_DEBUG(
97+
{ llvm::dbgs() << "failed to collapse non-consecutive IterateOps."; });
98+
return false;
99+
}
100+
101+
if (pItOp && !isCollapsableLoops(pItOp, nItOp)) {
102+
LLVM_DEBUG({
103+
llvm::dbgs()
104+
<< "failed to collapse IterateOps that are not perfectly nested.";
105+
});
106+
return false;
107+
}
108+
109+
CollapseSpaceInfo &info = toCollapse.emplace_back();
110+
info.space = curSpace;
111+
info.loop = nItOp;
112+
return true;
113+
}
114+
115+
void collapseSparseSpace(MutableArrayRef<CollapseSpaceInfo> toCollapse) {
116+
if (toCollapse.size() < 2)
117+
return;
118+
119+
ExtractIterSpaceOp root = toCollapse.front().space;
120+
ExtractIterSpaceOp leaf = toCollapse.back().space;
121+
Location loc = root.getLoc();
122+
123+
assert(root->hasOneUse() && leaf->hasOneUse());
124+
125+
// Insert collapsed operation at the same scope as root operation.
126+
OpBuilder builder(root);
127+
128+
// Construct the collapsed iteration space.
129+
auto collapsedSpace = builder.create<ExtractIterSpaceOp>(
130+
loc, root.getTensor(), root.getParentIter(), root.getLoLvl(),
131+
leaf.getHiLvl());
132+
133+
auto rItOp = llvm::cast<IterateOp>(*root->getUsers().begin());
134+
auto innermost = toCollapse.back().loop;
135+
136+
IRMapping mapper;
137+
mapper.map(leaf, collapsedSpace.getExtractedSpace());
138+
for (auto z : llvm::zip_equal(innermost.getInitArgs(), rItOp.getInitArgs()))
139+
mapper.map(std::get<0>(z), std::get<1>(z));
140+
141+
auto cloned = llvm::cast<IterateOp>(builder.clone(*innermost, mapper));
142+
builder.setInsertionPointToStart(cloned.getBody());
143+
144+
LevelSet crdUsedLvls;
145+
unsigned shift = 0, argIdx = 1;
146+
for (auto info : toCollapse.drop_back()) {
147+
LevelSet set = info.loop.getCrdUsedLvls();
148+
crdUsedLvls |= set.lshift(shift);
149+
shift += info.loop.getSpaceDim();
150+
for (BlockArgument crd : info.loop.getCrds()) {
151+
BlockArgument collapsedCrd = cloned.getBody()->insertArgument(
152+
argIdx++, builder.getIndexType(), crd.getLoc());
153+
crd.replaceAllUsesWith(collapsedCrd);
154+
}
155+
}
156+
crdUsedLvls |= innermost.getCrdUsedLvls().lshift(shift);
157+
cloned.getIterator().setType(collapsedSpace.getType().getIteratorType());
158+
cloned.setCrdUsedLvls(crdUsedLvls);
159+
160+
rItOp.replaceAllUsesWith(cloned.getResults());
161+
// Erase collapsed loops.
162+
rItOp.erase();
163+
root.erase();
164+
}
165+
166+
struct SparseSpaceCollapsePass
167+
: public impl::SparseSpaceCollapseBase<SparseSpaceCollapsePass> {
168+
SparseSpaceCollapsePass() = default;
169+
170+
void runOnOperation() override {
171+
func::FuncOp func = getOperation();
172+
173+
// A naive (experimental) implementation to collapse consecutive sparse
174+
// spaces. It does NOT handle complex cases where multiple spaces are
175+
// extracted in the same basic block. E.g.,
176+
//
177+
// %space1 = extract_space %t1 ...
178+
// %space2 = extract_space %t2 ...
179+
// sparse_tensor.iterate(%sp1) ...
180+
//
181+
SmallVector<CollapseSpaceInfo> toCollapse;
182+
func->walk([&](ExtractIterSpaceOp op) {
183+
if (!legalToCollapse(toCollapse, op)) {
184+
// if not legal to collapse one more space, collapse the existing ones
185+
// and clear.
186+
collapseSparseSpace(toCollapse);
187+
toCollapse.clear();
188+
}
189+
});
190+
191+
collapseSparseSpace(toCollapse);
192+
}
193+
};
194+
195+
} // namespace
196+
197+
std::unique_ptr<Pass> mlir::createSparseSpaceCollapsePass() {
198+
return std::make_unique<SparseSpaceCollapsePass>();
199+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-opt %s --sparse-space-collapse | FileCheck %s
2+
3+
#COO = #sparse_tensor.encoding<{
4+
map = (i, j) -> (
5+
i : compressed(nonunique),
6+
j : singleton(soa)
7+
)
8+
}>
9+
10+
// CHECK-LABEL: func.func @sparse_sparse_collapse(
11+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<4x8xf32, #sparse>,
12+
// CHECK-SAME: %[[VAL_1:.*]]: index) {
13+
// CHECK: %[[VAL_3:.*]] = sparse_tensor.extract_iteration_space %[[VAL_0]] lvls = 0 to 2 : tensor<4x8xf32, #sparse>
14+
// CHECK: %[[VAL_4:.*]] = sparse_tensor.iterate %[[VAL_5:.*]] in %[[VAL_3]] at(%[[VAL_6:.*]], _) iter_args(%[[VAL_7:.*]] = %[[VAL_1]])
15+
// CHECK: %[[VAL_8:.*]] = "test.op"(%[[VAL_7]]) : (index) -> index
16+
// CHECK: sparse_tensor.yield %[[VAL_8]] : index
17+
// CHECK: }
18+
// CHECK: "test.sink"(%[[VAL_4]]) : (index) -> ()
19+
// CHECK: return
20+
// CHECK: }
21+
func.func @sparse_sparse_collapse(%sp : tensor<4x8xf32, #COO>, %i : index) {
22+
%l1 = sparse_tensor.extract_iteration_space %sp lvls = 0
23+
: tensor<4x8xf32, #COO>
24+
-> !sparse_tensor.iter_space<#COO, lvls = 0>
25+
%r1 = sparse_tensor.iterate %it1 in %l1 at(%crd0) iter_args(%outer = %i): !sparse_tensor.iter_space<#COO, lvls = 0 to 1> -> index {
26+
%l2 = sparse_tensor.extract_iteration_space %sp at %it1 lvls = 1
27+
: tensor<4x8xf32, #COO>, !sparse_tensor.iterator<#COO, lvls = 0 to 1>
28+
-> !sparse_tensor.iter_space<#COO, lvls = 1>
29+
%r2 = sparse_tensor.iterate %it2 in %l2 iter_args(%inner = %outer): !sparse_tensor.iter_space<#COO, lvls = 1 to 2> -> index {
30+
%k ="test.op"(%inner) : (index) -> index
31+
sparse_tensor.yield %k : index
32+
}
33+
sparse_tensor.yield %r2 : index
34+
}
35+
"test.sink"(%r1) : (index) -> ()
36+
return
37+
}

0 commit comments

Comments
 (0)