Skip to content

Commit 9d07ee7

Browse files
committed
[mlir][Vector] Make vector.contract work with scalable vectors
This is just a small fix that makes sure that `vector.contract` works with scalable vectors. Rather than duplicating all the roundtrip tests for vector.contract, I'm treating scalable vectors as an edge case and just adding a couple of test cases to verify that this works.
1 parent 64366d4 commit 9d07ee7

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,8 @@ static LogicalResult verifyOutputShape(
820820
return e.cast<AffineConstantExpr>().getValue();
821821
}));
822822
auto expected =
823-
VectorType::get(expectedShape, resVectorType.getElementType());
823+
VectorType::get(expectedShape, resVectorType.getElementType(),
824+
resVectorType.getScalableDims());
824825
if (resVectorType != expected || accVectorType != expected)
825826
return op.emitOpError(
826827
"invalid accumulator/result vector shape, expected: ")

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,17 @@ func.func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -
307307
return %0 : f32
308308
}
309309

310+
// CHECK-LABEL: @contraction_to_scalar_scalable
311+
func.func @contraction_to_scalar_scalable(%arg0: vector<[10]xf32>, %arg1: vector<[10]xf32>) -> f32 {
312+
// CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
313+
%f0 = arith.constant 0.0: f32
314+
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<add>} %{{.*}}, %{{.*}}, %[[C0]] : vector<[10]xf32>, vector<[10]xf32> into f32
315+
%0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
316+
: vector<[10]xf32>, vector<[10]xf32> into f32
317+
// CHECK: return %[[X]] : f32
318+
return %0 : f32
319+
}
320+
310321
// CHECK-LABEL: @contraction_extra_attrs
311322
func.func @contraction_extra_attrs(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
312323
// CHECK: %[[C0:.*]] = arith.constant 0.000000e+00 : f32
@@ -392,6 +403,24 @@ func.func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf3
392403
return
393404
}
394405

406+
#contraction_matmul_accesses = [
407+
affine_map<(d0, d1, d2) -> (d0, d2)>,
408+
affine_map<(d0, d1, d2) -> (d2, d1)>,
409+
affine_map<(d0, d1, d2) -> (d0, d1)>
410+
]
411+
#contraction_matmul_trait = {
412+
indexing_maps = #contraction_matmul_accesses,
413+
iterator_types = ["parallel", "parallel", "reduction"]
414+
}
415+
// CHECK-LABEL: @contraction_matmul_scalable
416+
func.func @contraction_matmul_scalable(%A: vector<8x1xf32>, %B: vector<1x[32]xf32>, %C: vector<8x[32]xf32>) -> vector<8x[32]xf32> {
417+
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
418+
%res = vector.contract #contraction_matmul_trait %A, %B, %C
419+
: vector<8x1xf32>, vector<1x[32]xf32> into vector<8x[32]xf32>
420+
// CHECK: return %[[X]] : vector<8x[32]xf32>
421+
return %res : vector<8x[32]xf32>
422+
}
423+
395424
// CHECK-LABEL: @create_vector_mask
396425
func.func @create_vector_mask() {
397426
// CHECK: %[[C2:.*]] = arith.constant 2 : index

0 commit comments

Comments
 (0)