Skip to content

Commit 28fa83f

Browse files
authored
Revert "[mlir][linalg] Relax tensor.extract vectorization" (#102232)
Reverts #99299 because it breaks the lowering. To repro: `mlir-opt -transform-interpreter ~/repro.mlir` ```mlir #map = affine_map<(d0, d1) -> (d0)> #map1 = affine_map<(d0, d1) -> (d1)> #map2 = affine_map<(d0, d1) -> (d0, d1)> #map3 = affine_map<(d0, d1) -> (d0 + d1)> module { func.func @foo(%arg0: index, %arg1: tensor<2xf32>, %arg2: tensor<4xf32>, %arg3: tensor<1xf32>) -> tensor<4x1xf32> { %c0 = arith.constant 0 : index %cst = arith.constant 1.000000e+00 : f32 %cst_0 = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<4x1xf32> %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel"]} ins(%arg2, %arg3 : tensor<4xf32>, tensor<1xf32>) outs(%0 : tensor<4x1xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): %2 = linalg.index 0 : index %3 = linalg.index 1 : index %4 = affine.apply #map3(%3, %arg0) %extracted = tensor.extract %arg1[%c0] : tensor<2xf32> %5 = arith.cmpi eq, %2, %c0 : index %6 = arith.cmpi ult, %2, %c0 : index %7 = arith.select %5, %cst, %in : f32 %8 = arith.select %6, %cst_0, %7 : f32 %9 = arith.cmpi eq, %4, %c0 : index %10 = arith.cmpi ult, %4, %c0 : index %11 = arith.select %9, %cst, %in_1 : f32 %12 = arith.select %10, %cst_0, %11 : f32 %13 = arith.mulf %8, %12 : f32 %14 = arith.mulf %13, %extracted : f32 %15 = arith.cmpi eq, %2, %4 : index %16 = arith.select %15, %cst, %cst_0 : f32 %17 = arith.subf %16, %14 : f32 linalg.yield %17 : f32 } -> tensor<4x1xf32> return %1 : tensor<4x1xf32> } } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.structured.vectorize %0 : !transform.any_op transform.yield } } ```
1 parent 388b632 commit 28fa83f

File tree

2 files changed

+20
-71
lines changed

2 files changed

+20
-71
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -946,22 +946,27 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
946946
if (linalgOp.hasDynamicShape())
947947
return VectorMemoryAccessKind::Gather;
948948

949-
// True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
950-
// otherwise.
951-
bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
952-
return dimSize > 1;
953-
}) == 1);
954-
955-
// 1. Assume that it's a gather load when reading non-1D vector.
956-
if (!isOutput1DVector)
949+
// 1. Assume that it's a gather load when reading _into_:
950+
// * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
951+
// * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
952+
// TODO: Relax these conditions.
953+
// FIXME: This condition assumes non-dynamic sizes.
954+
if ((llvm::count_if(targetShape,
955+
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
956+
targetShape.back() == 1)
957+
return VectorMemoryAccessKind::Gather;
958+
959+
// 2. Assume that it's a gather load when reading _from_ a tensor for which
960+
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
961+
// TODO: Relax this condition.
962+
if (inputShape.getShape().back() == 1)
957963
return VectorMemoryAccessKind::Gather;
958964

959965
bool leadingIdxsLoopInvariant = true;
960966

961-
// 2. Analyze the leading indices of `extractOp`.
967+
// 3. Analyze the leading indices of `extractOp`.
962968
// Look at the way each index is calculated and decide whether it is suitable
963-
// for a contiguous load, i.e. whether it's loop invariant. If not, it's a
964-
// gather load.
969+
// for a contiguous load, i.e. whether it's loop invariant.
965970
auto indices = extractOp.getIndices();
966971
auto leadIndices = indices.drop_back(1);
967972

@@ -977,13 +982,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
977982
return VectorMemoryAccessKind::Gather;
978983
}
979984

980-
// 3. Analyze the trailing index for `extractOp`.
985+
// 4. Analyze the trailing index for `extractOp`.
981986
// At this point we know that the leading indices are loop invariant. This
982987
// means that is potentially a scalar or a contiguous load. We can decide
983988
// based on the trailing idx.
984989
auto extractOpTrailingIdx = indices.back();
985990

986-
// 3a. Scalar broadcast load
991+
// 4a. Scalar broadcast load
987992
// If the trailing index is loop invariant then this is a scalar load.
988993
if (leadingIdxsLoopInvariant &&
989994
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
@@ -992,7 +997,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
992997
return VectorMemoryAccessKind::ScalarBroadcast;
993998
}
994999

995-
// 3b. Contiguous loads
1000+
// 4b. Contiguous loads
9961001
// The trailing `extractOp` index should increment with every loop iteration.
9971002
// This effectively means that it must be based on the trailing loop index.
9981003
// This is what the following bool captures.
@@ -1006,7 +1011,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
10061011
return VectorMemoryAccessKind::Contiguous;
10071012
}
10081013

1009-
// 4. Fallback case - gather load.
1014+
// 5. Fallback case - gather load.
10101015
LDBG("Found gather load: " << extractOp);
10111016
return VectorMemoryAccessKind::Gather;
10121017
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -595,59 +595,3 @@ module attributes {transform.with_named_sequence} {
595595
transform.yield
596596
}
597597
}
598-
599-
600-
// -----
601-
602-
func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
603-
%c4 = arith.constant 4 : index
604-
%c0 = arith.constant 0 : index
605-
%cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
606-
607-
%out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
608-
^bb0(%out: i32):
609-
%8 = linalg.index 0 : index
610-
%idx_0 = linalg.index 0 : index
611-
%extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
612-
linalg.yield %extracted : i32
613-
} -> tensor<1x1x4xi32>
614-
615-
return %out:tensor<1x1x4xi32>
616-
}
617-
618-
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
619-
// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
620-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
621-
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
622-
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
623-
// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
624-
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
625-
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
626-
// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
627-
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
628-
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
629-
// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
630-
// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
631-
// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
632-
// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
633-
// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
634-
// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
635-
// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
636-
// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
637-
// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
638-
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
639-
// CHECK: %[[VAL_19:.*]] = arith.constant 0 : i32
640-
// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
641-
// CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
642-
// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
643-
// CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
644-
// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
645-
// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
646-
647-
module attributes {transform.with_named_sequence} {
648-
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
649-
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
650-
transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
651-
transform.yield
652-
}
653-
}

0 commit comments

Comments
 (0)