Skip to content

Commit d5148f0

Browse files
[X86] Fix arithmetic error in extractVector (#128052)
The computation of the element count for the result VT in extractVector is incorrect when vector width does not divide VT.getSizeInBits(), which can occur when the source vector element count is not a power of two, e.g. extracting a vectorWidth 256b vector from a 384b source. This rewrites the expression so the division is exact given that vectorWidth is a multiple of the source element size.
1 parent c83bdc7 commit d5148f0

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4076,9 +4076,12 @@ static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
40764076
const SDLoc &dl, unsigned vectorWidth) {
40774077
EVT VT = Vec.getValueType();
40784078
EVT ElVT = VT.getVectorElementType();
4079-
unsigned Factor = VT.getSizeInBits() / vectorWidth;
4080-
EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT,
4081-
VT.getVectorNumElements() / Factor);
4079+
unsigned ResultNumElts =
4080+
(VT.getVectorNumElements() * vectorWidth) / VT.getSizeInBits();
4081+
EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT, ResultNumElts);
4082+
4083+
assert(ResultVT.getSizeInBits() == vectorWidth &&
4084+
"Illegal subvector extraction");
40824085

40834086
// Extract the relevant vectorWidth bits. Generate an EXTRACT_SUBVECTOR
40844087
unsigned ElemsPerChunk = vectorWidth / ElVT.getSizeInBits();
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; Ensure assertion is not hit when folding concat of two contiguous extract_subvector operations
3+
; from a source with a non-power-of-two vector length.
4+
; RUN: llc -mtriple=x86_64 -mattr=+avx2 < %s | FileCheck %s
5+
6+
define void @foo(ptr %pDst) {
7+
; CHECK-LABEL: foo:
8+
; CHECK: # %bb.0: # %entry
9+
; CHECK-NEXT: vxorps %xmm0, %xmm0, %xmm0
10+
; CHECK-NEXT: vmovups %ymm0, 16(%rdi)
11+
; CHECK-NEXT: vzeroupper
12+
; CHECK-NEXT: retq
13+
entry:
14+
%0 = shufflevector <12 x float> zeroinitializer, <12 x float> zeroinitializer, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
15+
%1 = shufflevector <12 x float> zeroinitializer, <12 x float> zeroinitializer, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
16+
%2 = getelementptr i8, ptr %pDst, i64 16
17+
%3 = getelementptr i8, ptr %pDst, i64 32
18+
store <4 x float> %0, ptr %2, align 1
19+
store <4 x float> %1, ptr %3, align 1
20+
ret void
21+
}

0 commit comments

Comments
 (0)