Skip to content

Commit 459cf6e

Browse files
committed
[mlir] [VectorOps] Lowering of vector.extract/insert_slices to LLVM IR
Summary: Uses progressive lowering to convert vector.extract_slices and vector_insert_slices to equivalent vector operations that can be subsequently lowered into LLVM. Reviewers: nicolasvasilache, andydavis1, rriddle Reviewed By: nicolasvasilache, rriddle Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72808
1 parent 53eb0f8 commit 459cf6e

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -972,9 +972,17 @@ struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
972972
} // namespace
973973

974974
void LowerVectorToLLVMPass::runOnModule() {
975-
// Convert to the LLVM IR dialect using the converter defined above.
976-
OwningRewritePatternList patterns;
975+
// Perform progressive lowering of operations on "slices".
976+
// Folding and DCE get rid of all non-leaking tuple ops.
977+
{
978+
OwningRewritePatternList patterns;
979+
populateVectorSlicesLoweringPatterns(patterns, &getContext());
980+
applyPatternsGreedily(getModule(), patterns);
981+
}
982+
983+
// Convert to the LLVM IR dialect.
977984
LLVMTypeConverter converter(&getContext());
985+
OwningRewritePatternList patterns;
978986
populateVectorToLLVMConversionPatterns(converter, patterns);
979987
populateStdToLLVMConversionPatterns(converter, patterns);
980988

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

+38-14
Original file line numberDiff line numberDiff line change
@@ -424,10 +424,11 @@ func @vector_print_vector(%arg0: vector<2x2xf32>) {
424424
// CHECK: llvm.call @print_close() : () -> ()
425425
// CHECK: llvm.call @print_newline() : () -> ()
426426

427-
428-
func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<4x8x16xf32>) {
429-
// CHECK-LABEL: llvm.func @strided_slice(
427+
func @strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> {
430428
%0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32>
429+
return %0 : vector<2xf32>
430+
}
431+
// CHECK-LABEL: llvm.func @strided_slice1
431432
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
432433
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2xf32>) : !llvm<"<2 x float>">
433434
// CHECK: llvm.mlir.constant(2 : index) : !llvm.i64
@@ -439,15 +440,23 @@ func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<
439440
// CHECK: llvm.mlir.constant(1 : index) : !llvm.i64
440441
// CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
441442

442-
%1 = vector.strided_slice %arg1 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
443+
func @strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> {
444+
%0 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32>
445+
return %0 : vector<2x8xf32>
446+
}
447+
// CHECK-LABEL: llvm.func @strided_slice2
443448
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
444449
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x8xf32>) : !llvm<"[2 x <8 x float>]">
445450
// CHECK: llvm.extractvalue %{{.*}}[2] : !llvm<"[4 x <8 x float>]">
446451
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"[2 x <8 x float>]">
447452
// CHECK: llvm.extractvalue %{{.*}}[3] : !llvm<"[4 x <8 x float>]">
448453
// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"[2 x <8 x float>]">
449454

450-
%2 = vector.strided_slice %arg1 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
455+
func @strided_slice3(%arg0: vector<4x8xf32>) -> vector<2x2xf32> {
456+
%0 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8xf32> to vector<2x2xf32>
457+
return %0 : vector<2x2xf32>
458+
}
459+
// CHECK-LABEL: llvm.func @strided_slice3
451460
// CHECK: llvm.mlir.constant(0.000000e+00 : f32) : !llvm.float
452461
// CHECK: llvm.mlir.constant(dense<0.000000e+00> : vector<2x2xf32>) : !llvm<"[2 x <2 x float>]">
453462
//
@@ -479,17 +488,19 @@ func @strided_slice(%arg0: vector<4xf32>, %arg1: vector<4x8xf32>, %arg2: vector<
479488
// CHECK: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<2 x float>">
480489
// CHECK: llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"[2 x <2 x float>]">
481490

482-
return
483-
}
484-
485-
func @insert_strided_slice(%a: vector<2x2xf32>, %b: vector<4x4xf32>, %c: vector<4x4x4xf32>) {
486-
// CHECK-LABEL: @insert_strided_slice
487-
491+
func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vector<4x4x4xf32> {
488492
%0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xf32> into vector<4x4x4xf32>
493+
return %0 : vector<4x4x4xf32>
494+
}
495+
// CHECK-LABEL: @insert_strided_slice1
489496
// CHECK: llvm.extractvalue {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]">
490497
// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"[4 x [4 x <4 x float>]]">
491498

492-
%1 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
499+
func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<4x4xf32> {
500+
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
501+
return %0 : vector<4x4xf32>
502+
}
503+
// CHECK-LABEL: @insert_strided_slice2
493504
//
494505
// Subvector vector<2xf32> @0 into vector<4xf32> @2
495506
// CHECK: llvm.extractvalue {{.*}}[0] : !llvm<"[2 x <2 x float>]">
@@ -521,6 +532,19 @@ func @insert_strided_slice(%a: vector<2x2xf32>, %b: vector<4x4xf32>, %c: vector<
521532
// CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : !llvm.i64] : !llvm<"<4 x float>">
522533
// CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[3] : !llvm<"[4 x <4 x float>]">
523534

524-
return
535+
func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
536+
%0 = vector.extract_slices %arg0, [2, 2], [1, 1]
537+
: vector<3x3xf32> into tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
538+
%1 = vector.tuple_get %0, 3 : tuple<vector<2x2xf32>, vector<2x1xf32>, vector<1x2xf32>, vector<1x1xf32>>
539+
return %1 : vector<1x1xf32>
525540
}
526-
541+
// CHECK-LABEL: extract_strides(%arg0: !llvm<"[3 x <3 x float>]">)
542+
// CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]">
543+
// CHECK: %[[s1:.*]] = llvm.extractvalue %arg0[2] : !llvm<"[3 x <3 x float>]">
544+
// CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>">
545+
// CHECK: %[[s4:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
546+
// CHECK: %[[s5:.*]] = llvm.extractelement %[[s1]][%[[s4]] : !llvm.i64] : !llvm<"<3 x float>">
547+
// CHECK: %[[s6:.*]] = llvm.mlir.constant(0 : index) : !llvm.i64
548+
// CHECK: %[[s7:.*]] = llvm.insertelement %[[s5]], %[[s3]][%[[s6]] : !llvm.i64] : !llvm<"<1 x float>">
549+
// CHECK: %[[s8:.*]] = llvm.insertvalue %[[s7]], %[[s0]][0] : !llvm<"[1 x <1 x float>]">
550+
// CHECK: llvm.return %[[s8]] : !llvm<"[1 x <1 x float>]">

0 commit comments

Comments
 (0)