Skip to content

Commit 09693d1

Browse files
ziereisGroverkss
authored andcommitted
add pattern for arith::UIToFPOp to VectorNarrowTypeRewritePatterns (llvm#115485)
This pr just adds the patterns from llvm#89131 for the arith::UIToFPOp. Also does some slight renaming and moving of the tests for better readability.
1 parent 6240ed2 commit 09693d1

File tree

2 files changed

+113
-81
lines changed

2 files changed

+113
-81
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,8 +1452,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
14521452
RewriteAlignedSubByteIntExt<arith::SIToFPOp, /*isSigned=*/true>,
14531453
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
14541454
benefit.getBenefit() + 1);
1455-
patterns.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>>(
1456-
patterns.getContext(), benefit.getBenefit() + 1);
1455+
patterns
1456+
.add<RewriteAlignedSubByteIntExt<arith::ExtUIOp, /*isSigned=*/false>,
1457+
RewriteAlignedSubByteIntExt<arith::UIToFPOp, /*isSigned=*/false>>(
1458+
patterns.getContext(), benefit.getBenefit() + 1);
14571459
}
14581460

14591461
void vector::populateVectorTransposeNarrowTypeRewritePatterns(

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 109 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -193,36 +193,8 @@ func.func @f3ext(%a: vector<5xi8>) -> vector<8xi17> {
193193
return %1 : vector<8xi17>
194194
}
195195

196-
// CHECK-LABEL: func.func @aligned_extsi(
197-
func.func @aligned_extsi(%a: vector<8xi4>) -> vector<8xi32> {
198-
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
199-
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
200-
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
201-
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
202-
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
203-
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
204-
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
205-
// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
206-
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
207-
return %0 : vector<8xi32>
208-
}
209-
210-
// CHECK-LABEL: func.func @aligned_extsi_2d(
211-
func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
212-
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
213-
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
214-
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
215-
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
216-
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
217-
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
218-
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
219-
// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
220-
%0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
221-
return %0 : vector<8x32xi32>
222-
}
223-
224-
// CHECK-LABEL: func.func @aligned_extsi_base_case(
225-
func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
196+
// CHECK-LABEL: func.func @aligned_extsi_i4_to_i8(
197+
func.func @aligned_extsi_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
226198
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
227199
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
228200
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
@@ -234,60 +206,61 @@ func.func @aligned_extsi_base_case(%a: vector<8xi4>) -> vector<8xi8> {
234206
return %0 : vector<8xi8>
235207
}
236208

237-
// CHECK-LABEL: func.func @aligned_sitofp(
238-
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
239-
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
209+
// CHECK-LABEL: func.func @aligned_extsi_i4_to_i32(
210+
func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
211+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
240212
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
241213
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
242214
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
243215
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
244216
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
245217
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
246-
// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
247-
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
248-
return %0 : vector<8xf32>
218+
// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
219+
%0 = arith.extsi %a : vector<8xi4> to vector<8xi32>
220+
return %0 : vector<8xi32>
249221
}
250222

251-
// CHECK-LABEL: func.func @aligned_sitofp_2d(
252-
func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
253-
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
223+
// CHECK-LABEL: func.func @aligned_extsi_2d(
224+
func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
225+
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
254226
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
255227
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
256228
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
257229
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
258230
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
259231
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
260-
// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
261-
%0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
262-
return %0 : vector<8x32xf32>
232+
// CHECK: %[[I32:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
233+
%0 = arith.extsi %a : vector<8x32xi4> to vector<8x32xi32>
234+
return %0 : vector<8x32xi32>
263235
}
264236

265-
// CHECK-LABEL: func.func @aligned_trunci(
266-
func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
267-
// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
237+
238+
// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
239+
func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
240+
// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
268241
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
269242
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
270-
// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
271-
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
243+
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
272244
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
273245
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
274246
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
275247
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
276-
%0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
248+
%0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
277249
return %0 : vector<8xi4>
278250
}
279251

280-
// CHECK-LABEL: func.func @aligned_trunci_base_case(
281-
func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
282-
// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
252+
// CHECK-LABEL: func.func @aligned_trunci_i32_to_i4(
253+
func.func @aligned_trunci_i32_to_i4(%a: vector<8xi32>) -> vector<8xi4> {
254+
// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
283255
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
284256
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
285-
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[IN]] : vector<8xi8> -> vector<4xi8>
257+
// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
258+
// CHECK: %[[LOW:.*]], %[[HIGH:.*]] = vector.deinterleave %[[I8]] : vector<8xi8> -> vector<4xi8>
286259
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
287260
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
288261
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
289262
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
290-
%0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
263+
%0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
291264
return %0 : vector<8xi4>
292265
}
293266

@@ -314,33 +287,26 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
314287
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[I4_MASK]] : vector<3x8x16xi8>
315288
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[LEFT_SHIFT_BITS]] : vector<3x8x16xi8>
316289
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<3x8x16xi8>
317-
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
290+
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<3x8x16xi8> to vector<3x8x32xi4>
318291
%0 = arith.trunci %a : vector<3x8x32xi32> to vector<3x8x32xi4>
319292
return %0 : vector<3x8x32xi4>
320293
}
321294

