Skip to content

Commit 5620d3b

Browse files
banach-spaceGroverkss
authored andcommitted
Revert "[mlir][Vector] Support 0-d vectors natively in TransferOpReduceRank (llvm#112907)"
This reverts commit 1004865. Failing CI as discussed here: * iree-org/iree#19135
1 parent 1466711 commit 5620d3b

File tree

4 files changed

+33
-11
lines changed

4 files changed

+33
-11
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,10 +358,31 @@ struct TransferOpReduceRank
358358
op, "map is not a minor identity with broadcasting");
359359
}
360360

361+
// TODO: support zero-dimension vectors natively. See:
362+
// https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097.
363+
// In the meantime, lower these to a scalar load when they pop up.
364+
if (reducedShapeRank == 0) {
365+
Value newRead;
366+
if (isa<TensorType>(op.getShapedType())) {
367+
newRead = rewriter.create<tensor::ExtractOp>(
368+
op.getLoc(), op.getSource(), op.getIndices());
369+
} else {
370+
newRead = rewriter.create<memref::LoadOp>(
371+
op.getLoc(), originalVecType.getElementType(), op.getSource(),
372+
op.getIndices());
373+
}
374+
return rewriter
375+
.create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead)
376+
.getVector();
377+
}
378+
361379
SmallVector<int64_t> newShape(
362380
originalVecType.getShape().take_back(reducedShapeRank));
363381
SmallVector<bool> newScalableDims(
364382
originalVecType.getScalableDims().take_back(reducedShapeRank));
383+
// Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering.
384+
if (newShape.empty())
385+
return rewriter.notifyMatchFailure(op, "rank-reduced vector is 0-d");
365386

366387
VectorType newReadType = VectorType::get(
367388
newShape, originalVecType.getElementType(), newScalableDims);

mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ func.func @transfer_read_within_async_execute(%A : memref<2x2xf32>) -> !async.to
503503

504504
// CHECK-LABEL: transfer_read_with_tensor
505505
func.func @transfer_read_with_tensor(%arg: tensor<f32>) -> vector<1xf32> {
506-
// CHECK: %[[EXTRACTED:.*]] = vector.transfer_read %{{.*}}[], %{{.*}} : tensor<f32>, vector<f32>
507-
// CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : vector<f32> to vector<1xf32>
506+
// CHECK: %[[EXTRACTED:.*]] = tensor.extract %{{.*}}[] : tensor<f32>
507+
// CHECK-NEXT: %[[RESULT:.*]] = vector.broadcast %[[EXTRACTED]] : f32 to vector<1xf32>
508508
// CHECK-NEXT: return %[[RESULT]] : vector<1xf32>
509509
%f0 = arith.constant 0.0 : f32
510510
%0 = vector.transfer_read %arg[], %f0 {permutation_map = affine_map<()->(0)>} :

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,9 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(
138138
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
139139
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
140140

141-
// Same as example above, but reading into a column tensor.
141+
// Same as example above, but reading into a column tensor. Note that after the
142+
// vectorizatoin, the `TransferOpReduceRank` will replace
143+
// `vector.transfer_read` with `tensor.extract -> scalar`.
142144

143145
// TODO: Currently this fails to vectorise when the indices are non-constant.
144146

@@ -162,10 +164,9 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
162164
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
163165
// CHECK-SAME: %[[INPUT:.*]]: tensor<3x3x3xf32>,
164166
// CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x1x1xf32>)
165-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
166-
// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
167-
// CHECK: %[[READ:.*]] = vector.transfer_read %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]], %[[CST_0]] : tensor<3x3x3xf32>, vector<f32>
168-
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[READ]] : vector<f32> to vector<3x1x1xf32>
167+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
168+
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32>
169+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32>
169170
// CHECK: %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32>
170171
// CHECK: return %[[RES]] : tensor<3x1x1xf32>
171172

@@ -747,8 +748,8 @@ func.func @vectorize_0d_tensor_extract(%arg0: tensor<f32>, %arg2: tensor<1x1x3xf
747748

748749
// CHECK-LABEL: func.func @vectorize_0d_tensor_extract(
749750
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>
750-
// CHECK: %[[EXTRACT:.*]] = vector.transfer_read %[[ARG_0]][], %{{.+}} : tensor<f32>
751-
// CHECK: vector.broadcast %[[EXTRACT]] : vector<f32> to vector<1x1x3xf32>
751+
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor<f32>
752+
// CHECK: vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32>
752753

753754
module attributes {transform.with_named_sequence} {
754755
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {

mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf
2626
func.func @vector_transfer_ops_0d_tensor(%src: tensor<f32>) -> vector<1xf32> {
2727
%f0 = arith.constant 0.0 : f32
2828

29-
// CHECK: %[[S:.*]] = vector.transfer_read %[[SRC]][]
30-
// CHECK: %[[V:.*]] = vector.broadcast %[[S]] : vector<f32> to vector<1xf32>
29+
// CHECK-NEXT: %[[S:.*]] = tensor.extract %[[SRC]][] : tensor<f32>
30+
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<1xf32>
3131
%res = vector.transfer_read %src[], %f0 {in_bounds = [true], permutation_map = affine_map<()->(0)>} :
3232
tensor<f32>, vector<1xf32>
3333

0 commit comments

Comments
 (0)