Skip to content

Commit 48a73bc

Browse files
[mlir][sparse] Extract StorageSpecifierToLLVMPass from bufferization pipeline (#68635)
`StorageSpecifierToLLVMPass` does not have to be part of the bufferization mini pipeline. It can run after the bufferization pipeline. This is desirable because it keeps the bufferization pipeline smaller. Also fix incorrect bufferization API usage: `bufferizeOp` instead of `bufferizeModuleOp` was used, even though function boundaries were bufferized.
1 parent 9d34c05 commit 48a73bc

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
4242
/*enableSIMDIndex32=*/options.force32BitVectorIndices));
4343
if (options.testBufferizationAnalysisOnly)
4444
return;
45+
46+
pm.addPass(createStorageSpecifierToLLVMPass());
4547
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
4648
pm.addNestedPass<func::FuncOp>(
4749
mlir::bufferization::createFinalizingBufferizePass());

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace sparse_tensor {
3939
/// Return `true` if one of the given types is a sparse tensor type.
4040
static bool containsSparseTensor(TypeRange types) {
4141
for (Type t : types)
42-
if (getSparseTensorEncoding(t))
42+
if (isa<TensorType>(t) && getSparseTensorEncoding(t))
4343
return true;
4444
return false;
4545
}
@@ -97,7 +97,8 @@ class SparsificationAndBufferizationPass
9797
return false;
9898
});
9999

100-
if (failed(bufferization::bufferizeOp(getOperation(), updatedOptions)))
100+
if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
101+
updatedOptions)))
101102
return failure();
102103

103104
bufferization::removeBufferizationAttributesInModule(getOperation());
@@ -154,7 +155,6 @@ class SparsificationAndBufferizationPass
154155
pm.addPass(createSparseTensorCodegenPass(createSparseDeallocs,
155156
enableBufferInitialization));
156157
pm.addPass(createSparseBufferRewritePass(enableBufferInitialization));
157-
pm.addPass(createStorageSpecifierToLLVMPass());
158158
}
159159
if (failed(runPipeline(pm, getOperation())))
160160
return signalPassFailure();

0 commit comments

Comments
 (0)