Skip to content

Commit 6116ca6

Browse files
committed
[mlir][sparse] Add sparse rewriting rules for tensor::ReshapeOp
Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D149564
1 parent 62a2fef commit 6116ca6

File tree

3 files changed

+234
-1
lines changed

3 files changed

+234
-1
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,106 @@ struct FuseTensorCast : public OpRewritePattern<tensor::CastOp> {
385385
}
386386
};
387387

388+
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
389+
struct TensorReshapeRewriter : public OpRewritePattern<tensor::ReshapeOp> {
390+
public:
391+
using OpRewritePattern<tensor::ReshapeOp>::OpRewritePattern;
392+
393+
LogicalResult matchAndRewrite(tensor::ReshapeOp op,
394+
PatternRewriter &rewriter) const override {
395+
Location loc = op.getLoc();
396+
Value srcTensor = op.getSource();
397+
const auto srcTp = getSparseTensorType(srcTensor);
398+
const auto dstTp = getSparseTensorType(op.getResult());
399+
400+
if (!srcTp.hasEncoding() || !dstTp.hasEncoding() ||
401+
!dstTp.hasStaticDimShape())
402+
return failure();
403+
404+
SmallVector<Value> srcSizes;
405+
sizesForTensor(rewriter, srcSizes, loc, srcTp, srcTensor);
406+
SmallVector<Value> dstSizes;
407+
for (Dimension d : dstTp.getDimShape())
408+
dstSizes.push_back(constantIndex(rewriter, loc, d));
409+
410+
Value nnz = rewriter.create<NumberOfEntriesOp>(loc, srcTensor);
411+
// Only need an unordered COO buffer if input and output are not sorted
412+
// in the same way.
413+
Type bufferTp =
414+
srcTp.isAllOrdered() && srcTp.isIdentity() && dstTp.isIdentity()
415+
? dstTp.getRankedTensorType()
416+
: getUnorderedCOOFromType(dstTp);
417+
SmallVector<Value> dynSizes;
418+
Value buffer = rewriter
419+
.create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(),
420+
nnz, Attribute())
421+
.getResult();
422+
423+
// Convert src coordinates to dst coordinates by first collapsing it to 1D
424+
// and then expand it to the match the rank of the destination tensor.
425+
// Implemented as follows:
426+
// foreach srcCoords %srcTensor
427+
// collapsedCoords = reshapeCvs(srcCoords, [1, ..., srcRank])
428+
// expandedCoords = reshapeCvs(collapsedCoords, [1, ..., dstRank])
429+
// insert expandedCoords, %buffer
430+
//
431+
// followed by an optional
432+
// %t = sparse_tensor.cast %tmp
433+
// depending on whether the input/output are sorted in the same way.
434+
const auto encSrc = srcTp.getEncoding();
435+
ForeachOp foreachOp = rewriter.create<ForeachOp>(
436+
loc, srcTensor, buffer,
437+
[&](OpBuilder &builder, Location loc, ValueRange srcLcvs, Value v,
438+
ValueRange reduc) {
439+
const Dimension srcRank = srcTp.getDimRank();
440+
SmallVector<Value> srcDcvs;
441+
srcDcvs.reserve(srcRank);
442+
for (Dimension d = 0; d < srcRank; d++) {
443+
// FIXME: `toStoredDim` is deprecated
444+
Level lvl = toStoredDim(encSrc, d);
445+
srcDcvs.push_back(srcLcvs[lvl]);
446+
}
447+
448+
Value collapsed_size = constantIndex(builder, loc, 1);
449+
for (Dimension d = 0; d < srcRank; d++)
450+
collapsed_size =
451+
builder.create<arith::MulIOp>(loc, collapsed_size, srcSizes[d]);
452+
SmallVector<Value, 1> collapsedSizes = {collapsed_size};
453+
454+
ReassociationIndices collapse_indices;
455+
for (Dimension i = 0; i < srcRank; i++)
456+
collapse_indices.push_back(i);
457+
SmallVector<ReassociationIndices, 1> collapse_reassociation = {
458+
collapse_indices};
459+
SmallVector<Value, 1> collapsedDcvs;
460+
reshapeCvs(builder, loc, collapse_reassociation, srcSizes, srcDcvs,
461+
collapsedSizes, collapsedDcvs);
462+
463+
ReassociationIndices expand_indices;
464+
for (Dimension i = 0; i < dstTp.getDimRank(); i++)
465+
expand_indices.push_back(i);
466+
SmallVector<ReassociationIndices, 1> expand_reassociation = {
467+
expand_indices};
468+
SmallVector<Value> dstDcvs;
469+
reshapeCvs(builder, loc, expand_reassociation, collapsedSizes,
470+
collapsedDcvs, dstSizes, dstDcvs);
471+
472+
auto t = builder.create<InsertOp>(loc, v, reduc.front(), dstDcvs);
473+
builder.create<sparse_tensor::YieldOp>(loc, t);
474+
});
475+
476+
Value t = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
477+
if (bufferTp != dstTp) {
478+
auto dstRTT = dstTp.getRankedTensorType();
479+
Value converted = rewriter.create<ConvertOp>(loc, dstRTT, t).getResult();
480+
rewriter.create<DeallocTensorOp>(loc, t);
481+
t = converted;
482+
}
483+
rewriter.replaceOp(op, t);
484+
return success();
485+
}
486+
};
487+
388488
/// Sparse rewriting rule for sparse-to-sparse reshape operator.
389489
template <typename ReshapeOp>
390490
struct Sparse2SparseReshapeRewriter : public OpRewritePattern<ReshapeOp> {
@@ -1169,7 +1269,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
11691269
bool enableForeach,
11701270
bool enableConvert) {
11711271
patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
1172-
ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
1272+
ReshapeRewriter<tensor::CollapseShapeOp>, TensorReshapeRewriter>(
1273+
patterns.getContext());
11731274
if (enableForeach)
11741275
patterns.add<ForeachRewriter>(patterns.getContext());
11751276
// TODO: If RT not enabled, rewrite concatenate ops, etc here.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
2+
// RUN: --cse --canonicalize | FileCheck %s
3+
4+
#SparseMatrix = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
5+
6+
// CHECK: func.func @sparse_reshape(
7+
// CHECK-SAME: %[[S:.*]]:
8+
// CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
9+
// CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index
10+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
11+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
12+
// CHECK: %[[B:.*]] = bufferization.alloc_tensor()
13+
// CHECK: %[[P0:.*]] = sparse_tensor.positions %[[S]] {level = 0 : index}
14+
// CHECK: %[[I0:.*]] = sparse_tensor.coordinates %[[S]] {level = 0 : index}
15+
// CHECK: %[[P1:.*]] = sparse_tensor.positions %[[S]] {level = 1 : index}
16+
// CHECK: %[[I1:.*]] = sparse_tensor.coordinates %[[S]] {level = 1 : index}
17+
// CHECK: %[[V:.*]] = sparse_tensor.values %[[S]]
18+
// CHECK: %[[S0:.*]] = memref.load %[[P0]]{{\[}}%[[C0]]] : memref<?xindex>
19+
// CHECK: %[[E0:.*]] = memref.load %[[P0]]{{\[}}%[[C1]]] : memref<?xindex>
20+
// CHECK: %[[RET:.*]] = scf.for %[[I:.*]] = %[[S0]] to %[[E0]] step %[[C1]] iter_args(%[[A0:.*]] = %[[B]])
21+
// CHECK: %[[SI0:.*]] = memref.load %[[I0]]{{\[}}%[[I]]] : memref<?xindex>
22+
// CHECK-DAG: %[[S1:.*]] = memref.load %[[P1]]{{\[}}%[[I]]] : memref<?xindex>
23+
// CHECK-DAG: %[[PE1:.*]] = arith.addi %[[I]], %[[C1]] : index
24+
// CHECK: %[[E1:.*]] = memref.load %[[P1]]{{\[}}%[[PE1]]] : memref<?xindex>
25+
// CHECK: %[[RET_1:.*]] = scf.for %[[J:.*]] = %[[S1]] to %[[E1]] step %[[C1]] iter_args(%[[A1:.*]] = %[[A0]])
26+
// CHECK: %[[SI1:.*]] = memref.load %[[I1]]{{\[}}%[[J]]] : memref<?xindex>
27+
// CHECK: %[[SV:.*]] = memref.load %[[V]]{{\[}}%[[J]]] : memref<?xf64>
28+
// CHECK: %[[T:.*]] = arith.muli %[[SI0]], %[[C25]] : index
29+
// CHECK: %[[DI:.*]] = arith.addi %[[T]], %[[SI1]] : index
30+
// CHECK: %[[D:.*]] = arith.divui %[[DI]], %[[C10]] : index
31+
// CHECK: %[[R:.*]] = arith.remui %[[DI]], %[[C10]] : index
32+
// CHECK: %[[R1:.*]] = sparse_tensor.insert %[[SV]] into %[[A1]]{{\[}}%[[D]], %[[R]]]
33+
// CHECK: scf.yield %[[R1]]
34+
// CHECK: }
35+
// CHECK: scf.yield %[[RET_1]]
36+
// CHECK: }
37+
// CHECK: %[[NT1:.*]] = sparse_tensor.load %[[RET]] hasInserts
38+
// CHECK: return %[[NT1]] : tensor<10x10xf64, #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>>
39+
//
40+
func.func @sparse_reshape(%arg0: tensor<4x25xf64, #SparseMatrix>) -> tensor<10x10xf64, #SparseMatrix> {
41+
%shape = arith.constant dense <[ 10, 10 ]> : tensor<2xi32>
42+
%0 = tensor.reshape %arg0(%shape) :
43+
(tensor<4x25xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<10x10xf64, #SparseMatrix>
44+
return %0 : tensor<10x10xf64, #SparseMatrix>
45+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// DEFINE: %{option} = enable-runtime-library=true
2+
// DEFINE: %{compile} = mlir-opt %s --sparse-compiler=%{option}
3+
// DEFINE: %{run} = mlir-cpu-runner \
4+
// DEFINE: -e entry -entry-point-result=void \
5+
// DEFINE: -shared-libs=%mlir_c_runner_utils | \
6+
// DEFINE: FileCheck %s
7+
//
8+
// RUN: %{compile} | %{run}
9+
//
10+
// Do the same run, but now with direct IR generation.
11+
// REDEFINE: %{option} = enable-runtime-library=false
12+
// RUN: %{compile} | %{run}
13+
//
14+
// Do the same run, but now with direct IR generation and vectorization.
15+
// REDEFINE: %{option} = "enable-runtime-library=false vl=2 reassociate-fp-reductions=true enable-index-optimizations=true"
16+
// RUN: %{compile} | %{run}
17+
18+
#SparseVector = #sparse_tensor.encoding<{
19+
dimLevelType = ["compressed"]
20+
}>
21+
22+
#SparseMatrix = #sparse_tensor.encoding<{
23+
dimLevelType = ["compressed", "compressed"]
24+
}>
25+
26+
#Sparse3dTensor = #sparse_tensor.encoding<{
27+
dimLevelType = ["compressed", "compressed", "compressed"]
28+
}>
29+
30+
module {
31+
32+
func.func @reshape0(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix> {
33+
%shape = arith.constant dense <[ 2, 6 ]> : tensor<2xi32>
34+
%0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<2xi32>) -> tensor<2x6xf64, #SparseMatrix>
35+
return %0 : tensor<2x6xf64, #SparseMatrix>
36+
}
37+
38+
func.func @reshape1(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector> {
39+
%shape = arith.constant dense <[ 12 ]> : tensor<1xi32>
40+
%0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<1xi32>) -> tensor<12xf64, #SparseVector>
41+
return %0 : tensor<12xf64, #SparseVector>
42+
}
43+
44+
func.func @reshape2(%arg0: tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor> {
45+
%shape = arith.constant dense <[ 2, 3, 2 ]> : tensor<3xi32>
46+
%0 = tensor.reshape %arg0(%shape) : (tensor<3x4xf64, #SparseMatrix>, tensor<3xi32>) -> tensor<2x3x2xf64, #Sparse3dTensor>
47+
return %0 : tensor<2x3x2xf64, #Sparse3dTensor>
48+
}
49+
50+
51+
func.func @entry() {
52+
%m = arith.constant dense <[ [ 1.1, 0.0, 1.3, 0.0 ],
53+
[ 2.1, 0.0, 2.3, 0.0 ],
54+
[ 3.1, 0.0, 3.3, 0.0 ]]> : tensor<3x4xf64>
55+
%sm = sparse_tensor.convert %m : tensor<3x4xf64> to tensor<3x4xf64, #SparseMatrix>
56+
57+
%reshaped0 = call @reshape0(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x6xf64, #SparseMatrix>
58+
%reshaped1 = call @reshape1(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<12xf64, #SparseVector>
59+
%reshaped2 = call @reshape2(%sm) : (tensor<3x4xf64, #SparseMatrix>) -> tensor<2x3x2xf64, #Sparse3dTensor>
60+
61+
%c0 = arith.constant 0 : index
62+
%df = arith.constant -1.0 : f64
63+
64+
// CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
65+
%b0 = sparse_tensor.values %reshaped0: tensor<2x6xf64, #SparseMatrix> to memref<?xf64>
66+
%v0 = vector.transfer_read %b0[%c0], %df: memref<?xf64>, vector<12xf64>
67+
vector.print %v0 : vector<12xf64>
68+
69+
// CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
70+
%b1 = sparse_tensor.values %reshaped1: tensor<12xf64, #SparseVector> to memref<?xf64>
71+
%v1 = vector.transfer_read %b1[%c0], %df: memref<?xf64>, vector<12xf64>
72+
vector.print %v1 : vector<12xf64>
73+
74+
// CHECK: ( 1.1, 1.3, 2.1, 2.3, 3.1, 3.3
75+
%b2 = sparse_tensor.values %reshaped2: tensor<2x3x2xf64, #Sparse3dTensor> to memref<?xf64>
76+
%v2 = vector.transfer_read %b2[%c0], %df: memref<?xf64>, vector<12xf64>
77+
vector.print %v2: vector<12xf64>
78+
79+
bufferization.dealloc_tensor %sm : tensor<3x4xf64, #SparseMatrix>
80+
bufferization.dealloc_tensor %reshaped0 : tensor<2x6xf64, #SparseMatrix>
81+
bufferization.dealloc_tensor %reshaped1 : tensor<12xf64, #SparseVector>
82+
bufferization.dealloc_tensor %reshaped2 : tensor<2x3x2xf64, #Sparse3dTensor>
83+
84+
return
85+
}
86+
87+
}

0 commit comments

Comments
 (0)