Skip to content

Commit c91d3b0

Browse files
committed
[mlir][vector] Constrain patterns: vector.contract -> vector.outerproduct
This patch constrains the patterns for converting `vector.contract` to `vector.outerproduct` so that * the reduction dimension is _not unrolled_ if the corresponding dimension is scalable. This is necessary as the current lowering is incorrect for scalable dims. Indeed, the following unrolling for `vector.contract` would be invalid if the corresponding dimension was scalable (K is the size of the reduction dimension): ``` // K times. This is valid if K _is not_ scalable. %lhs = vector.extract %LHS[0] %rhs = vector.extract %RHS[0] vector.outerproduct %lhs, %rhs %lhs = vector.extract %LHS[1] %rhs = vector.extract %RHS[1] vector.outerproduct %lhs, %rhs // ... ``` Instead, a `for` loop should be generated: ``` // This would be valid regardless of whether K is scalable or not scf.for %k = 0 to K step 1 %lhs = vector.extract LHS[%k] %rhs = vector.extract RHS[%k] vector.outerproduct %lhs, %rhs ``` However, the lowering of: * `vector.extract` of vector slices with dynamic indices is incomplete and hence the implementation proposed above (with `scf.for`) wouldn't work just yet, i.e. it wouldn't be possible to lower it further. Instead, this patch disables unrolling in cases when the reduction dimension is scalable, i.e. where the generated code would be functionally incorrect. In order to document unsupported cases, a dedicated test file is added: * "vector-contract-to-outerproduct-transforms-unsupported.mlir" This is the first patch in a series of patches that strives to update these patterns (and to test them) for scalable vectors. Resolves #68400
1 parent 092ef55 commit c91d3b0

File tree

3 files changed

+107
-25
lines changed

3 files changed

+107
-25
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -424,9 +424,14 @@ struct UnrolledOuterProductGenerator
424424
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
425425
}
426426

427-
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
427+
FailureOr<Value> outerProd(Value lhs, Value rhs, Value res,
428+
VectorType lhsType, int reductionDim,
428429
std::optional<Value> maybeMask = std::nullopt) {
429-
assert(reductionSize > 0);
430+
// Unrolling a scalable dimension would be incorrect - bail out.
431+
if (lhsType.getScalableDims()[reductionDim])
432+
return failure();
433+
434+
int reductionSize = lhsType.getDimSize(reductionDim);
430435
// Incremental support for masking.
431436
if (mask && !maybeMask.has_value())
432437
return failure();
@@ -459,33 +464,39 @@ struct UnrolledOuterProductGenerator
459464
Value transposedMask = t(mask, {2, 0, 1});
460465
// Classical row-major matmul: Just permute the lhs.
461466
if (layout({{m, k}, {k, n}, {m, n}}))
462-
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
467+
return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
468+
transposedMask);
463469
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
464470
if (layout({{m, k}, {n, k}, {m, n}})) {
465471
Value tlhs = t(lhs);
466-
return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
472+
return outerProd(tlhs, t(rhs), res, lhsType, /*reductionDim=*/1,
467473
transposedMask);
468474
}
469475
// No need to permute anything.
470476
if (layout({{k, m}, {k, n}, {m, n}}))
471-
return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
477+
return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
478+
transposedMask);
472479
// Just permute the rhs.
473480
if (layout({{k, m}, {n, k}, {m, n}}))
474-
return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
481+
return outerProd(lhs, t(rhs), res, lhsType, /*reductionDim=*/0,
482+
transposedMask);
475483
// Transposed output: swap RHS and LHS.
476484
// Classical row-major matmul: permute the lhs.
477485
if (layout({{m, k}, {k, n}, {n, m}}))
478-
return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
486+
return outerProd(rhs, t(lhs), res, lhsType, /*reductionDim=*/1,
487+
transposedMask);
479488
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
480489
if (layout({{m, k}, {n, k}, {n, m}})) {
481490
Value trhs = t(rhs);
482-
return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
491+
return outerProd(trhs, t(lhs), res, lhsType, /*reductionDim=*/1,
483492
transposedMask);
484493
}
485494
if (layout({{k, m}, {k, n}, {n, m}}))
486-
return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
495+
return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
496+
transposedMask);
487497
if (layout({{k, m}, {n, k}, {n, m}}))
488-
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
498+
return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
499+
transposedMask);
489500
return failure();
490501
}
491502

