Skip to content

Commit 6a45339

Browse files
authored
[mlir][sparse] refine sparse fusion with empty tensors materialization (#66563)
This is a minor step towards deprecating bufferization.alloc_tensor(). It replaces the examples with tensor.empty() and adjusts the underlying rewriting logic to prepare for this upcoming change.
1 parent f71a9e8 commit 6a45339

File tree

2 files changed

+42
-42
lines changed

2 files changed

+42
-42
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ static bool isSparseTensor(Value v) {
5050
}
5151
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }
5252

53-
// Helper method to find zero/uninitialized allocation.
54-
static bool isAlloc(OpOperand *op, bool isZero) {
53+
// Helper method to find zero/uninitialized tensor materialization.
54+
static bool isMaterializing(OpOperand *op, bool isZero) {
5555
Value val = op->get();
5656
// Check allocation, with zero alloc when required.
5757
if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
@@ -60,6 +60,9 @@ static bool isAlloc(OpOperand *op, bool isZero) {
6060
return copy && isZeroValue(copy);
6161
return !copy;
6262
}
63+
// Check for empty tensor materialization.
64+
if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
65+
return !isZero;
6366
// Last resort for zero alloc: the whole value is zero.
6467
return isZero && isZeroValue(val);
6568
}
@@ -219,24 +222,22 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
219222
LogicalResult matchAndRewrite(GenericOp op,
220223
PatternRewriter &rewriter) const override {
221224
if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
222-
!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
225+
!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
223226
!isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
224227
return failure();
225228
auto outputType = getRankedTensorType(op.getResult(0));
226-
// Yielding zero on newly allocated (all-zero) sparse tensors can be
227-
// optimized out directly (regardless of dynamic or static size).
229+
// Yielding zero on newly materialized sparse tensor can be
230+
// optimized directly (regardless of dynamic or static size).
228231
if (getSparseTensorEncoding(outputType)) {
229232
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
230233
return success();
231234
}
232-
// Incorporate zero value into allocation copy.
235+
// Use static zero value directly instead of materialization.
233236
if (!outputType.hasStaticShape())
234237
return failure();
235-
Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
236-
AllocTensorOp a =
237-
op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
238-
rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
239-
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
238+
Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
239+
rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
240+
rewriter.eraseOp(def);
240241
return success();
241242
}
242243
};
@@ -286,8 +287,8 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
286287
!prod.getResult(0).hasOneUse())
287288
return failure();
288289
// Sampling consumer and sum of multiplication chain producer.
289-
if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
290-
!isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) ||
290+
if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
291+
!isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
291292
!isSampling(op) || !isSumOfMul(prod))
292293
return failure();
293294
// Modify operand structure of producer and consumer.
@@ -327,6 +328,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
327328
last = rewriter.clone(*acc, mapper)->getResult(0);
328329
rewriter.create<linalg::YieldOp>(loc, last);
329330
// Force initial value on merged allocation for dense outputs.
331+
// TODO: deal with non alloc tensor here one day
330332
if (!getSparseTensorEncoding(op.getResult(0).getType())) {
331333
Value init = prod.getDpsInitOperand(0)
332334
->get()

mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,12 @@
2121
}
2222

2323
// CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
24-
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
25-
// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) : tensor<1024x1024xf64>
26-
// CHECK: return %[[VAL_1]] : tensor<1024x1024xf64>
24+
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
25+
// CHECK: return %[[C0]] : tensor<1024x1024xf64>
2726
// CHECK: }
2827
func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
2928
%cst = arith.constant 0.000000e+00 : f64
30-
%0 = bufferization.alloc_tensor() : tensor<1024x1024xf64>
29+
%0 = tensor.empty() : tensor<1024x1024xf64>
3130
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>,
3231
affine_map<(d0, d1) -> (d0, d1)>],
3332
iterator_types = ["parallel", "parallel"]}
@@ -40,13 +39,12 @@ func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
4039
}
4140

