Skip to content

[Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case #107922

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,28 @@ static Value calculateGatherOffset(RewriterBase &rewriter,

enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };

/// Find the non-unit dim in a linalgOp.
/// When executing this hook, it is expected that only one dim will be non-unit.
/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
/// loads before calling this method. This is used for finding contiguous loads
/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
/// this condition is expected to hold for statically shaped Linalg Ops only.
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
uint64_t nonUnitDim = 0;
uint64_t countNonUnitDim = 0;
for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
if (tripCount.value() != 1) {
nonUnitDim = tripCount.index();
countNonUnitDim++;
}
}

assert(linalgOp.hasDynamicShape() ||
countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
"non-unit loop dim is expected");
return nonUnitDim;
}

/// Checks whether `val` can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
VectorType resType) {
Expand Down Expand Up @@ -889,11 +911,12 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");

// Given the assumption on the loop ranges above, only the trailing loop
// index is not constant.
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
// Given the assumption on the loop ranges above, we expect only 1 non-unit
// loop dim.
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);

if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
return true;
}

Expand Down
52 changes: 52 additions & 0 deletions mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<8x1xf32>
%1 = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel", "parallel"]
} outs(%0 : tensor<8x1xf32>) {
^bb0(%arg5: f32):
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
%4 = affine.apply #map1(%arg1, %3, %arg1)
%extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
linalg.yield %extracted : f32
} -> tensor<8x1xf32>
return %1 : tensor<8x1xf32>
}

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}

// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>

// -----

#map = affine_map<(d0) -> (d0)>
Expand Down
Loading