322-
// CHECK-LABEL: func.func @i4_transpose(
323-
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
324-
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
325-
// CHECK: %[[EXT:.*]] = vector.interleave
326-
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
327-
// CHECK: vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8>
328-
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
329-
return %0 : vector<16x8xi4>
330-
}
331-
332-
// CHECK-LABEL: func.func @i7_transpose(
333-
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
334-
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
335-
// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
336-
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
337-
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
338-
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
339-
return %0 : vector<16x8xi7>
295+
// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
296+
func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
297+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
298+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
299+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
300+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
301+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
302+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
303+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
304+
%0 = arith.extui %a : vector<8xi4> to vector<8xi8>
305+
return %0 : vector<8xi8>
340306
}
341307

342-
// CHECK-LABEL: func.func @aligned_extui(
343-
func.func @aligned_extui(%a: vector<8xi4>) -> vector<8xi32> {
308+
// CHECK-LABEL: func.func @aligned_extui_i4_to_i32(
309+
func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
344310
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi32> {
345311
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
346312
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
@@ -367,19 +333,83 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
367333
return %0 : vector<8x32xi32>
368334
}
369335

370-
// CHECK-LABEL: func.func @aligned_extui_base_case(
371-
func.func @aligned_extui_base_case(%a: vector<8xi4>) -> vector<8xi8> {
372-
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
336+
// CHECK-LABEL: func.func @aligned_sitofp(
337+
func.func @aligned_sitofp(%a: vector<8xi4>) -> vector<8xf32> {
338+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
339+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
340+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
341+
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
342+
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<4xi8>
343+
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
344+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
345+
// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
346+
%0 = arith.sitofp %a : vector<8xi4> to vector<8xf32>
347+
return %0 : vector<8xf32>
348+
}
349+
350+
// CHECK-LABEL: func.func @aligned_sitofp_2d(
351+
func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
352+
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
353+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
354+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
355+
// CHECK: %[[SHL_LOW:.*]] = arith.shli %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
356+
// CHECK: %[[LOW:.*]] = arith.shrsi %[[SHL_LOW]], %[[I4_BITS]] : vector<8x16xi8>
357+
// CHECK: %[[HIGH:.*]] = arith.shrsi %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
358+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
359+
// CHECK: %[[F32:.*]] = arith.sitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
360+
%0 = arith.sitofp %a : vector<8x32xi4> to vector<8x32xf32>
361+
return %0 : vector<8x32xf32>
362+
}
363+
364+
// CHECK-LABEL: func.func @aligned_uitofp(
365+
func.func @aligned_uitofp(%a: vector<8xi4>) -> vector<8xf32> {
366+
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xf32> {
373367
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
374368
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
375369
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi4> to vector<4xi8>
376370
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<4xi8>
377371
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<4xi8>
378372
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<4xi8>
379-
%0 = arith.extui %a : vector<8xi4> to vector<8xi8>
380-
return %0 : vector<8xi8>
373+
// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8xi8> to vector<8xf32>
374+
%0 = arith.uitofp %a : vector<8xi4> to vector<8xf32>
375+
return %0 : vector<8xf32>
376+
}
377+
378+
// CHECK-LABEL: func.func @aligned_uitofp_2d(
379+
func.func @aligned_uitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
380+
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xf32> {
381+
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
382+
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
383+
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
384+
// CHECK: %[[LOW:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x16xi8>
385+
// CHECK: %[[HIGH:.*]] = arith.shrui %[[BITCAST]], %[[I4_BITS]] : vector<8x16xi8>
386+
// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[LOW]], %[[HIGH]] : vector<8x16xi8>
387+
// CHECK: %[[F32:.*]] = arith.uitofp %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xf32>
388+
%0 = arith.uitofp %a : vector<8x32xi4> to vector<8x32xf32>
389+
return %0 : vector<8x32xf32>
390+
}
391+
392+
// CHECK-LABEL: func.func @i4_transpose(
393+
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
394+
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
395+
// CHECK: %[[EXT:.*]] = vector.interleave
396+
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
397+
// CHECK: vector.deinterleave %[[TRANS]] : vector<16x8xi8> -> vector<16x4xi8>
398+
%0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4>
399+
return %0 : vector<16x8xi4>
381400
}
382401

402+
// CHECK-LABEL: func.func @i7_transpose(
403+
func.func @i7_transpose(%a: vector<8x16xi7>) -> vector<16x8xi7> {
404+
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi7>) -> vector<16x8xi7> {
405+
// CHECK: %[[EXT:.*]] = arith.extsi %[[IN]] : vector<8x16xi7> to vector<8x16xi8>
406+
// CHECK: %[[TRANS:.*]] = vector.transpose %[[EXT]], [1, 0] : vector<8x16xi8> to vector<16x8xi8>
407+
// CHECK: %[[TRUNC:.*]] = arith.trunci %[[TRANS]] : vector<16x8xi8> to vector<16x8xi7>
408+
%0 = vector.transpose %a, [1, 0] : vector<8x16xi7> to vector<16x8xi7>
409+
return %0 : vector<16x8xi7>
410+
}
411+
412+
383413
module attributes {transform.with_named_sequence} {
384414
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
385415
%f = transform.structured.match ops{["func.func"]} in %module_op

0 commit comments

Comments
 (0)