Skip to content

Commit 597cde1

Browse files
angelz913kuhar
andauthored
[mlir][spirv] Implement SPIR-V lowering for vector.deinterleave (#95313)
1. Added a conversion for `vector.deinterleave` to the `VectorToSPIRV` pass. 2. Added LIT tests for the new conversion. --------- Co-authored-by: Jakub Kuderski <[email protected]>
1 parent b6688a0 commit 597cde1

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,66 @@ struct VectorInterleaveOpConvert final
618618
}
619619
};
620620

621+
struct VectorDeinterleaveOpConvert final
622+
: public OpConversionPattern<vector::DeinterleaveOp> {
623+
using OpConversionPattern::OpConversionPattern;
624+
625+
LogicalResult
626+
matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
627+
ConversionPatternRewriter &rewriter) const override {
628+
629+
// Check the result vector type.
630+
VectorType oldResultType = deinterleaveOp.getResultVectorType();
631+
Type newResultType = getTypeConverter()->convertType(oldResultType);
632+
if (!newResultType)
633+
return rewriter.notifyMatchFailure(deinterleaveOp,
634+
"unsupported result vector type");
635+
636+
Location loc = deinterleaveOp->getLoc();
637+
638+
// Deinterleave the indices.
639+
Value sourceVector = adaptor.getSource();
640+
VectorType sourceType = deinterleaveOp.getSourceVectorType();
641+
int n = sourceType.getNumElements();
642+
643+
// Output vectors of size 1 are converted to scalars by the type converter.
644+
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
645+
// use `spirv::CompositeExtractOp`.
646+
if (n == 2) {
647+
auto elem0 = rewriter.create<spirv::CompositeExtractOp>(
648+
loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({0}));
649+
650+
auto elem1 = rewriter.create<spirv::CompositeExtractOp>(
651+
loc, newResultType, sourceVector, rewriter.getI32ArrayAttr({1}));
652+
653+
rewriter.replaceOp(deinterleaveOp, {elem0, elem1});
654+
return success();
655+
}
656+
657+
// Indices for `shuffleEven` (result 0).
658+
auto seqEven = llvm::seq<int64_t>(n / 2);
659+
auto indicesEven =
660+
llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
661+
662+
// Indices for `shuffleOdd` (result 1).
663+
auto seqOdd = llvm::seq<int64_t>(n / 2);
664+
auto indicesOdd =
665+
llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
666+
667+
// Create two SPIR-V shuffles.
668+
auto shuffleEven = rewriter.create<spirv::VectorShuffleOp>(
669+
loc, newResultType, sourceVector, sourceVector,
670+
rewriter.getI32ArrayAttr(indicesEven));
671+
672+
auto shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
673+
loc, newResultType, sourceVector, sourceVector,
674+
rewriter.getI32ArrayAttr(indicesOdd));
675+
676+
rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
677+
return success();
678+
}
679+
};
680+
621681
struct VectorLoadOpConverter final
622682
: public OpConversionPattern<vector::LoadOp> {
623683
using OpConversionPattern::OpConversionPattern;
@@ -862,9 +922,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
862922
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
863923
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
864924
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
865-
VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
866-
VectorStoreOpConverter>(typeConverter, patterns.getContext(),
867-
PatternBenefit(1));
925+
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
926+
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
927+
typeConverter, patterns.getContext(), PatternBenefit(1));
868928

869929
// Make sure that the more specialized dot product pattern has higher benefit
870930
// than the generic one that extracts all elements.

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,32 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
507507

508508
// -----
509509

510+
// CHECK-LABEL: func @deinterleave
511+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
512+
// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
513+
// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
514+
// CHECK: return %[[SHUFFLE0]], %[[SHUFFLE1]]
515+
func.func @deinterleave(%a: vector<4xf32>) -> (vector<2xf32>, vector<2xf32>) {
516+
%0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
517+
return %0, %1 : vector<2xf32>, vector<2xf32>
518+
}
519+
520+
// -----
521+
522+
// CHECK-LABEL: func @deinterleave_scalar
523+
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
524+
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
525+
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
526+
// CHECK: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
527+
// CHECK: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
528+
// CHECK: return %[[CAST0]], %[[CAST1]]
529+
func.func @deinterleave_scalar(%a: vector<2xf32>) -> (vector<1xf32>, vector<1xf32>) {
530+
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
531+
return %0, %1 : vector<1xf32>, vector<1xf32>
532+
}
533+
534+
// -----
535+
510536
// CHECK-LABEL: func @reduction_add
511537
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
512538
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

0 commit comments

Comments
 (0)