Skip to content

Commit cf402a1

Browse files
committed
[mlir][vector] Add unit test for vector distribute by block
When distributing a vector larger than the given multiplicity, we can distribute it by block where each id gets a chunk of consecutive element along the dimension distributed. This adds a test for this case and adds extra checks to make sure we don't distribute for cases not multiple of multiplicity. Differential Revision: https://reviews.llvm.org/D89061
1 parent afff74e commit cf402a1

File tree

3 files changed

+61
-9
lines changed

3 files changed

+61
-9
lines changed

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2444,7 +2444,14 @@ mlir::vector::distributPointwiseVectorOp(OpBuilder &builder, Operation *op,
24442444
OpBuilder::InsertionGuard guard(builder);
24452445
builder.setInsertionPointAfter(op);
24462446
Location loc = op->getLoc();
2447+
if (op->getNumResults() != 1)
2448+
return {};
24472449
Value result = op->getResult(0);
2450+
VectorType type = op->getResult(0).getType().dyn_cast<VectorType>();
2451+
// Currently only support distributing 1-D vectors of size multiple of the
2452+
// given multiplicty. To handle more sizes we would need to support masking.
2453+
if (!type || type.getRank() != 1 || type.getNumElements() % multiplicity != 0)
2454+
return {};
24482455
DistributeOps ops;
24492456
ops.extract =
24502457
builder.create<vector::ExtractMapOp>(loc, result, id, multiplicity);

mlir/test/Dialect/Vector/vector-distribution.mlir

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -test-vector-distribute-patterns | FileCheck %s
1+
// RUN: mlir-opt %s -test-vector-distribute-patterns=distribution-multiplicity=32 | FileCheck %s
22

33
// CHECK-LABEL: func @distribute_vector_add
44
// CHECK-SAME: (%[[ID:.*]]: index
@@ -14,12 +14,12 @@ func @distribute_vector_add(%id : index, %A: vector<32xf32>, %B: vector<32xf32>)
1414

1515
// CHECK-LABEL: func @vector_add_read_write
1616
// CHECK-SAME: (%[[ID:.*]]: index
17-
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
18-
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
17+
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
18+
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
1919
// CHECK-NEXT: %[[ADD1:.*]] = addf %[[EXA]], %[[EXB]] : vector<1xf32>
20-
// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : memref<32xf32>, vector<1xf32>
20+
// CHECK-NEXT: %[[EXC:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<32xf32>, vector<1xf32>
2121
// CHECK-NEXT: %[[ADD2:.*]] = addf %[[ADD1]], %[[EXC]] : vector<1xf32>
22-
// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%{{.*}}] : vector<1xf32>, memref<32xf32>
22+
// CHECK-NEXT: vector.transfer_write %[[ADD2]], %{{.*}}[%[[ID]]] : vector<1xf32>, memref<32xf32>
2323
// CHECK-NEXT: return
2424
func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>, %C: memref<32xf32>, %D: memref<32xf32>) {
2525
%c0 = constant 0 : index
@@ -32,3 +32,41 @@ func @vector_add_read_write(%id : index, %A: memref<32xf32>, %B: memref<32xf32>,
3232
vector.transfer_write %d, %D[%c0]: vector<32xf32>, memref<32xf32>
3333
return
3434
}
35+
36+
// CHECK-LABEL: func @vector_add_cycle
37+
// CHECK-SAME: (%[[ID:.*]]: index
38+
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32>
39+
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[ID]]], %{{.*}} : memref<64xf32>, vector<2xf32>
40+
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<2xf32>
41+
// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[ID]]] : vector<2xf32>, memref<64xf32>
42+
// CHECK-NEXT: return
43+
func @vector_add_cycle(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
44+
%c0 = constant 0 : index
45+
%cf0 = constant 0.0 : f32
46+
%a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<64xf32>
47+
%b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<64xf32>
48+
%acc = addf %a, %b: vector<64xf32>
49+
vector.transfer_write %acc, %C[%c0]: vector<64xf32>, memref<64xf32>
50+
return
51+
}
52+
53+
// Negative test to make sure nothing is done in case the vector size is not a
54+
// multiple of multiplicity.
55+
// CHECK-LABEL: func @vector_negative_test
56+
// CHECK: %[[C0:.*]] = constant 0 : index
57+
// CHECK: %[[EXA:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32>
58+
// CHECK-NEXT: %[[EXB:.*]] = vector.transfer_read %{{.*}}[%[[C0]]], %{{.*}} : memref<64xf32>, vector<16xf32>
59+
// CHECK-NEXT: %[[ADD:.*]] = addf %[[EXA]], %[[EXB]] : vector<16xf32>
60+
// CHECK-NEXT: vector.transfer_write %[[ADD]], %{{.*}}[%[[C0]]] {{.*}} : vector<16xf32>, memref<64xf32>
61+
// CHECK-NEXT: return
62+
func @vector_negative_test(%id : index, %A: memref<64xf32>, %B: memref<64xf32>, %C: memref<64xf32>) {
63+
%c0 = constant 0 : index
64+
%cf0 = constant 0.0 : f32
65+
%a = vector.transfer_read %A[%c0], %cf0: memref<64xf32>, vector<16xf32>
66+
%b = vector.transfer_read %B[%c0], %cf0: memref<64xf32>, vector<16xf32>
67+
%acc = addf %a, %b: vector<16xf32>
68+
vector.transfer_write %acc, %C[%c0]: vector<16xf32>, memref<64xf32>
69+
return
70+
}
71+
72+

mlir/test/lib/Transforms/TestVectorTransforms.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,28 @@ struct TestVectorUnrollingPatterns
127127

128128
struct TestVectorDistributePatterns
129129
: public PassWrapper<TestVectorDistributePatterns, FunctionPass> {
130+
TestVectorDistributePatterns() = default;
131+
TestVectorDistributePatterns(const TestVectorDistributePatterns &pass) {}
130132
void getDependentDialects(DialectRegistry &registry) const override {
131133
registry.insert<VectorDialect>();
132134
registry.insert<AffineDialect>();
133135
}
136+
Option<int32_t> multiplicity{
137+
*this, "distribution-multiplicity",
138+
llvm::cl::desc("Set the multiplicity used for distributing vector"),
139+
llvm::cl::init(32)};
134140
void runOnFunction() override {
135141
MLIRContext *ctx = &getContext();
136142
OwningRewritePatternList patterns;
137143
FuncOp func = getFunction();
138144
func.walk([&](AddFOp op) {
139145
OpBuilder builder(op);
140146
Optional<mlir::vector::DistributeOps> ops = distributPointwiseVectorOp(
141-
builder, op.getOperation(), func.getArgument(0), 32);
142-
assert(ops.hasValue());
143-
SmallPtrSet<Operation *, 1> extractOp({ops->extract});
144-
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
147+
builder, op.getOperation(), func.getArgument(0), multiplicity);
148+
if (ops.hasValue()) {
149+
SmallPtrSet<Operation *, 1> extractOp({ops->extract});
150+
op.getResult().replaceAllUsesExcept(ops->insert.getResult(), extractOp);
151+
}
145152
});
146153
patterns.insert<PointwiseExtractPattern>(ctx);
147154
populateVectorToVectorTransformationPatterns(patterns, ctx);

0 commit comments

Comments
 (0)