Skip to content

[mlir][sparse] refine sparse fusion with empty tensors materialization #66563

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ static bool isSparseTensor(Value v) {
}
static bool isSparseTensor(OpOperand *op) { return isSparseTensor(op->get()); }

// Helper method to find zero/uninitialized allocation.
static bool isAlloc(OpOperand *op, bool isZero) {
// Helper method to find zero/uninitialized tensor materialization.
static bool isMaterializing(OpOperand *op, bool isZero) {
Value val = op->get();
// Check allocation, with zero alloc when required.
if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
Expand All @@ -60,6 +60,9 @@ static bool isAlloc(OpOperand *op, bool isZero) {
return copy && isZeroValue(copy);
return !copy;
}
// Check for empty tensor materialization.
if (auto empty = val.getDefiningOp<tensor::EmptyOp>())
return !isZero;
// Last resort for zero alloc: the whole value is zero.
return isZero && isZeroValue(val);
}
Expand Down Expand Up @@ -219,24 +222,22 @@ struct FoldInvariantYield : public OpRewritePattern<GenericOp> {
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override {
if (!op.hasTensorSemantics() || op.getNumResults() != 1 ||
!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isZeroYield(op) || !op.getDpsInitOperand(0)->get().hasOneUse())
return failure();
auto outputType = getRankedTensorType(op.getResult(0));
// Yielding zero on newly allocated (all-zero) sparse tensors can be
// optimized out directly (regardless of dynamic or static size).
// Yielding zero on newly materialized sparse tensor can be
// optimized directly (regardless of dynamic or static size).
if (getSparseTensorEncoding(outputType)) {
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
return success();
}
// Incorporate zero value into allocation copy.
// Use static zero value directly instead of materialization.
if (!outputType.hasStaticShape())
return failure();
Value zero = constantZero(rewriter, op.getLoc(), op.getResult(0).getType());
AllocTensorOp a =
op.getDpsInitOperand(0)->get().getDefiningOp<AllocTensorOp>();
rewriter.updateRootInPlace(a, [&]() { a.getCopyMutable().assign(zero); });
rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
Operation *def = op.getDpsInitOperand(0)->get().getDefiningOp();
rewriter.replaceOp(op, constantZero(rewriter, op.getLoc(), outputType));
rewriter.eraseOp(def);
return success();
}
};
Expand Down Expand Up @@ -286,8 +287,8 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
!prod.getResult(0).hasOneUse())
return failure();
// Sampling consumer and sum of multiplication chain producer.
if (!isAlloc(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isAlloc(prod.getDpsInitOperand(0), /*isZero=*/true) ||
if (!isMaterializing(op.getDpsInitOperand(0), /*isZero=*/false) ||
!isMaterializing(prod.getDpsInitOperand(0), /*isZero=*/true) ||
!isSampling(op) || !isSumOfMul(prod))
return failure();
// Modify operand structure of producer and consumer.
Expand Down Expand Up @@ -327,6 +328,7 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
last = rewriter.clone(*acc, mapper)->getResult(0);
rewriter.create<linalg::YieldOp>(loc, last);
// Force initial value on merged allocation for dense outputs.
// TODO: deal with non alloc tensor here one day
if (!getSparseTensorEncoding(op.getResult(0).getType())) {
Value init = prod.getDpsInitOperand(0)
->get()
Expand Down
56 changes: 27 additions & 29 deletions mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@
}

// CHECK-LABEL: func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) : tensor<1024x1024xf64>
// CHECK: return %[[VAL_1]] : tensor<1024x1024xf64>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<1024x1024xf64>
// CHECK: return %[[C0]] : tensor<1024x1024xf64>
// CHECK: }
func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
%cst = arith.constant 0.000000e+00 : f64
%0 = bufferization.alloc_tensor() : tensor<1024x1024xf64>
%0 = tensor.empty() : tensor<1024x1024xf64>
%1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> ()>,
affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]}
Expand All @@ -40,13 +39,12 @@ func.func @fold_yield_arg_zero() -> tensor<1024x1024xf64> {
}

