Skip to content

Commit 6942f1d

Browse files
authored
[MLIR][Linalg] Scalable Vectorization of Reduction on the Trailing Dimension (#97788)
Allow scalable vectorization of linalg::reduce and linalg::generic that has reduction iterator(s) with two restrictions: 1. The reduction dim is the last (innermost) dim of the op; and 2. Only the reduction dim is requested for scalable vectorization. One exception is that scalable vectorization of the reduction dim in Matmul-like ops are not supported even above restrictions are met. Allowed combinations of scalable flags and iterator types: Matmul: Iterators: ["parallel", "parallel", "reduction"] Scalable Flags: ["true", "true", "false"] ["false", "true", "false"] Matvec: Iterators: ["parallel", "reduction"] Scalable Flags: ["false", "true"] ["true", "false"]
1 parent 2ca300f commit 6942f1d

File tree

5 files changed

+640
-19
lines changed

5 files changed

+640
-19
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,14 @@ static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
586586
llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
587587
}
588588

589+
/// Check if `op` is a linalg.reduce or a linalg.generic that has at least one
590+
/// reduction iterator.
591+
static bool hasReductionIterator(LinalgOp &op) {
592+
return isa<linalg::ReduceOp>(op) ||
593+
(isa<linalg::GenericOp>(op) &&
594+
llvm::any_of(op.getIteratorTypesArray(), isReductionIterator));
595+
}
596+
589597
/// Build a vector.transfer_write of `value` into `outputOperand` at indices set
590598
/// to all `0`; where `outputOperand` is an output operand of the LinalgOp
591599
/// currently being vectorized. If `dest` has null rank, build an memref.store.
@@ -1787,6 +1795,9 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
17871795
if (isa<ConvolutionOpInterface>(op.getOperation()))
17881796
return vectorizeDynamicConvOpPrecondition(op, flatten1DDepthwiseConv);
17891797

1798+
if (hasReductionIterator(op))
1799+
return reductionPreconditions(op);
1800+
17901801
// TODO: Masking only supports dynamic element-wise ops, linalg.generic ops,
17911802
// linalg.copy ops and ops that implement ContractionOpInterface for now.
17921803
if (!isElementwise(op) &&
@@ -1976,6 +1987,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
19761987
// 1. exactly 1 dim is scalable and that's the _last_ parallel dim
19771988
// 2. exactly 2 dims are scalable and those are the _last two adjacent_
19781989
// parallel dims
1990+
// 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
19791991
// The 2nd restriction above means that only Matmul-like Ops are supported
19801992
// when 2 dims are scalable, e.g. :
19811993
// * iterators = [parallel, parallel, reduction]
@@ -1992,19 +2004,45 @@ vectorizeScalableVectorPrecondition(Operation *op,
19922004
scalableFlags.pop_back();
19932005
}
19942006

1995-
// TODO: Support scalable vectorisation for reduction dims
1996-
if (iterators.back() == utils::IteratorType::reduction)
1997-
return failure();
1998-
1999-
// If this is not the _last_ parallel dim, 1. above is not met
2000-
if (seenParalell)
2001-
return failure();
2007+
switch (iterators.back()) {
2008+
case utils::IteratorType::reduction: {
2009+
// Check 3. above is met.
2010+
if (iterators.size() != inputVectorSizes.size()) {
2011+
LDBG("Non-trailing reduction dim requested for scalable "
2012+
"vectorization\n");
2013+
return failure();
2014+
}
2015+
if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
2016+
LDBG("Scalable vectorization of the reduction dim in Matmul-like ops "
2017+
"is not supported\n");
2018+
return failure();
2019+
}
2020+
break;
2021+
}
2022+
case utils::IteratorType::parallel: {
2023+
// Check 1. and 2. above are met.
2024+
if (seenParalell) {
2025+
LDBG("Inner parallel dim not requested for scalable "
2026+
"vectorization\n");
2027+
return failure();
2028+
}
2029+
break;
2030+
}
2031+
}
20022032

