Skip to content

Commit d95238f

Browse files
committed
[mlir][nfc] Update vectorize-tensor-extract.mlir (1/N)
Tests in "vectorize-tensor-extract.mlir" are inconsistent and would benefit from refactoring to: * Clearly categorize tests into "contiguous load," "gather load," and "scalar load + broadcast" cases, reflecting the structure of tensor.extract vectorization. * Unify variable naming (both MLIR and FileCheck). * Ensure all tests exercise unmasked vectorization (masked vectorization is covered in "vectorize-tensor-extract-masked.mlir"). * Improve and standardize formatting. These changes will make it easier to identify the test cases being exercised and simplify future maintenance or refactoring. This is patch 1/N in the series. Below is a summary of the changes in this patch. ---------------------------------------------------------------------- This PR updates the `@vectorize_scalar_broadcast_column_tensor` test in "vectorize-tensor-extract.mlir", which exercises: * Vectorization of tensor.extract. * A scalar read followed by a broadcast. * Reading from a constant column tensor. Currently, the test uses "masked" vectorization, but the file exclusively tests unmasked vectorization paths. To address this inconsistency, this PR removes masking, aligning the test with the rest of the file. Masked vectorization scenarios remain covered in "vectorize-tensor-extract-masked.mlir". This update switches from: * `transform.structured.vectorize`, to * `transform.structured.vectorize_children_and_apply_patterns`. The latter approach applies canonicalization patterns, significantly simplifying the generated output. Additional improvements for readability: * Renamed the test function for clarity. * Updated variable names and removed unused variables. * Added empty lines for better formatting.
1 parent 1885886 commit d95238f

File tree

1 file changed

+26
-41
lines changed

1 file changed

+26
-41
lines changed

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 26 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -807,56 +807,41 @@ module attributes {transform.with_named_sequence} {
807807

808808
// -----
809809

810-
func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
810+
func.func @vectorize_scalar_read_with_broadcast_from_column_tensor(%init: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
811811
%c4 = arith.constant 4 : index
812812
%c0 = arith.constant 0 : index
813-
%cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
814-
815-
%out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
816-
^bb0(%out: i32):
817-
%8 = linalg.index 0 : index
818-
%idx_0 = linalg.index 0 : index
819-
%extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
820-
linalg.yield %extracted : i32
813+
%src = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
814+
815+
%res = linalg.generic {
816+
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
817+
iterator_types = ["parallel", "parallel", "parallel"]}
818+
outs(%init : tensor<1x1x4xi32>) {
819+
820+
^bb0(%out: i32):
821+
%idx = linalg.index 0 : index
822+
%extracted = tensor.extract %src[%idx, %c0] : tensor<15x1xi32>
823+
linalg.yield %extracted : i32
821824
} -> tensor<1x1x4xi32>
822825

823-
return %out:tensor<1x1x4xi32>
826+
return %res : tensor<1x1x4xi32>
824827
}
825828

826-
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
827-
// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
828-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
829-
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
830-
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
831-
// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
832-
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
833-
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
834-
// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
835-
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
836-
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
837-
// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
838-
// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
839-
// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
840-
// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
841-
// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
842-
// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
843-
// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
844-
// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
845-
// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
846-
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
847-
// CHECK: %[[VAL_19:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
848-
// CHECK: %[[VAL_20:.*]] = vector.extract %[[VAL_19]][0] : index from vector<4xindex>
849-
// CHECK: %[[VAL_21:.*]] = arith.constant 0 : i32
850-
// CHECK: %[[VAL_22:.*]] = vector.constant_mask [1] : vector<1xi1>
851-
// CHECK: %[[VAL_23:.*]] = vector.mask %[[VAL_22]] { vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_20]], %[[VAL_2]]], %[[VAL_21]] {in_bounds = [true, true, true], permutation_map = #[[$MAP]]} : tensor<15x1xi32>, vector<1x1x4xi32> } : vector<1xi1> -> vector<1x1x4xi32>
852-
// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
853-
// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
854-
// CHECK: return %[[VAL_25]] : tensor<1x1x4xi32>
829+
// CHECK-LABEL: func.func @vectorize_scalar_read_with_broadcast_from_column_tensor(
830+
// CHECK-SAME: %[[INIT:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
831+
// CHECK: %[[PAD:.*]] = arith.constant 0 : i32
832+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
833+
// CHECK: %[[SRC:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
834+
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<0> : vector<1xindex>
835+
// CHECK: %[[IDX_ELT:.*]] = vector.extract %[[IDX_VEC]][0] : index from vector<1xindex>
836+
// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]]{{\[}}%[[IDX_ELT]], %[[C0]]], %[[PAD]] : tensor<15x1xi32>, vector<i32>
837+
// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<i32> to vector<1x1x4xi32>
838+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xi32>, tensor<1x1x4xi32>
855839

856840
module attributes {transform.with_named_sequence} {
857841
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
858-
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
859-
transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
842+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
843+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
844+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
860845
transform.yield
861846
}
862847
}

0 commit comments

Comments
 (0)