Skip to content

Commit f8b9035

Browse files
committed
[X86] Support amx-int8 intrinsic.
Adding support for intrinsics of TDPBSUD/TDPBUSD/TDPBUUD. Differential Revision: https://reviews.llvm.org/D97259
1 parent dc6a84f commit f8b9035

File tree

11 files changed

+145
-11
lines changed

11 files changed

+145
-11
lines changed

clang/include/clang/Basic/BuiltinsX86_64.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ TARGET_BUILTIN(__builtin_ia32_senduipi, "vUWi", "n", "uintr")
103103
// AMX internal builtin
104104
TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile")
105105
TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
106+
TARGET_BUILTIN(__builtin_ia32_tdpbsud_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
107+
TARGET_BUILTIN(__builtin_ia32_tdpbusd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
108+
TARGET_BUILTIN(__builtin_ia32_tdpbuud_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
106109
TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile")
107110
TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile")
108111
// AMX

clang/lib/Headers/amxintrin.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,24 @@ _tile_dpbssd_internal(unsigned short m, unsigned short n, unsigned short k,
238238
return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
239239
}
240240

241+
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
242+
_tile_dpbsud_internal(unsigned short m, unsigned short n, unsigned short k,
243+
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
244+
return __builtin_ia32_tdpbsud_internal(m, n, k, dst, src1, src2);
245+
}
246+
247+
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
248+
_tile_dpbusd_internal(unsigned short m, unsigned short n, unsigned short k,
249+
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
250+
return __builtin_ia32_tdpbusd_internal(m, n, k, dst, src1, src2);
251+
}
252+
253+
static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8
254+
_tile_dpbuud_internal(unsigned short m, unsigned short n, unsigned short k,
255+
_tile1024i dst, _tile1024i src1, _tile1024i src2) {
256+
return __builtin_ia32_tdpbuud_internal(m, n, k, dst, src1, src2);
257+
}
258+
241259
static __inline__ void __DEFAULT_FN_ATTRS_INT8
242260
_tile_stored_internal(unsigned short m, unsigned short n, void *base,
243261
__SIZE_TYPE__ stride, _tile1024i tile) {
@@ -264,6 +282,27 @@ static void __tile_dpbssd(__tile1024i *dst, __tile1024i src1,
264282
src1.tile, src2.tile);
265283
}
266284

285+
__DEFAULT_FN_ATTRS_INT8
286+
static void __tile_dpbsud(__tile1024i *dst, __tile1024i src1,
287+
__tile1024i src2) {
288+
dst->tile = _tile_dpbsud_internal(src1.row, src2.col, src1.col, dst->tile,
289+
src1.tile, src2.tile);
290+
}
291+
292+
__DEFAULT_FN_ATTRS_INT8
293+
static void __tile_dpbusd(__tile1024i *dst, __tile1024i src1,
294+
__tile1024i src2) {
295+
dst->tile = _tile_dpbusd_internal(src1.row, src2.col, src1.col, dst->tile,
296+
src1.tile, src2.tile);
297+
}
298+
299+
__DEFAULT_FN_ATTRS_INT8
300+
static void __tile_dpbuud(__tile1024i *dst, __tile1024i src1,
301+
__tile1024i src2) {
302+
dst->tile = _tile_dpbuud_internal(src1.row, src2.col, src1.col, dst->tile,
303+
src1.tile, src2.tile);
304+
}
305+
267306
__DEFAULT_FN_ATTRS_TILE
268307
static void __tile_stored(void *base, __SIZE_TYPE__ stride, __tile1024i src) {
269308
_tile_stored_internal(src.row, src.col, base, stride, src.tile);

clang/test/CodeGen/X86/amx_api.c

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,27 @@ void test_tile_dpbssd(__tile1024i a, __tile1024i b, __tile1024i c) {
4646
__tile_dpbssd(&c, a, b);
4747
}
4848

49+
void test_tile_dpbsud(__tile1024i a, __tile1024i b, __tile1024i c) {
50+
//CHECK-LABEL: @test_tile_dpbsud
51+
//CHECK: call x86_amx @llvm.x86.tdpbsud.internal
52+
//CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x i32>
53+
__tile_dpbsud(&c, a, b);
54+
}
55+
56+
void test_tile_dpbusd(__tile1024i a, __tile1024i b, __tile1024i c) {
57+
//CHECK-LABEL: @test_tile_dpbusd
58+
//CHECK: call x86_amx @llvm.x86.tdpbusd.internal
59+
//CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x i32>
60+
__tile_dpbusd(&c, a, b);
61+
}
62+
63+
void test_tile_dpbuud(__tile1024i a, __tile1024i b, __tile1024i c) {
64+
//CHECK-LABEL: @test_tile_dpbuud
65+
//CHECK: call x86_amx @llvm.x86.tdpbuud.internal
66+
//CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x i32>
67+
__tile_dpbuud(&c, a, b);
68+
}
69+
4970
void test_tile_stored(__tile1024i c) {
5071
//CHECK-LABEL: @test_tile_stored
5172
//CHECK: {{%.*}} = bitcast <256 x i32> {{%.*}} to x86_amx

llvm/include/llvm/IR/IntrinsicsX86.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5053,6 +5053,24 @@ let TargetPrefix = "x86" in {
50535053
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
50545054
llvm_x86amx_ty, llvm_x86amx_ty,
50555055
llvm_x86amx_ty], []>;
5056+
def int_x86_tdpbsud_internal :
5057+
GCCBuiltin<"__builtin_ia32_tdpbsud_internal">,
5058+
Intrinsic<[llvm_x86amx_ty],
5059+
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
5060+
llvm_x86amx_ty, llvm_x86amx_ty,
5061+
llvm_x86amx_ty], []>;
5062+
def int_x86_tdpbusd_internal :
5063+
GCCBuiltin<"__builtin_ia32_tdpbusd_internal">,
5064+
Intrinsic<[llvm_x86amx_ty],
5065+
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
5066+
llvm_x86amx_ty, llvm_x86amx_ty,
5067+
llvm_x86amx_ty], []>;
5068+
def int_x86_tdpbuud_internal :
5069+
GCCBuiltin<"__builtin_ia32_tdpbuud_internal">,
5070+
Intrinsic<[llvm_x86amx_ty],
5071+
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
5072+
llvm_x86amx_ty, llvm_x86amx_ty,
5073+
llvm_x86amx_ty], []>;
50565074
def int_x86_tilestored64_internal :
50575075
GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
50585076
Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,

llvm/lib/Target/X86/X86ExpandPseudo.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,22 @@ bool X86ExpandPseudo::ExpandMI(MachineBasicBlock &MBB,
467467
MI.setDesc(TII->get(X86::TILELOADD));
468468
return true;
469469
}
470-
case X86::PTDPBSSDV: {
470+
case X86::PTDPBSSDV:
471+
case X86::PTDPBSUDV:
472+
case X86::PTDPBUSDV:
473+
case X86::PTDPBUUDV: {
471474
MI.untieRegOperand(4);
472475
for (unsigned i = 3; i > 0; --i)
473476
MI.RemoveOperand(i);
474-
MI.setDesc(TII->get(X86::TDPBSSD));
477+
unsigned Opc;
478+
switch (Opcode) {
479+
case X86::PTDPBSSDV: Opc = X86::TDPBSSD; break;
480+
case X86::PTDPBSUDV: Opc = X86::TDPBSUD; break;
481+
case X86::PTDPBUSDV: Opc = X86::TDPBUSD; break;
482+
case X86::PTDPBUUDV: Opc = X86::TDPBUUD; break;
483+
default: llvm_unreachable("Impossible Opcode!");
484+
}
485+
MI.setDesc(TII->get(Opc));
475486
MI.tieOperands(0, 1);
476487
return true;
477488
}

llvm/lib/Target/X86/X86ISelDAGToDAG.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4621,11 +4621,22 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
46214621
ReplaceNode(Node, CNode);
46224622
return;
46234623
}
4624-
case Intrinsic::x86_tdpbssd_internal: {
4624+
4625+
case Intrinsic::x86_tdpbssd_internal:
4626+
case Intrinsic::x86_tdpbsud_internal:
4627+
case Intrinsic::x86_tdpbusd_internal:
4628+
case Intrinsic::x86_tdpbuud_internal: {
46254629
if (!Subtarget->hasAMXTILE())
46264630
break;
46274631
SDValue Chain = Node->getOperand(0);
4628-
unsigned Opc = X86::PTDPBSSDV;
4632+
unsigned Opc;
4633+
switch (IntNo) {
4634+
case Intrinsic::x86_tdpbssd_internal: Opc = X86::PTDPBSSDV; break;
4635+
case Intrinsic::x86_tdpbsud_internal: Opc = X86::PTDPBSUDV; break;
4636+
case Intrinsic::x86_tdpbusd_internal: Opc = X86::PTDPBUSDV; break;
4637+
case Intrinsic::x86_tdpbuud_internal: Opc = X86::PTDPBUUDV; break;
4638+
default: llvm_unreachable("Impossible intrinsic");
4639+
}
46294640
SDValue Ops[] = {Node->getOperand(2),
46304641
Node->getOperand(3),
46314642
Node->getOperand(4),

llvm/lib/Target/X86/X86InstrAMX.td

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,20 @@ let Predicates = [HasAMXINT8, In64BitMode] in {
9292
}
9393

9494
// Pseduo instruction for RA.
95-
let Constraints = "$src4 = $dst" in
96-
def PTDPBSSDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
97-
GR16:$src2, GR16:$src3, TILE:$src4,
98-
TILE:$src5, TILE:$src6), []>;
95+
let Constraints = "$src4 = $dst" in {
96+
def PTDPBSSDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
97+
GR16:$src2, GR16:$src3, TILE:$src4,
98+
TILE:$src5, TILE:$src6), []>;
99+
def PTDPBSUDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
100+
GR16:$src2, GR16:$src3, TILE:$src4,
101+
TILE:$src5, TILE:$src6), []>;
102+
def PTDPBUSDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
103+
GR16:$src2, GR16:$src3, TILE:$src4,
104+
TILE:$src5, TILE:$src6), []>;
105+
def PTDPBUUDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
106+
GR16:$src2, GR16:$src3, TILE:$src4,
107+
TILE:$src5, TILE:$src6), []>;
108+
}
99109