@@ -503,16 +514,20 @@ struct UnrolledOuterProductGenerator
503514

504515
// Case mat-vec: transpose.
505516
if (layout({{m, k}, {k}, {m}}))
506-
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
517+
return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1,
518+
transposedMask);
507519
// Case mat-trans-vec: ready to go.
508520
if (layout({{k, m}, {k}, {m}}))
509-
return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
521+
return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0,
522+
transposedMask);
510523
// Case vec-mat: swap and transpose.
511524
if (layout({{k}, {m, k}, {m}}))
512-
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
525+
return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0,
526+
transposedMask);
513527
// Case vec-mat-trans: swap and ready to go.
514528
if (layout({{k}, {k, m}, {m}}))
515-
return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
529+
return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0,
530+
transposedMask);
516531
return failure();
517532
}
518533

@@ -528,16 +543,16 @@ struct UnrolledOuterProductGenerator
528543

529544
// Case mat-vec: transpose.
530545
if (layout({{m, k}, {k}, {m}}))
531-
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
546+
return outerProd(t(lhs), rhs, res, lhsType, /*reductionDim=*/1, mask);
532547
// Case mat-trans-vec: ready to go.
533548
if (layout({{k, m}, {k}, {m}}))
534-
return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
549+
return outerProd(lhs, rhs, res, lhsType, /*reductionDim=*/0, mask);
535550
// Case vec-mat: swap and transpose.
536551
if (layout({{k}, {m, k}, {m}}))
537-
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
552+
return outerProd(t(rhs), lhs, res, lhsType, /*reductionDim=*/0, mask);
538553
// Case vec-mat-trans: swap and ready to go.
539554
if (layout({{k}, {k, m}, {m}}))
540-
return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
555+
return outerProd(rhs, lhs, res, lhsType, /*reductionDim=*/0, mask);
541556
return failure();
542557
}
543558

@@ -980,9 +995,19 @@ FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
980995
diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex
981996
<< " to map to the same dimension";
982997
});
998+
if (lhsType.getScalableDims()[lhsIndex])
999+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1000+
diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex
1001+
<< ") is not supported yet";
1002+
});
9831003
dimSize = lhsType.getDimSize(lhsIndex);
9841004
} else if (rhsIndex >= 0) {
9851005
iterIndex = iMap[1].getDimPosition(rhsIndex);
1006+
if (rhsType.getScalableDims()[rhsIndex])
1007+
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
1008+
diag << "Unrolloing scalable dimension (lhsIndex=" << lhsIndex
1009+
<< ") is not supported yet";
1010+
});
9861011
dimSize = rhsType.getDimSize(rhsIndex);
9871012
}
9881013
if (iterIndex < 0)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics
2+
3+
#matvec_accesses = [
4+
affine_map<(i, j) -> (i, j)>,
5+
affine_map<(i, j) -> (j)>,
6+
affine_map<(i, j) -> (i)>
7+
]
8+
#matvec_trait = {
9+
indexing_maps = #matvec_accesses,
10+
iterator_types = ["parallel", "reduction"]
11+
}
12+
13+
// Unrolling scalable reduction dim is not supported - bail out
14+
15+
// expected-error@below {{greedy pattern application failed}}
16+
func.func @masked_extract_contract2_scalable_reduction_dim(%arg0: vector<[2]x[3]xf32>,
17+
%arg1: vector<[3]xf32>,
18+
%arg2: vector<[2]xf32>,
19+
%m: vector<[2]x[3]xi1>) -> vector<[2]xf32> {
20+
%0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
21+
: vector<[2]x[3]xf32>, vector<[3]xf32> into vector<[2]xf32> } : vector<[2]x[3]xi1> -> vector<[2]xf32>
22+
return %0 : vector<[2]xf32>
23+
}
24+
25+
transform.sequence failures(propagate) {
26+
^bb1(%module_op: !transform.any_op):
27+
%f = transform.structured.match ops{["func.func"]} in %module_op
28+
: (!transform.any_op) -> !transform.any_op
29+
30+
transform.apply_patterns to %f {
31+
transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
32+
} : !transform.any_op
33+
}

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,19 @@
3131
}
3232

