Skip to content

Commit 2e12bad

Browse files
authored
[MLIR][Linalg] Fix insert_slice fusion with rank reduction (#130961)
Insert_slice fusion with a linalg producer does not account for possible rank-reduction in the insert_slice return type. When that happens, a tensor.cast gets generated due to the type mismatch which is invalid for tensor with different rank. This later trips other pass.
1 parent a4380fe commit 2e12bad

File tree

4 files changed

+126
-2
lines changed

4 files changed

+126
-2
lines changed

mlir/include/mlir/Dialect/Tensor/Utils/Utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ FailureOr<RankedTensorType>
4343
computeTransposedType(RankedTensorType rankedTensorType,
4444
ArrayRef<int64_t> transposeVector);
4545

46+
/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor
47+
/// `src`.
48+
CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
49+
const llvm::SmallBitVector &dropDims);
50+
4651
/// A tensor.insert_slice is a cast-like operation if it merely rank-extends the
4752
/// source tensor or inserts the source tensor into a destination tensor with
4853
/// the same shape.

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1919
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2020
#include "mlir/Dialect/Tensor/IR/Tensor.h"
21+
#include "mlir/Dialect/Tensor/Utils/Utils.h"
2122
#include "mlir/IR/AffineExpr.h"
2223
#include "mlir/IR/AffineMap.h"
2324
#include "mlir/IR/Dominance.h"
@@ -26,6 +27,7 @@
2627
#include "mlir/Transforms/RegionUtils.h"
2728
#include "llvm/ADT/MapVector.h"
2829
#include "llvm/ADT/ScopeExit.h"
30+
#include "llvm/ADT/SmallBitVector.h"
2931
#include "llvm/Support/CommandLine.h"
3032
#include "llvm/Support/Debug.h"
3133

@@ -271,12 +273,20 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult,
271273
consumerOpOperand);
272274

273275
// Replace use.
276+
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
277+
Type consumerType = consumerOpOperand.get().getType();
278+
// Check if rank-reduction occurred as part of the extract_slice. If yes,
279+
// collapse the dropped dimensions.
280+
if (cast<ShapedType>(consumerType).getRank() !=
281+
cast<ShapedType>(def.getType()).getRank()) {
282+
llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims();
283+
def =
284+
tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims);
285+
}
274286
// Canonicalizations are not guaranteed to have happened before constructing
275287
// `fusedProducer`. In the tensor case this can result in temporary type
276288
// mismatches. Insert a `tensor.cast` op to propagate the transformation
277289
// invariant that types are compatible.
278-
Value def = fusedProducer->getResult(producerOpResult.getResultNumber());
279-
Type consumerType = consumerOpOperand.get().getType();
280290
if (consumerType != def.getType())
281291
def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def);
282292
consumerOpOperand.set(def);

mlir/lib/Dialect/Tensor/Utils/Utils.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,37 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
9494
return transposedTensorType;
9595
}
9696

97+
CollapseShapeOp
98+
mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value src,
99+
const llvm::SmallBitVector &dropDims) {
100+
auto srcType = cast<ShapedType>(src.getType());
101+
int64_t rank = srcType.getRank();
102+
assert(rank == static_cast<int64_t>(dropDims.size()) &&
103+
"dropDims dimension does not match src tensor rank");
104+
assert(llvm::all_of(
105+
dropDims.set_bits(),
106+
[&](unsigned dim) { return srcType.getShape()[dim] == 1; }) &&
107+
"Dropping non unit dimension");
108+
// Computed reassociation map for the corresponding tensor.collapse_shape.
109+
SmallVector<ReassociationIndices, 2> reassocMaps;
110+
// Current reassociation group to add dropped dimension to.
111+
112+
int64_t nextDimToGroup = 0;
113+
llvm::SmallBitVector keptDims(dropDims);
114+
keptDims.flip();
115+
int64_t lastSetBit = keptDims.find_last();
116+
for (int64_t setBit : keptDims.set_bits()) {
117+
// Group consecutive dropped dimension with the next non-dropped dimension.
118+
// If this is the last set dimension, also group all subsequent dropped
119+
// dimension, if any.
120+
int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit;
121+
auto seq = llvm::seq_inclusive(nextDimToGroup, upTo);
122+
reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end()));
123+
nextDimToGroup = setBit + 1;
124+
}
125+
return b.create<tensor::CollapseShapeOp>(loc, src, reassocMaps);
126+
}
127+
97128
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
98129
llvm::SmallBitVector droppedDims = op.getDroppedDims();
99130
int64_t srcDim = 0;

mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,81 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
318318
}
319319
return %for0 : tensor<64x128xf32>
320320
}
321+
322+
// -----
323+
324+
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
325+
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>
326+
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
327+
#map3 = affine_map<(d0, d1, d2) -> (d0, d2)>
328+
#map4 = affine_map<(d0, d1, d2) -> (d2, d1)>
329+
#map5 = affine_map<(d0, d1, d2) -> (d0, d1)>
330+
func.func @rank_reduced_extract_slice(
331+
%prod_in: tensor<1x6x5xf32>, %prod_weight: tensor<1x5x6xf32>,
332+
%cons_in: tensor<4x6xf32>, %prod_init: tensor<1x6x6xf32>,
333+
%for_iv_init: tensor<4x6xf32>, %cons_init: tensor<4x2xf32>
334+
) -> tensor<4x6xf32> {
335+
%c0 = arith.constant 0 : index
336+
%c2 = arith.constant 2 : index
337+
%c6 = arith.constant 6 : index
338+
%mmul_prod = linalg.generic
339+
{indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]}
340+
ins(%prod_in, %prod_weight : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%prod_init : tensor<1x6x6xf32>) {
341+
^bb0(%in: f32, %in_1: f32, %out: f32):
342+
%10 = arith.mulf %in, %in_1 : f32
343+
%11 = arith.addf %out, %10 : f32
344+
linalg.yield %11 : f32
345+
} -> tensor<1x6x6xf32>
346+
%for = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %for_iv_init) -> (tensor<4x6xf32>) {
347+
348+
// Extract slice with rank-reduced result type. When fused in the loop
349+
// with sliced operands, the producer linalg must have its now sliced
350+
// result be rank-reduced as well to match consumer's use type.
351+
%prod_slice = tensor.extract_slice %mmul_prod[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32>
352+
%mmul_cons = linalg.generic
353+
{indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]}
354+
ins(%cons_in, %prod_slice : tensor<4x6xf32>, tensor<6x2xf32>) outs(%cons_init : tensor<4x2xf32>) {
355+
^bb0(%in: f32, %in_1: f32, %out: f32):
356+
%20 = arith.mulf %in, %in_1 : f32
357+
%21 = arith.addf %out, %20 : f32
358+
linalg.yield %21 : f32
359+
} -> tensor<4x2xf32>
360+
%4 = tensor.insert_slice %mmul_cons into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
361+
scf.yield %4 : tensor<4x6xf32>
362+
}
363+
return %for : tensor<4x6xf32>
364+
}
365+
366+
// CHECK: func @rank_reduced_extract_slice(
367+
// CHECK-SAME: %[[PROD_IN:[0-9a-z]*]]: tensor<1x6x5xf32>
368+
// CHECK-SAME: %[[PROD_WEIGHT:[0-9a-z]*]]: tensor<1x5x6xf32>
369+
// CHECK-SAME: %[[CONS_IN:[0-9a-z]*]]: tensor<4x6xf32>
370+
// CHECK-SAME: %[[PROD_INIT:[0-9a-z]*]]: tensor<1x6x6xf32>
371+
// CHECK-SAME: %[[FOR_IV_INIT:[0-9a-z]*]]: tensor<4x6xf32>
372+
// CHECK-SAME: %[[CONS_INIT:[0-9a-z]*]]: tensor<4x2xf32>
373+
374+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
375+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
376+
// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
377+
378+
// For loop right after tensor alloc & fill, no linalg.generic.
379+
// CHECK-NOT: linalg.generic
380+
// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[FOR_IV_INIT]])
381+
382+
// Producer linalg.generic now inside the loop, with tiled args sliced before
383+
// it.
384+
// CHECK-DAG: %[[PROD_WEIGHT_SLICE:.*]] = tensor.extract_slice %[[PROD_WEIGHT]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
385+
// CHECK-DAG: %[[PROD_INIT_SLICE:.*]] = tensor.extract_slice %[[PROD_INIT]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
386+
// CHECK: %[[MMUL_PROD:.*]] = linalg.generic
387+
// CHECK-SAME: ins(%[[PROD_IN]], %[[PROD_WEIGHT_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
388+
// CHECK-SAME: outs(%[[PROD_INIT_SLICE]] : tensor<1x6x2xf32>)
389+
//
390+
// Consumer uses a rank-reduced version of producer result so a collapse_shape
391+
// is generated.
392+
// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
393+
// CHECK: %[[MMUL_CONS:.*]] = linalg.generic
394+
// CHECK-SAME: ins(%[[CONS_IN]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
395+
// CHECK-SAME: outs(%[[CONS_INIT]] : tensor<4x2xf32>)
396+
// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
397+
// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
398+
// CHECK: return %[[FOR]] : tensor<4x6xf32>

0 commit comments

Comments
 (0)