100110
let usesCustomInserter = 1 in {
101111
// Pseudo instructions, using immediates instead of tile registers.

llvm/lib/Target/X86/X86LowerAMXType.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
6767
}
6868
// a * b + c
6969
// The shape depends on which operand.
70-
case Intrinsic::x86_tdpbssd_internal: {
70+
case Intrinsic::x86_tdpbssd_internal:
71+
case Intrinsic::x86_tdpbsud_internal:
72+
case Intrinsic::x86_tdpbusd_internal:
73+
case Intrinsic::x86_tdpbuud_internal: {
7174
switch (OpNo) {
7275
case 3:
7376
Row = II->getArgOperand(0);

llvm/lib/Target/X86/X86PreTileConfig.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) {
155155
llvm_unreachable("Unexpected machine instruction on tile");
156156
case X86::PTILELOADDV:
157157
case X86::PTDPBSSDV:
158+
case X86::PTDPBSUDV:
159+
case X86::PTDPBUSDV:
160+
case X86::PTDPBUUDV:
158161
case X86::PTILEZEROV:
159162
MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1));
160163
MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2));
@@ -249,6 +252,9 @@ static bool isAMXInstruction(MachineBasicBlock::iterator MII) {
249252
case X86::PTILELOADDV:
250253
case X86::PTILESTOREDV:
251254
case X86::PTDPBSSDV:
255+
case X86::PTDPBSUDV:
256+
case X86::PTDPBUSDV:
257+
case X86::PTDPBUUDV:
252258
case X86::PTILEZEROV:
253259
return true;
254260
}