// CHECK-LABEL: func.func @fold_yield_direct_zero() -> tensor<32xf64> {
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
// CHECK: %[[VAL_1:.*]] = bufferization.alloc_tensor() copy(%[[VAL_0]]) : tensor<32xf64>
// CHECK: return %[[VAL_1]] : tensor<32xf64>
// CHECK: %[[C0:.*]] = arith.constant dense<0.000000e+00> : tensor<32xf64>
// CHECK: return %[[C0]] : tensor<32xf64>
// CHECK: }
func.func @fold_yield_direct_zero() -> tensor<32xf64> {
%cst = arith.constant 0.000000e+00 : f64
%0 = bufferization.alloc_tensor() : tensor<32xf64>
%0 = tensor.empty() : tensor<32xf64>
%1 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>],
iterator_types = ["parallel"]}
outs(%0 : tensor<32xf64>) {
Expand Down Expand Up @@ -92,9 +90,9 @@ func.func @fold_yield_direct_zero() -> tensor<32xf64> {
// CHECK: %[[VAL_32:.*]] = arith.mulf %[[VAL_30]], %[[VAL_31]] : f64
// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_28]], %[[VAL_32]] : f64
// CHECK: memref.store %[[VAL_33]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_27]]] : memref<8x8xf64>
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: } {"Emitted from" = "linalg.generic"}
// CHECK: }
// CHECK: }
// CHECK: }
// CHECK: %[[VAL_34:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64>
// CHECK: return %[[VAL_34]] : tensor<8x8xf64>
// CHECK: }
Expand Down Expand Up @@ -123,29 +121,29 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
}

// CHECK-LABEL: func.func @sparse_sampled_dd_unfused(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>,
// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> {
// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that! Less things for me to migrate ;)

// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[VAL_6:.*]] = arith.constant false
// CHECK-DAG: %[[VAL_7:.*]] = arith.constant true
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
// CHECK-DAG: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) : tensor<8x8xf64>
// CHECK-DAG: %[[VAL_10:.*]] = bufferization.alloc_tensor() : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
// CHECK-DAG: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
// CHECK-DAG: %[[VAL_10:.*]] = tensor.empty() : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK-DAG: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
// CHECK-DAG: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xindex>
// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>
// CHECK-DAG: %[[VAL_13:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_14:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 0 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_15:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_16:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
// CHECK-DAG: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
// CHECK: %[[VAL_19:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>) {
// CHECK: %[[VAL_20:.*]] = scf.for %[[VAL_21:.*]] = %[[VAL_18]] to %[[VAL_19]] step %[[VAL_5]] iter_args(%[[VAL_22:.*]] = %[[VAL_10]]) -> (tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>) {
// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_21]]] : memref<?xindex>
// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
// CHECK: %[[VAL_24:.*]], %[[VAL_25:.*]], %[[VAL_26:.*]], %[[VAL_27:.*]] = sparse_tensor.expand %[[VAL_10]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>, memref<?xi1>, memref<?xindex>
// CHECK: %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_27]]) -> (index) {
// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_29]]] : memref<8x8xf64>
// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_21]]] : memref<?xindex>
Expand All @@ -170,15 +168,15 @@ func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
// CHECK: scf.yield %[[VAL_37]] : index
// CHECK: }
// CHECK: memref.store %[[VAL_44]], %[[VAL_24]]{{\[}}%[[VAL_38]]] : memref<?xf64>
// CHECK: scf.yield %[[VAL_49:.*]] : index
// CHECK: scf.yield %[[VAL_47]] : index
// CHECK: }
// CHECK: scf.yield %[[VAL_50:.*]] : index
// CHECK: scf.yield %[[VAL_35]] : index
// CHECK: }
// CHECK: %[[VAL_51:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_52:.*]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
// CHECK: scf.yield %[[VAL_51]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
// CHECK: %[[VAL_49:.*]] = sparse_tensor.compress %[[VAL_24]], %[[VAL_25]], %[[VAL_26]], %[[VAL_28]] into %[[VAL_22]]{{\[}}%[[VAL_23]]] : memref<?xf64>, memref<?xi1>, memref<?xindex>, tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: scf.yield %[[VAL_49]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: }
// CHECK: %[[VAL_53:.*]] = sparse_tensor.load %[[VAL_54:.*]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
// CHECK: return %[[VAL_53]] : tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "compressed", "compressed" ] }>>
// CHECK: %[[VAL_50:.*]] = sparse_tensor.load %[[VAL_20]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: return %[[VAL_50]] : tensor<8x8xf64, #sparse_tensor.encoding<{{{.*}}}>>
// CHECK: }
func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
%arga: tensor<8x8xf64>,
Expand All @@ -194,7 +192,7 @@ func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
linalg.yield %q : f64
} -> tensor<8x8xf64>
// Sample the result with elements-wise multiplication with sparse matrix.
%3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
%3 = tensor.empty() : tensor<8x8xf64, #SM>
%4 = linalg.generic #trait_scale
ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
outs(%3 : tensor<8x8xf64, #SM>) {
Expand Down