Skip to content

Commit 69ece45

Browse files
address reviwer comments
1 parent b88f8e3 commit 69ece45

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -810,10 +810,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
810810

811811
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
812812

813-
814-
/// Find the non constant dim in a linalgOp. This is used for finding contiguous
815-
/// loads and it is expected that only one dim will be non constant, if thats
816-
/// not the case this function will assert.
813+
/// Find the non-unit dim in a linalgOp.
814+
/// When executing this hook, it is expected that only one dim will be non-unit.
815+
/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
816+
/// loads before calling this method. This is used for finding contiguous loads
817+
/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
818+
/// this condition is expected to hold for statically shaped Linalg Ops only.
817819
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
818820
uint64_t nonUnitDim = 0;
819821
uint64_t countNonUnitDim = 0;
@@ -823,8 +825,10 @@ static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
823825
countNonUnitDim++;
824826
}
825827
}
826-
assert(countNonUnitDim == 1 &&
827-
"Expected only one non unit loop dim in this linalg op");
828+
829+
assert(linalgOp.hasDynamicShape() ||
830+
countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
831+
"non-unit loop dim is expected");
828832
return nonUnitDim;
829833
}
830834

@@ -908,6 +912,8 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
908912
Operation *defOp = val.getDefiningOp();
909913
assert(defOp && "This is neither a block argument nor an operation result");
910914

915+
// Given the assumption on the loop ranges above, we expect only 1 non-unit
916+
// loop dim.
911917
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
912918

913919
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ module attributes {transform.with_named_sequence} {
258258

259259
#map = affine_map<(d0, d1) -> (d0, d1)>
260260
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
261-
func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
261+
func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
262262
%c0 = arith.constant 0 : index
263263
%0 = tensor.empty() : tensor<8x1xf32>
264264
%1 = linalg.generic {
@@ -276,15 +276,15 @@ func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x12
276276
}
277277

278278
module attributes {transform.with_named_sequence} {
279-
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
280-
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
279+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
280+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
281281
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
282282
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
283283
transform.yield
284284
}
285285
}
286286

287-
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_without_outer_unit_dim
287+
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load
288288
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
289289
// CHECK-SAME: %[[ARG1:.*]]: index
290290
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index

0 commit comments

Comments
 (0)