20032033
// If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
20042034
// supported for which expect the folowing config:
20052035
// * iterators = [parallel, parallel, reduction]
20062036
// * scalable flags = [true, true, false]
20072037
if (numOfScalableDims == 2) {
2038+
// Disallow below case which breaks 3. above:
2039+
// * iterators = [..., parallel, reduction]
2040+
// * scalable flags = [..., true, true]
2041+
if (iterators.back() == utils::IteratorType::reduction) {
2042+
LDBG("Higher dim than the trailing reduction dim requested for scalable "
2043+
"vectorization\n");
2044+
return failure();
2045+
}
20082046
scalableFlags.pop_back();
20092047
iterators.pop_back();
20102048

@@ -2017,7 +2055,8 @@ vectorizeScalableVectorPrecondition(Operation *op,
20172055
// presence of scalable vectors
20182056
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
20192057
isa<linalg::MatmulTransposeAOp>(op) ||
2020-
isa<linalg::DepthwiseConv1DNwcWcOp>(op));
2058+
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
2059+
isa<linalg::MatvecOp>(op) || hasReductionIterator(linalgOp));
20212060
}
20222061

20232062
LogicalResult mlir::linalg::vectorizeOpPrecondition(

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,168 @@ module attributes {transform.with_named_sequence} {
189189
transform.yield
190190
}
191191
}
192+
193+
// -----
194+
195+
func.func @vectorize_dynamic_reduction_scalable_1d(%arg0: tensor<?xf32>,
196+
%arg1: tensor<f32>) -> tensor<f32> {
197+
198+
%0 = linalg.reduce ins(%arg0 : tensor<?xf32>) outs(%arg1 : tensor<f32>) dimensions = [0]
199+
(%in: f32, %init: f32) {
200+
%0 = arith.addf %in, %init : f32
201+
linalg.yield %0 : f32
202+
}
203+
return %0 : tensor<f32>
204+
}
205+
206+
// CHECK-LABEL: func.func @vectorize_dynamic_reduction_scalable_1d(
207+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<f32>) -> tensor<f32> {
208+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
209+
// CHECK: %[[DIM_A0_0:.*]] = tensor.dim %[[ARG_0]], %[[C0_idx]] : tensor<?xf32>
210+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
211+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
212+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_A0_0]] : vector<[4]xi1>
213+
// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
214+
// CHECK: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32
215+
// CHECK: %[[VEC_RD_1:.*]] = vector.transfer_read %[[ARG_1]][], %[[C0_F32]] : tensor<f32>, vector<f32>
216+
// CHECK: %[[ACC_f32:.*]] = vector.extractelement %[[VEC_RD_1]][] : vector<f32>
217+
// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK]] { vector.multi_reduction <add>, %[[VEC_RD_0]], %[[ACC_f32]] [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32
218+
// CHECK: %[[VEC_f32:.*]] = vector.broadcast %[[REDUCE]] : f32 to vector<f32>
219+
// CHECK: %{{.*}} = vector.transfer_write %[[VEC_f32]], %[[ARG_1]][] : vector<f32>, tensor<f32>
220+
221+
module attributes {transform.with_named_sequence} {
222+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
223+
%0 = transform.structured.match ops{["linalg.reduce"]} in %arg1 : (!transform.any_op) -> !transform.any_op
224+
transform.structured.vectorize %0 vector_sizes [[4]] : !transform.any_op
225+
transform.yield
226+
}
227+
}
228+
229+
// -----
230+
231+
// Note: scalable version of `vectorize_dynamic_reduction` in test/Dialect/Linalg/vectorization.mlir.
232+
func.func @vectorize_dynamic_reduction_scalable_2d(%arg0: tensor<?x?xf32>,
233+
%arg1: tensor<?xf32>) -> tensor<?xf32> {
234+
%0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
235+
affine_map<(d0, d1) -> (d0)>],
236+
iterator_types = ["parallel", "reduction"] }
237+
ins(%arg0 : tensor<?x?xf32>)
238+
outs(%arg1 : tensor<?xf32>) {
239+
^bb(%in: f32, %out: f32) :
240+
%0 = arith.addf %in, %out : f32
241+
linalg.yield %0 : f32
242+
} -> tensor<?xf32>
243+
return %0 : tensor<?xf32>
244+
}
245+
246+
// CHECK-LABEL: func.func @vectorize_dynamic_reduction_scalable_2d(
247+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, %[[ARG_1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
248+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
249+
// CHECK: %[[DIM_A0_0:.*]] = tensor.dim %[[ARG_0]], %[[C0_idx]] : tensor<?x?xf32>
250+
// CHECK: %[[C1_idx:.*]] = arith.constant 1 : index
251+
// CHECK: %[[DIM_A0_1:.*]] = tensor.dim %[[ARG_0]], %[[C1_idx]] : tensor<?x?xf32>
252+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
253+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
254+
// CHECK: %[[MASK_2d:.*]] = vector.create_mask %[[DIM_A0_0]], %[[DIM_A0_1]] : vector<4x[8]xi1>
255+
// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK_2d]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]], %[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x[8]xf32> } : vector<4x[8]xi1> -> vector<4x[8]xf32>
256+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
257+
// CHECK: %[[MASK_1d:.*]] = vector.create_mask %[[DIM_A0_0]] : vector<4xi1>
258+
// CHECK: %[[VEC_RD_1:.*]] = vector.mask %[[MASK_1d]] { vector.transfer_read %[[ARG_1]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
259+
// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK_2d]] { vector.multi_reduction <add>, %[[VEC_RD_0]], %[[VEC_RD_1]] [1] : vector<4x[8]xf32> to vector<4xf32> } : vector<4x[8]xi1> -> vector<4xf32>
260+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
261+
// CHECK: %{{.*}} = vector.mask %[[MASK_1d]] { vector.transfer_write %[[REDUCE]], %[[ARG_1]][%[[C0_idx]]] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
262+
263+
module attributes {transform.with_named_sequence} {
264+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
265+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
266+
transform.structured.vectorize %0 vector_sizes [4, [8]] : !transform.any_op
267+
transform.yield
268+
}
269+
}
270+
271+
// -----
272+
273+
func.func @vectorize_dynamic_matvec_trailing_reduction_dim(%arg0: tensor<?x?xf32>,
274+
%arg1: tensor<?xf32>,
275+
%arg2: tensor<?xf32>) {
276+
linalg.matvec ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
277+
outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
278+
return
279+
}
280+
281+
// CHECK-LABEL: func.func @vectorize_dynamic_matvec_trailing_reduction_dim(
282+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>) {
283+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
284+
// CHECK: %[[DIM_A0_0:.*]] = tensor.dim %[[ARG_0]], %[[C0_idx]] : tensor<?x?xf32>
285+
// CHECK: %[[C1_idx:.*]] = arith.constant 1 : index
286+
// CHECK: %[[DIM_A0_1:.*]] = tensor.dim %[[ARG_0]], %[[C1_idx]] : tensor<?x?xf32>
287+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
288+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
289+
// CHECK: %[[MASK_2d:.*]] = vector.create_mask %[[DIM_A0_0]], %[[DIM_A0_1]] : vector<4x[4]xi1>
290+
// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK_2d]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]], %[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x[4]xf32> } : vector<4x[4]xi1> -> vector<4x[4]xf32>
291+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
292+
// CHECK: %[[MASK_d1:.*]] = vector.create_mask %[[DIM_A0_1]] : vector<[4]xi1>
293+
// CHECK: %[[VEC_RD_1:.*]] = vector.mask %[[MASK_d1]] { vector.transfer_read %[[ARG_1]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true], permutation_map = #map} : tensor<?xf32>, vector<4x[4]xf32> } : vector<[4]xi1> -> vector<4x[4]xf32>
294+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
295+
// CHECK: %[[MASK_d2:.*]] = vector.create_mask %[[DIM_A0_0]] : vector<4xi1>
296+
// CHECK: %[[VEC_RD_2:.*]] = vector.mask %[[MASK_d2]] { vector.transfer_read %[[ARG_2]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
297+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_RD_0:.*]], %[[VEC_RD_1:.*]] : vector<4x[4]xf32>
298+
// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK_2d]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_RD_2]] [1] : vector<4x[4]xf32> to vector<4xf32> } : vector<4x[4]xi1> -> vector<4xf32>
299+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
300+
// CHECK: %{{.*}} = vector.mask %[[MASK_d2]] { vector.transfer_write %[[REDUCE]], %[[ARG_2]][%[[C0_idx]]] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
301+
302+
module attributes {transform.with_named_sequence} {
303+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
304+
%0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 : (!transform.any_op) -> !transform.any_op
305+
transform.structured.vectorize %0 vector_sizes [4, [4]] : !transform.any_op
306+
transform.yield
307+
}
308+
}
309+
310+
// -----
311+
312+
func.func @vectorize_dynamic_generic_matvec_leading_parallel_dim(%arg0: tensor<?x?xf32>,
313+
%arg1: tensor<?xf32>,
314+
%arg2: tensor<?xf32>) -> tensor<?xf32> {
315+
%0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
316+
affine_map<(d0, d1) -> (d1)>,
317+
affine_map<(d0, d1) -> (d0)>],
318+
iterator_types = ["parallel", "reduction"] }
319+
ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>)
320+
outs(%arg2 : tensor<?xf32>) {
321+
^bb(%mat: f32, %vec: f32, %res: f32) :
322+
%0 = arith.mulf %mat, %vec : f32
323+
%1 = arith.addf %res, %0 : f32
324+
linalg.yield %1 : f32
325+
} -> tensor<?xf32>
326+
return %0 : tensor<?xf32>
327+
}
328+
329+
// CHECK-LABEL: func.func @vectorize_dynamic_generic_matvec_leading_parallel_dim(
330+
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?x?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>) -> tensor<?xf32> {
331+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
332+
// CHECK: %[[DIM_A0_0:.*]] = tensor.dim %[[ARG_0]], %[[C0_idx]] : tensor<?x?xf32>
333+
// CHECK: %[[C1_idx:.*]] = arith.constant 1 : index
334+
// CHECK: %[[DIM_A0_1:.*]] = tensor.dim %[[ARG_0]], %[[C1_idx]] : tensor<?x?xf32>
335+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
336+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
337+
// CHECK: %[[MASK_2d:.*]] = vector.create_mask %[[DIM_A0_0]], %[[DIM_A0_1]] : vector<[4]x4xi1>
338+
// CHECK: %[[VEC_RD_0:.*]] = vector.mask %[[MASK_2d]] { vector.transfer_read %[[ARG_0]][%[[C0_idx]], %[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<[4]x4xf32> } : vector<[4]x4xi1> -> vector<[4]x4xf32>
339+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
340+
// CHECK: %[[MASK_d1:.*]] = vector.create_mask %[[DIM_A0_1]] : vector<4xi1>
341+
// CHECK: %[[VEC_RD_1:.*]] = vector.mask %[[MASK_d1]] { vector.transfer_read %[[ARG_1]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true, true], permutation_map = #map} : tensor<?xf32>, vector<[4]x4xf32> } : vector<4xi1> -> vector<[4]x4xf32>
342+
// CHECK: %[[C0_f32:.*]] = arith.constant 0.000000e+00 : f32
343+
// CHECK: %[[MASK_d2:.*]] = vector.create_mask %[[DIM_A0_0]] : vector<[4]xi1>
344+
// CHECK: %[[VEC_RD_2:.*]] = vector.mask %[[MASK_d2]] { vector.transfer_read %[[ARG_2]][%[[C0_idx]]], %[[C0_f32]] {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
345+
// CHECK: %[[MUL:.*]] = arith.mulf %[[VEC_RD_0:.*]], %[[VEC_RD_1:.*]] : vector<[4]x4xf32>
346+
// CHECK: %[[REDUCE:.*]] = vector.mask %[[MASK_2d]] { vector.multi_reduction <add>, %[[MUL]], %[[VEC_RD_2]] [1] : vector<[4]x4xf32> to vector<[4]xf32> } : vector<[4]x4xi1> -> vector<[4]xf32>
347+
// CHECK: %[[C0_idx:.*]] = arith.constant 0 : index
348+
// CHECK: %{{.*}} = vector.mask %[[MASK_d2]] { vector.transfer_write %[[REDUCE]], %[[ARG_2]][%[[C0_idx]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
349+
350+
module attributes {transform.with_named_sequence} {
351+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
352+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
353+
transform.structured.vectorize %0 vector_sizes [[4], 4] : !transform.any_op
354+
transform.yield
355+
}
356+
}

0 commit comments

Comments
 (0)