3333
// CHECK-LABEL: func.func @masked_extract_contract2(
34-
// CHECK-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
35-
// CHECK-SAME: %[[VAL_1:.*]]: vector<3xf32>,
36-
// CHECK-SAME: %[[VAL_2:.*]]: vector<2xf32>,
37-
// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
34+
// CHECK-SAME: %{{.*}}: vector<2x3xf32>,
35+
// CHECK-SAME: %{{.*}}: vector<3xf32>,
36+
// CHECK-SAME: %{{.*}}: vector<2xf32>,
37+
// CHECK-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
3838
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
3939
// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<2xi1> from vector<3x2xi1>
40-
// CHECK: vector.mask %[[MASK0]] { vector.outerproduct
40+
// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
4141

4242
// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<2xi1> from vector<3x2xi1>
43-
// CHECK: vector.mask %[[MASK1]] { vector.outerproduct
43+
// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
4444

4545
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<2xi1> from vector<3x2xi1>
46-
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct
46+
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
4747

4848
func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
4949
%arg1: vector<3xf32>,
@@ -54,6 +54,30 @@ func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
5454
return %0 : vector<2xf32>
5555
}
5656

57+
58+
// CHECK-LABEL: func.func @masked_extract_contract2_scalable_parallel_dim(
59+
// CHECK-SAME: %{{.*}}: vector<[2]x3xf32>,
60+
// CHECK-SAME: %{{.*}}: vector<3xf32>,
61+
// CHECK-SAME: %{{.*}}: vector<[2]xf32>,
62+
// CHECK-SAME: %[[IN_MASK:.*]]: vector<[2]x3xi1>) -> vector<[2]xf32>
63+
// CHECK: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<[2]x3xi1> to vector<3x[2]xi1>
64+
// CHECK: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<[2]xi1> from vector<3x[2]xi1>
65+
// CHECK: vector.mask %[[MASK0]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
66+
67+
// CHECK: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<[2]xi1> from vector<3x[2]xi1>
68+
// CHECK: vector.mask %[[MASK1]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
69+
70+
// CHECK: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<[2]xi1> from vector<3x[2]xi1>
71+
// CHECK: vector.mask %[[MASK2]] { vector.outerproduct {{.*}} {kind = #vector.kind<add>} : vector<[2]xf32>, f32 } : vector<[2]xi1> -> vector<[2]xf32>
72+
func.func @masked_extract_contract2_scalable_parallel_dim(%arg0: vector<[2]x3xf32>,
73+
%arg1: vector<3xf32>,
74+
%arg2: vector<[2]xf32>,
75+
%m: vector<[2]x3xi1>) -> vector<[2]xf32> {
76+
%0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
77+
: vector<[2]x3xf32>, vector<3xf32> into vector<[2]xf32> } : vector<[2]x3xi1> -> vector<[2]xf32>
78+
return %0 : vector<[2]xf32>
79+
}
80+
5781
// CHECK-LABEL: func.func @masked_extract_contract4(
5882
// CHECK-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
5983
// CHECK-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,

0 commit comments

Comments
 (0)