@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
37
37
}
38
38
39
39
// -----
40
+
40
41
#map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
41
42
func.func @vectorize_nd_tensor_extract_constant_idx (%arg0: tensor <3 x3 xf32 >, %arg2: tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
42
43
%c0 = arith.constant 1 : index
@@ -74,20 +75,24 @@ module attributes {transform.with_named_sequence} {
74
75
75
76
// -----
76
77
77
- #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
78
- func.func @vectorize_nd_tensor_extract_transfer_read_basic (%arg0: tensor <3 x3 x3 xf32 >, %arg2: tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
79
- %1 = linalg.generic {
80
- indexing_maps = [#map1 ],
78
+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
79
+ func.func @vectorize_nd_tensor_extract_transfer_read_basic (
80
+ %arg0: tensor <3 x3 x3 xf32 >,
81
+ %arg1: tensor <1 x1 x3 xf32 >) -> tensor <1 x1 x3 xf32 > {
82
+
83
+ %res = linalg.generic {
84
+ indexing_maps = [#map ],
81
85
iterator_types = [" parallel" , " parallel" , " parallel" ]
82
- } outs (%arg2 : tensor <1 x1 x3 xf32 >) {
83
- ^bb0 (%arg4 : f32 ):
84
- %2 = linalg.index 0 : index
85
- %3 = linalg.index 1 : index
86
- %4 = linalg.index 2 : index
87
- %5 = tensor.extract %arg0 [%2 , %3 , %4 ] : tensor <3 x3 x3 xf32 >
88
- linalg.yield %5 : f32
86
+ } outs (%arg1 : tensor <1 x1 x3 xf32 >) {
87
+ ^bb0 (%out : f32 ):
88
+ %1 = linalg.index 0 : index
89
+ %2 = linalg.index 1 : index
90
+ %3 = linalg.index 2 : index
91
+ %4 = tensor.extract %arg0 [%1 , %2 , %3 ] : tensor <3 x3 x3 xf32 >
92
+ linalg.yield %4 : f32
89
93
} -> tensor <1 x1 x3 xf32 >
90
- return %1 : tensor <1 x1 x3 xf32 >
94
+
95
+ return %res : tensor <1 x1 x3 xf32 >
91
96
}
92
97
93
98
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
@@ -104,6 +109,38 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
104
109
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
105
110
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
106
111
112
+ // Same as example above, but reading into a column tensor. Note that after the
113
+ // vectorizatoin, the `TransferOpReduceRank` will replace
114
+ // `vector.transfer_read` with `tensor.extract -> scalar`.
115
+
116
+ // TODO: Currently this fails to vectorise when the indices are non-constant.
117
+
118
+ func.func @vectorize_nd_tensor_extract_transfer_read_basic_column (
119
+ %input: tensor <3 x3 x3 xf32 >,
120
+ %output: tensor <3 x1 x1 xf32 >) -> tensor <3 x1 x1 xf32 > {
121
+
122
+ %c0 = arith.constant 0 : index
123
+ %res = linalg.generic {
124
+ indexing_maps = [#map ],
125
+ iterator_types = [" parallel" , " parallel" , " parallel" ]
126
+ } outs (%output : tensor <3 x1 x1 xf32 >) {
127
+ ^bb0 (%out: f32 ):
128
+ %5 = tensor.extract %input [%c0 , %c0 , %c0 ] : tensor <3 x3 x3 xf32 >
129
+ linalg.yield %5 : f32
130
+ } -> tensor <3 x1 x1 xf32 >
131
+
132
+ return %res : tensor <3 x1 x1 xf32 >
133
+ }
134
+
135
+ // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
136
+ // CHECK-SAME: %[[INPUT:.*]]: tensor<3x3x3xf32>,
137
+ // CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x1x1xf32>)
138
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
139
+ // CHECK: %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32>
140
+ // CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32>
141
+ // CHECK: %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32>
142
+ // CHECK: return %[[RES]] : tensor<3x1x1xf32>
143
+
107
144
module attributes {transform.with_named_sequence } {
108
145
transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
109
146
%0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
@@ -595,3 +632,59 @@ module attributes {transform.with_named_sequence} {
595
632
transform.yield
596
633
}
597
634
}
635
+
636
+
637
+ // -----
638
+
639
+ func.func @vectorize_scalar_broadcast_column_tensor (%in: tensor <1 x1 x4 xi32 >) -> tensor <1 x1 x4 xi32 > {
640
+ %c4 = arith.constant 4 : index
641
+ %c0 = arith.constant 0 : index
642
+ %cst = arith.constant dense <[[0 ], [1 ], [2 ], [3 ], [4 ], [5 ], [6 ], [7 ], [8 ], [9 ], [10 ], [11 ], [12 ], [13 ], [14 ]]> : tensor <15 x1 xi32 >
643
+
644
+ %out = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" ]} outs (%in : tensor <1 x1 x4 xi32 >) {
645
+ ^bb0 (%out: i32 ):
646
+ %8 = linalg.index 0 : index
647
+ %idx_0 = linalg.index 0 : index
648
+ %extracted = tensor.extract %cst [%idx_0 , %c0 ] : tensor <15 x1 xi32 >
649
+ linalg.yield %extracted : i32
650
+ } -> tensor <1 x1 x4 xi32 >
651
+
652
+ return %out:tensor <1 x1 x4 xi32 >
653
+ }
654
+
655
+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
656
+ // CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
657
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
658
+ // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
659
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
660
+ // CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
661
+ // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
662
+ // CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
663
+ // CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
664
+ // CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
665
+ // CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
666
+ // CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
667
+ // CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
668
+ // CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
669
+ // CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
670
+ // CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
671
+ // CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
672
+ // CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
673
+ // CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
674
+ // CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
675
+ // CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
676
+ // CHECK: %[[VAL_19:.*]] = arith.constant 0 : i32
677
+ // CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
678
+ // CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
679
+ // CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
680
+ // CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
681
+ // CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
682
+ // CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
683
+
684
+ module attributes {transform.with_named_sequence } {
685
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
686
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
687
+ transform.structured.vectorize %0 vector_sizes [1 , 1 , 4 ]{ vectorize_nd_extract } : !transform.any_op
688
+ transform.yield
689
+ }
690
+ }
0 commit comments