@@ -1017,7 +1017,7 @@ func.func @broadcast_same_shape(%input: tensor<2x3xf32>, %init: tensor<2x3xf32>)
1017
1017
return %0 : tensor <2 x3 xf32 >
1018
1018
}
1019
1019
1020
- // ----
1020
+ // -----
1021
1021
1022
1022
func.func @transpose_1d (%input: tensor <16 xf32 >,
1023
1023
%init: tensor <16 xf32 >) -> tensor <16 xf32 > {
@@ -1096,3 +1096,76 @@ func.func @transpose_transpose_fold(%input: tensor<5x4x3xf32>,
1096
1096
func.return %transpose2 : tensor <3 x4 x5 xf32 >
1097
1097
}
1098
1098
1099
+ // -----
1100
+
1101
+ func.func @broadcast_transpose_fold (%input: tensor <2 x4 x5 xf32 >,
1102
+ %init1: tensor <1 x2 x3 x4 x5 x6 xf32 >,
1103
+ %init2: tensor <1 x6 x2 x3 x5 x4 xf32 >) -> tensor <1 x6 x2 x3 x5 x4 xf32 > {
1104
+ // CHECK-LABEL: @broadcast_transpose_fold
1105
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2x4x5xf32>
1106
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x2x3x4x5x6xf32>
1107
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x6x2x3x5x4xf32>
1108
+ // CHECK: %[[TMP_INIT:.+]] = tensor.empty() : tensor<2x5x4xf32>
1109
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<2x4x5xf32>) outs(%[[TMP_INIT]] : tensor<2x5x4xf32>) permutation = [0, 2, 1]
1110
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<2x5x4xf32>) outs(%[[INIT2]] : tensor<1x6x2x3x5x4xf32>) dimensions = [0, 3, 1]
1111
+ // CHECK: return %[[BROADCAST]] : tensor<1x6x2x3x5x4xf32>
1112
+ %broadcast = linalg.broadcast
1113
+ ins (%input : tensor <2 x4 x5 xf32 >)
1114
+ outs (%init1 : tensor <1 x2 x3 x4 x5 x6 xf32 >)
1115
+ dimensions = [0 , 2 , 5 ]
1116
+ %transpose = linalg.transpose
1117
+ ins (%broadcast : tensor <1 x2 x3 x4 x5 x6 xf32 >)
1118
+ outs (%init2 : tensor <1 x6 x2 x3 x5 x4 xf32 >)
1119
+ permutation = [0 , 5 , 1 , 2 , 4 , 3 ]
1120
+ func.return %transpose : tensor <1 x6 x2 x3 x5 x4 xf32 >
1121
+ }
1122
+
1123
+ // -----
1124
+
1125
+ func.func @broadcast_transpose_fold_dynamic (%input: tensor <?x?x5 xf32 >,
1126
+ %init1: tensor <1 x?x3 x?x5 x6 xf32 >,
1127
+ %init2: tensor <1 x3 x?x6 x5 x?xf32 >) -> tensor <1 x3 x?x6 x5 x?xf32 > {
1128
+ // CHECK-LABEL: @broadcast_transpose_fold_dynamic
1129
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<?x?x5xf32>
1130
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<1x?x3x?x5x6xf32>
1131
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<1x3x?x6x5x?xf32>
1132
+ // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
1133
+ // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
1134
+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x5xf32>
1135
+ // CHECK: %[[DIM1:.+]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<?x?x5xf32>
1136
+ // CHECK: %[[TMP_INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM0]]) : tensor<?x5x?xf32>
1137
+ // CHECK: %[[TRANSPOSE:.+]] = linalg.transpose ins(%[[INPUT]] : tensor<?x?x5xf32>) outs(%[[TMP_INIT]] : tensor<?x5x?xf32>) permutation = [1, 2, 0]
1138
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[TRANSPOSE]] : tensor<?x5x?xf32>) outs(%[[INIT2]] : tensor<1x3x?x6x5x?xf32>) dimensions = [0, 1, 3]
1139
+ // CHECK: return %[[BROADCAST]] : tensor<1x3x?x6x5x?xf32>
1140
+ %broadcast = linalg.broadcast
1141
+ ins (%input : tensor <?x?x5 xf32 >)
1142
+ outs (%init1 : tensor <1 x?x3 x?x5 x6 xf32 >)
1143
+ dimensions = [0 , 2 , 5 ]
1144
+ %transpose = linalg.transpose
1145
+ ins (%broadcast : tensor <1 x?x3 x?x5 x6 xf32 >)
1146
+ outs (%init2 : tensor <1 x3 x?x6 x5 x?xf32 >)
1147
+ permutation = [0 , 2 , 3 , 5 , 4 , 1 ]
1148
+ func.return %transpose : tensor <1 x3 x?x6 x5 x?xf32 >
1149
+ }
1150
+
1151
+ // -----
1152
+
1153
+ func.func @broadcast_transpose_fold_2dim (%input: tensor <2 xf32 >,
1154
+ %init1: tensor <2 x4 xf32 >,
1155
+ %init2: tensor <4 x2 xf32 >) -> tensor <4 x2 xf32 > {
1156
+ // CHECK-LABEL: @broadcast_transpose_fold_2dim
1157
+ // CHECK-SAME: %[[INPUT:[a-zA-Z0-9]+]]: tensor<2xf32>
1158
+ // CHECK-SAME: %[[INIT1:[a-zA-Z0-9]+]]: tensor<2x4xf32>
1159
+ // CHECK-SAME: %[[INIT2:[a-zA-Z0-9]+]]: tensor<4x2xf32>
1160
+ // CHECK: %[[BROADCAST:.+]] = linalg.broadcast ins(%[[INPUT]] : tensor<2xf32>) outs(%[[INIT2]] : tensor<4x2xf32>) dimensions = [0]
1161
+ // CHECK: return %[[BROADCAST]] : tensor<4x2xf32>
1162
+ %broadcast = linalg.broadcast
1163
+ ins (%input : tensor <2 xf32 >)
1164
+ outs (%init1 : tensor <2 x4 xf32 >)
1165
+ dimensions = [1 ]
1166
+ %transpose = linalg.transpose
1167
+ ins (%broadcast : tensor <2 x4 xf32 >)
1168
+ outs (%init2 : tensor <4 x2 xf32 >)
1169
+ permutation = [1 , 0 ]
1170
+ func.return %transpose : tensor <4 x2 xf32 >
1171
+ }
0 commit comments