llvm/lib/Target/X86/X86RegisterInfo.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,9 @@ static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
884884
// We only collect the tile shape that is defined.
885885
case X86::PTILELOADDV:
886886
case X86::PTDPBSSDV:
887+
case X86::PTDPBSUDV:
888+
case X86::PTDPBUSDV:
889+
case X86::PTDPBUUDV:
887890
case X86::PTILEZEROV:
888891
MachineOperand &MO1 = MI->getOperand(1);
889892
MachineOperand &MO2 = MI->getOperand(2);

llvm/test/CodeGen/X86/AMX/amx-tile-basic.ll

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,29 @@ define void @test_amx(i8* %pointer, i8* %base, i64 %stride) {
1919
; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm1
2020
; CHECK-NEXT: tileloadd (%rsi,%rdx), %tmm2
2121
; CHECK-NEXT: tdpbssd %tmm2, %tmm1, %tmm0
22+
; CHECK-NEXT: tdpbsud %tmm2, %tmm1, %tmm0
23+
; CHECK-NEXT: tdpbusd %tmm2, %tmm1, %tmm0
24+
; CHECK-NEXT: tdpbuud %tmm2, %tmm1, %tmm0
2225
; CHECK-NEXT: tilestored %tmm0, (%rdi,%rdx)
2326
; CHECK-NEXT: tilerelease
2427
; CHECK-NEXT: vzeroupper
2528
; CHECK-NEXT: retq
2629
%c = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 8)
2730
%a = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, i8* %base, i64 %stride)
2831
%b = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 8, i8* %base, i64 %stride)
29-
%d = call x86_amx @llvm.x86.tdpbssd.internal(i16 8, i16 8, i16 8, x86_amx %c, x86_amx %a, x86_amx %b)
30-
call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %d)
32+
%d0 = call x86_amx @llvm.x86.tdpbssd.internal(i16 8, i16 8, i16 8, x86_amx %c, x86_amx %a, x86_amx %b)
33+
%d1 = call x86_amx @llvm.x86.tdpbsud.internal(i16 8, i16 8, i16 8, x86_amx %d0, x86_amx %a, x86_amx %b)
34+
%d2 = call x86_amx @llvm.x86.tdpbusd.internal(i16 8, i16 8, i16 8, x86_amx %d1, x86_amx %a, x86_amx %b)
35+
%d3 = call x86_amx @llvm.x86.tdpbuud.internal(i16 8, i16 8, i16 8, x86_amx %d2, x86_amx %a, x86_amx %b)
36+
call void @llvm.x86.tilestored64.internal(i16 8, i16 8, i8* %pointer, i64 %stride, x86_amx %d3)
3137

3238
ret void
3339
}
3440

3541
declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
3642
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
3743
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
44+
declare x86_amx @llvm.x86.tdpbsud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
45+
declare x86_amx @llvm.x86.tdpbusd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
46+
declare x86_amx @llvm.x86.tdpbuud.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
3847
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)

0 commit comments

Comments
 (0)