4241
// CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> {
43-
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
44-
// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) : tensor<32xf64>
45-
// CHECK: return %[[VAL_1]] : tensor<32xf64>
42+
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
43+
// CHECK: return %[[C0]] : tensor<32xf64>
4644
// CHECK: }
4745
func.func @fold_yield_direct_zero() -> tensor<32xf64> {
4846
%cst = arith.constant 0.000000e+00 : f64
49-
%0 = bufferization.alloc_tensor() : tensor<32xf64>
47+
%0 = tensor.empty() : tensor<32xf64>
5048
%1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
5149
iterator_types = ["parallel"]}
5250
outs(%0 : tensor<32xf64>) {
@@ -92,9 +90,9 @@ func.func @fold_yield_direct_zero() -> tensor<32xf64> {
9290
// CHECK: %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f64
9391
// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_28]], %[[VAL_32]] : f64
9492
// CHECK: memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_27]]] : memref<8x8xf64>
95-
// CHECK: } {"Emitted from" = "linalg.generic"}
96-
// CHECK: } {"Emitted from" = "linalg.generic"}
97-
// CHECK: } {"Emitted from" = "linalg.generic"}
93+
// CHECK: }
94+
// CHECK: }
95+
// CHECK: }
9896
// CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64>
9997
// CHECK: return %[[VAL_34]] : tensor<8x8xf64>
10098
// CHECK: }
@@ -123,29 +121,29 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
123121
}
124122

125123
// CHECK-LABEL: func.func @sparse_sampled_dd_unfused(
126-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>,
124+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>,
127125
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
128-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> {
126+
// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
129127
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
130128
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
131129
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
132130
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant false
133131
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant true
134132
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
135-
// CHECK-DAG: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) : tensor<8x8xf64>
136-
// CHECK-DAG: %[[VAL_10:.*]] = bufferization.alloc_tensor() : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
133+
// CHECK-DAG: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
134+
// CHECK-DAG: %[[VAL_10:.*]] = tensor.empty() : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
137135
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
138136
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
139-
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
140-
// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
141-
// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
142-
// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
143-
// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>
137+
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
138+
// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
139+
// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
140+
// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
141+
// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
144142
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
145143
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
146-
// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>) {
144+
// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
147145
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_21]]] : memref<?xindex>
148-
// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
146+
// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
149147
// CHECK: %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_27]]) -> (index) {
150148
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_29]]] : memref<8x8xf64>
151149
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_21]]] : memref<?xindex>
@@ -170,15 +168,15 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
170168
// CHECK: scf.yield %[[VAL_37]] : index
171169
// CHECK: }
172170
// CHECK: memref.store %[[VAL_44]], %[[VAL_24]]{{\[}}%[[VAL_38]]] : memref<?xf64>
173-
// CHECK: scf.yield %[[VAL_49:.*]] : index
171+
// CHECK: scf.yield %[[VAL_47]] : index
174172
// CHECK: }
175-
// CHECK: scf.yield %[[VAL_50:.*]] : index
173+
// CHECK: scf.yield %[[VAL_35]] : index
176174
// CHECK: }
177-
// CHECK: %[[VAL_51:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_52:.*]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
178-
// CHECK: scf.yield %[[VAL_51]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
175+
// CHECK: %[[VAL_49:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_28]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
176+
// CHECK: scf.yield %[[VAL_49]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
179177
// CHECK: }
180-
// CHECK: %[[VAL_53:.*]] = sparse_tensor.load %[[VAL_54:.*]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
181-
// CHECK: return %[[VAL_53]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
178+
// CHECK: %[[VAL_50:.*]] = sparse_tensor.load %[[VAL_20]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
179+
// CHECK: return %[[VAL_50]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
182180
// CHECK: }
183181
func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
184182
%arga: tensor<8x8xf64>,
@@ -194,7 +192,7 @@ func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
194192
linalg.yield %q : f64
195193
} -> tensor<8x8xf64>
196194
// Sample the result with elements-wise multiplication with sparse matrix.
197-
%3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
195+
%3 = tensor.empty() : tensor<8x8xf64, #SM>
198196
%4 = linalg.generic #trait_scale
199197
ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
200198
outs(%3 : tensor<8x8xf64, #SM>) {

0 commit comments

Comments
 (0)