Skip to content

Commit a509a18

Browse files
authored
[mlir][vector] proper masking support for contract lowering (#67145)
Support all known permutations when lowering masked vector.contract to vector.outerproduct, and not just the canonical permutation.
1 parent a433592 commit a509a18

File tree

2 files changed

+195
-19
lines changed

2 files changed

+195
-19
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -456,61 +456,69 @@ struct UnrolledOuterProductGenerator
456456
// Set up the parallel/reduction structure in the right form.
457457
AffineExpr m, n, k;
458458
bindDims(rewriter.getContext(), m, n, k);
459+
Value transposedMask = t(mask, {2, 0, 1});
459460
// Classical row-major matmul: Just permute the lhs.
460461
if (layout({{m, k}, {k, n}, {m, n}}))
461-
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
462-
t(mask, {2, 0, 1}));
462+
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
463463
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
464464
if (layout({{m, k}, {n, k}, {m, n}})) {
465465
Value tlhs = t(lhs);
466-
return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1));
466+
return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1),
467+
transposedMask);
467468
}
468469
// No need to permute anything.
469470
if (layout({{k, m}, {k, n}, {m, n}}))
470-
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
471+
return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
471472
// Just permute the rhs.
472473
if (layout({{k, m}, {n, k}, {m, n}}))
473-
return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0));
474+
return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0), transposedMask);
474475
// Transposed output: swap RHS and LHS.
475476
// Classical row-major matmul: permute the lhs.
476477
if (layout({{m, k}, {k, n}, {n, m}}))
477-
return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1));
478+
return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1), transposedMask);
478479
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
479480
if (layout({{m, k}, {n, k}, {n, m}})) {
480481
Value trhs = t(rhs);
481-
return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1));
482+
return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1),
483+
transposedMask);
482484
}
483485
if (layout({{k, m}, {k, n}, {n, m}}))
484-
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
486+
return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
485487
if (layout({{k, m}, {n, k}, {n, m}}))
486-
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
488+
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
487489
return failure();
488490
}
489491

490-
/// One outer parallel, one inner reduction (matvec flavor)
492+
//
493+
// One outer parallel, one inner reduction (matvec flavor).
494+
// Mask needs to be transposed everywhere to turn the reduction dimension
495+
// outermost as required by outerproduct.
496+
//
491497
FailureOr<Value> matvec() {
492498
if (!iters({Par(), Red()}))
493499
return failure();
494500
AffineExpr m, k;
495501
bindDims(rewriter.getContext(), m, k);
502+
Value transposedMask = t(mask);
496503

497504
// Case mat-vec: transpose.
498505
if (layout({{m, k}, {k}, {m}}))
499-
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
506+
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), transposedMask);
500507
// Case mat-trans-vec: ready to go.
501508
if (layout({{k, m}, {k}, {m}}))
502-
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
509+
return outerProd(lhs, rhs, res, lhsType.getDimSize(0), transposedMask);
503510
// Case vec-mat: swap and transpose.
504511
if (layout({{k}, {m, k}, {m}}))
505-
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
512+
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), transposedMask);
506513
// Case vec-mat-trans: swap and ready to go.
507514
if (layout({{k}, {k, m}, {m}}))
508-
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
515+
return outerProd(rhs, lhs, res, lhsType.getDimSize(0), transposedMask);
509516
return failure();
510517
}
511518

512519
//
513-
// One outer reduction, one inner parallel (tmatvec flavor)
520+
// One outer reduction, one inner parallel (tmatvec flavor).
521+
// Mask already has the shape of the outer product.
514522
//
515523
FailureOr<Value> tmatvec() {
516524
if (!iters({Red(), Par()}))
@@ -520,16 +528,16 @@ struct UnrolledOuterProductGenerator
520528

521529
// Case mat-vec: transpose.
522530
if (layout({{m, k}, {k}, {m}}))
523-
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
531+
return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), mask);
524532
// Case mat-trans-vec: ready to go.
525533
if (layout({{k, m}, {k}, {m}}))
526-
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
534+
return outerProd(lhs, rhs, res, lhsType.getDimSize(0), mask);
527535
// Case vec-mat: swap and transpose.
528536
if (layout({{k}, {m, k}, {m}}))
529-
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0));
537+
return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0), mask);
530538
// Case vec-mat-trans: swap and ready to go.
531539
if (layout({{k}, {k, m}, {m}}))
532-
return outerProd(rhs, lhs, res, lhsType.getDimSize(0));
540+
return outerProd(rhs, lhs, res, lhsType.getDimSize(0), mask);
533541
return failure();
534542
}
535543

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,174 @@ func.func @matmul_7(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
341341
return %0 : vector<3x2xf32>
342342
}
343343

