Skip to content

Commit ab70251

Browse files
authored
[mlir][VectorOps] Add vector.interleave operation (#80965)
The interleave operation constructs a new vector by interleaving the elements from the trailing (or final) dimension of two input vectors, returning a new vector where the trailing dimension is twice the size. Note that for the n-D case this differs from the interleaving possible with `vector.shuffle`, which would only operate on the leading dimension. Another key difference is this operation supports scalable vectors, though currently a general LLVM lowering is limited to the case where only the trailing dimension is scalable. Example: ```mlir %0 = vector.interleave %a, %b : vector<[4]xi32> ; yields vector<[8]xi32> %1 = vector.interleave %c, %d : vector<8xi8> ; yields vector<16xi8> %2 = vector.interleave %e, %f : vector<f16> ; yields vector<2xf16> %3 = vector.interleave %g, %h : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64> %4 = vector.interleave %i, %j : vector<6x3xf32> ; yields vector<6x6xf32> ``` Note: This change alone does not add any lowerings.
1 parent 673e5e3 commit ab70251

File tree

2 files changed

+98
-0
lines changed

2 files changed

+98
-0
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,69 @@ def Vector_ShuffleOp :
478478
let hasCanonicalizer = 1;
479479
}
480480

481+
def Vector_InterleaveOp :
482+
Vector_Op<"interleave", [Pure,
483+
AllTypesMatch<["lhs", "rhs"]>,
484+
TypesMatchWith<
485+
"type of 'result' is double the width of the inputs",
486+
"lhs", "result",
487+
[{
488+
[&]() -> ::mlir::VectorType {
489+
auto vectorType = ::llvm::cast<mlir::VectorType>($_self);
490+
::mlir::VectorType::Builder builder(vectorType);
491+
if (vectorType.getRank() == 0) {
492+
static constexpr int64_t v2xty_shape[] = { 2 };
493+
return builder.setShape(v2xty_shape);
494+
}
495+
auto lastDim = vectorType.getRank() - 1;
496+
return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2);
497+
}()
498+
}]>]> {
499+
let summary = "constructs a vector by interleaving two input vectors";
500+
let description = [{
501+
The interleave operation constructs a new vector by interleaving the
502+
elements from the trailing (or final) dimension of two input vectors,
503+
returning a new vector where the trailing dimension is twice the size.
504+
505+
Note that for the n-D case this differs from the interleaving possible with
506+
`vector.shuffle`, which would only operate on the leading dimension.
507+
508+
Another key difference is this operation supports scalable vectors, though
509+
currently a general LLVM lowering is limited to the case where only the
510+
trailing dimension is scalable.
511+
512+
Example:
513+
```mlir
514+
%0 = vector.interleave %a, %b
515+
: vector<[4]xi32> ; yields vector<[8]xi32>
516+
%1 = vector.interleave %c, %d
517+
: vector<8xi8> ; yields vector<16xi8>
518+
%2 = vector.interleave %e, %f
519+
: vector<f16> ; yields vector<2xf16>
520+
%3 = vector.interleave %g, %h
521+
: vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64>
522+
%4 = vector.interleave %i, %j
523+
: vector<6x3xf32> ; yields vector<6x6xf32>
524+
```
525+
}];
526+
527+
let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs);
528+
let results = (outs AnyVector:$result);
529+
530+
let assemblyFormat = [{
531+
$lhs `,` $rhs attr-dict `:` type($lhs)
532+
}];
533+
534+
let extraClassDeclaration = [{
535+
VectorType getSourceVectorType() {
536+
return ::llvm::cast<VectorType>(getLhs().getType());
537+
}
538+
VectorType getResultVectorType() {
539+
return ::llvm::cast<VectorType>(getResult().getType());
540+
}
541+
}];
542+
}
543+
481544
def Vector_ExtractElementOp :
482545
Vector_Op<"extractelement", [Pure,
483546
TypesMatchWith<"result type matches element type of vector operand",

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,3 +1081,38 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 {
10811081
%min = vector.reduction <minnumf>, %x fastmath<reassoc,nnan,ninf> : vector<42xf32> into f32
10821082
return %min: f32
10831083
}
1084+
1085+
// CHECK-LABEL: @interleave_0d
1086+
func.func @interleave_0d(%a: vector<f32>, %b: vector<f32>) -> vector<2xf32> {
1087+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<f32>
1088+
%0 = vector.interleave %a, %b : vector<f32>
1089+
return %0 : vector<2xf32>
1090+
}
1091+
1092+
// CHECK-LABEL: @interleave_1d
1093+
func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> {
1094+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32>
1095+
%0 = vector.interleave %a, %b : vector<4xf32>
1096+
return %0 : vector<8xf32>
1097+
}
1098+
1099+
// CHECK-LABEL: @interleave_1d_scalable
1100+
func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> {
1101+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16>
1102+
%0 = vector.interleave %a, %b : vector<[8]xi16>
1103+
return %0 : vector<[16]xi16>
1104+
}
1105+
1106+
// CHECK-LABEL: @interleave_2d
1107+
func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> {
1108+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32>
1109+
%0 = vector.interleave %a, %b : vector<2x8xf32>
1110+
return %0 : vector<2x16xf32>
1111+
}
1112+
1113+
// CHECK-LABEL: @interleave_2d_scalable
1114+
func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> {
1115+
// CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64>
1116+
%0 = vector.interleave %a, %b : vector<2x[2]xf64>
1117+
return %0 : vector<2x[4]xf64>
1118+
}

0 commit comments

Comments
 (0)