344+
// CHECK-LABEL: @masked_matvec_mk_k_m
345+
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
346+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
347+
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
348+
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
349+
func.func @masked_matvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
350+
// CHECK: vector.transpose %[[MASK]]
351+
// CHECK: vector.transpose %[[MAT]]
352+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
353+
%res = vector.mask %mask {
354+
vector.contract {
355+
indexing_maps = [affine_map<(m, k) -> (m, k)>,
356+
affine_map<(m, k) -> (k)>,
357+
affine_map<(m, k) -> (m)>],
358+
iterator_types = ["parallel", "reduction"],
359+
kind = #vector.kind<add>
360+
} %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
361+
} : vector<4x2xi1> -> vector<4xf32>
362+
return %res : vector<4xf32>
363+
}
364+
365+
// CHECK-LABEL: @masked_matvec_km_k_m
366+
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
367+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
368+
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
369+
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
370+
func.func @masked_matvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
371+
// CHECK: vector.transpose %[[MASK]]
372+
// CHECK-NOT: vector.transpose %[[MAT]]
373+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
374+
%res = vector.mask %mask {
375+
vector.contract {
376+
indexing_maps = [affine_map<(m, k) -> (k, m)>,
377+
affine_map<(m, k) -> (k)>,
378+
affine_map<(m, k) -> (m)>],
379+
iterator_types = ["parallel", "reduction"],
380+
kind = #vector.kind<add>
381+
} %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
382+
} : vector<4x2xi1> -> vector<4xf32>
383+
return %res : vector<4xf32>
384+
}
385+
386+
// CHECK-LABEL: @masked_matvec_k_mk_m
387+
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
388+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
389+
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
390+
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
391+
func.func @masked_matvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
392+
// CHECK: vector.transpose %[[MASK]]
393+
// CHECK: vector.transpose %[[MAT]]
394+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
395+
%res = vector.mask %mask {
396+
vector.contract {
397+
indexing_maps = [affine_map<(m, k) -> (k)>,
398+
affine_map<(m, k) -> (m, k)>,
399+
affine_map<(m, k) -> (m)>],
400+
iterator_types = ["parallel", "reduction"],
401+
kind = #vector.kind<add>
402+
} %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
403+
} : vector<4x2xi1> -> vector<4xf32>
404+
return %res : vector<4xf32>
405+
}
406+
407+
// CHECK-LABEL: @masked_matvec_k_km_m
408+
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
409+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
410+
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
411+
// CHECK-SAME: %[[MASK:.+]]: vector<4x2xi1>
412+
func.func @masked_matvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<4x2xi1>) -> vector<4xf32> {
413+
// CHECK: vector.transpose %[[MASK]]
414+
// CHECK-NOT: vector.transpose %[[MAT]]
415+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
416+
%res = vector.mask %mask {
417+
vector.contract {
418+
indexing_maps = [affine_map<(m, k) -> (k)>,
419+
affine_map<(m, k) -> (k, m)>,
420+
affine_map<(m, k) -> (m)>],
421+
iterator_types = ["parallel", "reduction"],
422+
kind = #vector.kind<add>
423+
} %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
424+
} : vector<4x2xi1> -> vector<4xf32>
425+
return %res : vector<4xf32>
426+
}
427+
428+
// CHECK-LABEL: @masked_tmatvec_mk_k_m
429+
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
430+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
431+
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
432+
// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
433+
func.func @masked_tmatvec_mk_k_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
434+
// CHECK: vector.transpose %[[MAT]]
435+
// CHECK-NOT: vector.transpose %[[MASK]]
436+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
437+
%res = vector.mask %mask {
438+
vector.contract {
439+
indexing_maps = [affine_map<(k, m) -> (m, k)>,
440+
affine_map<(k, m) -> (k)>,
441+
affine_map<(k, m) -> (m)>],
442+
iterator_types = ["reduction", "parallel"],
443+
kind = #vector.kind<add>
444+
} %arg0, %arg1, %arg2 : vector<4x2xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
445+
} : vector<2x4xi1> -> vector<4xf32>
446+
return %res : vector<4xf32>
447+
}
448+
449+
// CHECK-LABEL: @masked_tmatvec_km_k_m
450+
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
451+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
452+
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
453+
// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
454+
func.func @masked_tmatvec_km_k_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
455+
// CHECK-NOT: vector.transpose %[[MAT]]
456+
// CHECK-NOT: vector.transpose %[[MASK]]
457+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
458+
%res = vector.mask %mask {
459+
vector.contract {
460+
indexing_maps = [affine_map<(k, m) -> (k, m)>,
461+
affine_map<(k, m) -> (k)>,
462+
affine_map<(k, m) -> (m)>],
463+
iterator_types = ["reduction", "parallel"],
464+
kind = #vector.kind<add>
465+
} %arg0, %arg1, %arg2 : vector<2x4xf32>, vector<2xf32>, vector<4xf32> into vector<4xf32>
466+
} : vector<2x4xi1> -> vector<4xf32>
467+
return %res : vector<4xf32>
468+
}
469+
470+
// CHECK-LABEL: @masked_tmatvec_k_mk_m
471+
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
472+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
473+
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
474+
// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
475+
func.func @masked_tmatvec_k_mk_m(%arg0: vector<4x2xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
476+
// CHECK: vector.transpose %[[MAT]]
477+
// CHECK-NOT: vector.transpose %[[MASK]]
478+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
479+
%res = vector.mask %mask {
480+
vector.contract {
481+
indexing_maps = [affine_map<(k, m) -> (k)>,
482+
affine_map<(k, m) -> (m, k)>,
483+
affine_map<(k, m) -> (m)>],
484+
iterator_types = ["reduction", "parallel"],
485+
kind = #vector.kind<add>
486+
} %arg1, %arg0, %arg2 : vector<2xf32>, vector<4x2xf32>, vector<4xf32> into vector<4xf32>
487+
} : vector<2x4xi1> -> vector<4xf32>
488+
return %res : vector<4xf32>
489+
}
490+
491+
// CHECK-LABEL: @masked_tmatvec_k_km_m
492+
// CHECK-SAME: %[[MAT:.+]]: vector<2x4xf32>
493+
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
494+
// CHECK-SAME: %[[INIT:.+]]: vector<4xf32>
495+
// CHECK-SAME: %[[MASK:.+]]: vector<2x4xi1>
496+
func.func @masked_tmatvec_k_km_m(%arg0: vector<2x4xf32>, %arg1: vector<2xf32>, %arg2: vector<4xf32>, %mask: vector<2x4xi1>) -> vector<4xf32> {
497+
// CHECK-NOT: vector.transpose %[[MAT]]
498+
// CHECK-NOT: vector.transpose %[[MASK]]
499+
// CHECK-COUNT-2: vector.mask %{{.*}} { vector.outerproduct %{{.*}}, %{{.*}}, %{{.*}} {kind = #vector.kind<add>} : vector<4xf32>, f32 }
500+
%res = vector.mask %mask {
501+
vector.contract {
502+
indexing_maps = [affine_map<(k, m) -> (k)>,
503+
affine_map<(k, m) -> (k, m)>,
504+
affine_map<(k, m) -> (m)>],
505+
iterator_types = ["reduction", "parallel"],
506+
kind = #vector.kind<add>
507+
} %arg1, %arg0, %arg2 : vector<2xf32>, vector<2x4xf32>, vector<4xf32> into vector<4xf32>
508+
} : vector<2x4xi1> -> vector<4xf32>
509+
return %res : vector<4xf32>
510+
}
511+
344512

345513
transform.sequence failures(propagate) {
346514
^bb1(%module_op: !transform.any_op):

0 commit comments

Comments
 (0)