Skip to content

Commit e7ad07f

Browse files
authored
[libclc] Move fma to the CLC library (#126052)
This builtin is a little more involved than others as targets deal with fma in various different ways. Fundamentally, the CLC __clc_fma builtin compiles to __builtin_elementwise_fma, which compiles to the @llvm.fma intrinsic. However, in the case of fp32 fma some targets call the __clc_sw_fma function, which provides a software implementation of the builtin. This in principle is controlled by the __CLC_HAVE_HW_FMA32 macro and may be a runtime decision, depending on how the target defines that macro. All targets build the CLC fma functions for all types. This is to the CLC library can have a reliable internal implementation for its own purposes. For AMD/NVPTX targets there are no meaningful changes to the generated LLVM bytecode. Some blocks of code have moved around, which confounds llvm-diff. For the clspv and SPIR-V/Mesa targets, only fp32 fma is of interest. Its use in libclc is tightly controlled by checking __CLC_HAVE_HW_FMA32 first. This can either be a compile-time constant (1, for clspv) or a runtime function for SPIR-V/Mesa. The SPIR-V/Mesa target only provided fp32 fma in the OpenCL layer. It unconditionally mapped that to the __clc_sw_fma builtin, even though the generic version in theory had a runtime toggle through __CLC_HAVE_HW_FMA32 specifically for that target. Callers of fma, though, would end up using the ExtInst fma, *not* calling the _Z3fmafff function provided by libclc. This commit keeps this system in place in the OpenCL layer, by mapping fma to __clc_sw_fma. Where other builtins would previously call fma (i.e., result in the ExtInst), they now call __clc_fma. This function checks the __CLC_HAVE_HW_FMA32 runtime toggle, which selects between the slow version or the quick version. The quick version is the LLVM fma intrinsic which llvm-spirv translates to the ExtInst. The clspv target had its own software implementation of fp32 fma, which it called unconditionally - even though __CLC_HAVE_HW_FMA32 is 1 for that target. This is potentially just so its library ships a software version which it can fall back on. In the OpenCL layer, the target doesn't provide fp64 fma, and maps fp16 fma to fp32 mad. This commit keeps this system roughly in place: in the OpenCL layer it maps fp32 fma to __clc_sw_fma, and fp16 fma to mad. Where builtins would previously call into fma, they now call __clc_fma, which compiles to the LLVM intrinsic. If this goes through a translation to SPIR-V it will become the fma ExtInst, or the intrinsic could be replaced by the _Z3fmafff software implementation. The clspv and SPIR-V/Mesa targets could potentially be cleaned up later, depending on their needs.
1 parent 3532651 commit e7ad07f

29 files changed

+542
-447
lines changed

libclc/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS
2828
spirv/lib/SOURCES;
2929
# CLC internal libraries
3030
clc/lib/generic/SOURCES;
31+
clc/lib/clspv/SOURCES;
32+
clc/lib/spirv/SOURCES;
3133
)
3234

3335
set( LIBCLC_MIN_LLVM 3.9.0 )
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1-
#define __CLC_FUNCTION __clc_fma
2-
#define __CLC_INTRINSIC "llvm.fma"
3-
#include "math/ternary_intrin.inc"
1+
#ifndef __CLC_INTERNAL_MATH_CLC_SW_FMA_H__
2+
#define __CLC_INTERNAL_MATH_CLC_SW_FMA_H__
43

5-
#define __FLOAT_ONLY
64
#define __CLC_FUNCTION __clc_sw_fma
75
#define __CLC_BODY <clc/shared/ternary_decl.inc>
6+
87
#include <clc/math/gentype.inc>
8+
99
#undef __CLC_BODY
1010
#undef __CLC_FUNCTION
11-
#undef __FLOAT_ONLY
11+
12+
#endif // __CLC_INTERNAL_MATH_CLC_SW_FMA_H__

libclc/clc/include/clc/math/clc_fma.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#ifndef __CLC_MATH_CLC_FMA_H__
2+
#define __CLC_MATH_CLC_FMA_H__
3+
4+
#define __CLC_FUNCTION __clc_fma
5+
#define __CLC_BODY <clc/shared/ternary_decl.inc>
6+
7+
#include <clc/math/gentype.inc>
8+
9+
#undef __CLC_BODY
10+
#undef __CLC_FUNCTION
11+
12+
#endif // __CLC_MATH_CLC_FMA_H__

libclc/clc/include/clc/math/math.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@
3939
#define PINF 0x200
4040

4141
#if (defined __AMDGCN__ || defined __R600__) && !defined __HAS_FMAF__
42-
#define HAVE_HW_FMA32() (0)
42+
#define __CLC_HAVE_HW_FMA32() (0)
4343
#elif defined(CLC_SPIRV)
4444
bool __attribute__((noinline)) __clc_runtime_has_hw_fma32(void);
45-
#define HAVE_HW_FMA32() __clc_runtime_has_hw_fma32()
45+
#define __CLC_HAVE_HW_FMA32() __clc_runtime_has_hw_fma32()
4646
#else
47-
#define HAVE_HW_FMA32() (1)
47+
#define __CLC_HAVE_HW_FMA32() (1)
4848
#endif
4949

5050
#define HAVE_BITALIGN() (0)

libclc/clc/lib/clspv/SOURCES

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
math/clc_sw_fma.cl
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
/*
2+
* Copyright (c) 2014 Advanced Micro Devices, Inc.
3+
*
4+
* Permission is hereby granted, free of charge, to any person obtaining a copy
5+
* of this software and associated documentation files (the "Software"), to deal
6+
* in the Software without restriction, including without limitation the rights
7+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8+
* copies of the Software, and to permit persons to whom the Software is
9+
* furnished to do so, subject to the following conditions:
10+
*
11+
* The above copyright notice and this permission notice shall be included in
12+
* all copies or substantial portions of the Software.
13+
*
14+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20+
* THE SOFTWARE.
21+
*/
22+
23+
// This version is derived from the generic fma software implementation
24+
// (__clc_sw_fma), but avoids the use of ulong in favor of uint2. The logic has
25+
// been updated as appropriate.
26+
27+
#include <clc/clc_as_type.h>
28+
#include <clc/clcmacro.h>
29+
#include <clc/float/definitions.h>
30+
#include <clc/integer/clc_abs.h>
31+
#include <clc/integer/clc_clz.h>
32+
#include <clc/integer/clc_hadd.h>
33+
#include <clc/integer/clc_mul_hi.h>
34+
#include <clc/integer/definitions.h>
35+
#include <clc/math/clc_mad.h>
36+
#include <clc/math/math.h>
37+
#include <clc/relational/clc_isinf.h>
38+
#include <clc/relational/clc_isnan.h>
39+
#include <clc/shared/clc_max.h>
40+
41+
struct fp {
42+
uint2 mantissa;
43+
int exponent;
44+
uint sign;
45+
};
46+
47+
static uint2 u2_set(uint hi, uint lo) {
48+
uint2 res;
49+
res.lo = lo;
50+
res.hi = hi;
51+
return res;
52+
}
53+
54+
static uint2 u2_set_u(uint val) { return u2_set(0, val); }
55+
56+
static uint2 u2_mul(uint a, uint b) {
57+
uint2 res;
58+
res.hi = __clc_mul_hi(a, b);
59+
res.lo = a * b;
60+
return res;
61+
}
62+
63+
static uint2 u2_sll(uint2 val, uint shift) {
64+
if (shift == 0)
65+
return val;
66+
if (shift < 32) {
67+
val.hi <<= shift;
68+
val.hi |= val.lo >> (32 - shift);
69+
val.lo <<= shift;
70+
} else {
71+
val.hi = val.lo << (shift - 32);
72+
val.lo = 0;
73+
}
74+
return val;
75+
}
76+
77+
static uint2 u2_srl(uint2 val, uint shift) {
78+
if (shift == 0)
79+
return val;
80+
if (shift < 32) {
81+
val.lo >>= shift;
82+
val.lo |= val.hi << (32 - shift);
83+
val.hi >>= shift;
84+
} else {
85+
val.lo = val.hi >> (shift - 32);
86+
val.hi = 0;
87+
}
88+
return val;
89+
}
90+
91+
static uint2 u2_or(uint2 a, uint b) {
92+
a.lo |= b;
93+
return a;
94+
}
95+
96+
static uint2 u2_and(uint2 a, uint2 b) {
97+
a.lo &= b.lo;
98+
a.hi &= b.hi;
99+
return a;
100+
}
101+
102+
static uint2 u2_add(uint2 a, uint2 b) {
103+
uint carry = (__clc_hadd(a.lo, b.lo) >> 31) & 0x1;
104+
a.lo += b.lo;
105+
a.hi += b.hi + carry;
106+
return a;
107+
}
108+
109+
static uint2 u2_add_u(uint2 a, uint b) { return u2_add(a, u2_set_u(b)); }
110+
111+
static uint2 u2_inv(uint2 a) {
112+
a.lo = ~a.lo;
113+
a.hi = ~a.hi;
114+
return u2_add_u(a, 1);
115+
}
116+
117+
static uint u2_clz(uint2 a) {
118+
uint leading_zeroes = __clc_clz(a.hi);
119+
if (leading_zeroes == 32) {
120+
leading_zeroes += __clc_clz(a.lo);
121+
}
122+
return leading_zeroes;
123+
}
124+
125+
static bool u2_eq(uint2 a, uint2 b) { return a.lo == b.lo && a.hi == b.hi; }
126+
127+
static bool u2_zero(uint2 a) { return u2_eq(a, u2_set_u(0)); }
128+
129+
static bool u2_gt(uint2 a, uint2 b) {
130+
return a.hi > b.hi || (a.hi == b.hi && a.lo > b.lo);
131+
}
132+
133+
_CLC_DEF _CLC_OVERLOAD float __clc_sw_fma(float a, float b, float c) {
134+
/* special cases */
135+
if (__clc_isnan(a) || __clc_isnan(b) || __clc_isnan(c) || __clc_isinf(a) ||
136+
__clc_isinf(b)) {
137+
return __clc_mad(a, b, c);
138+
}
139+
140+
/* If only c is inf, and both a,b are regular numbers, the result is c*/
141+
if (__clc_isinf(c)) {
142+
return c;
143+
}
144+
145+
a = __clc_flush_denormal_if_not_supported(a);
146+
b = __clc_flush_denormal_if_not_supported(b);
147+
c = __clc_flush_denormal_if_not_supported(c);
148+
149+
if (a == 0.0f || b == 0.0f) {
150+
return c;
151+
}
152+
153+
if (c == 0) {
154+
return a * b;
155+
}
156+
157+
struct fp st_a, st_b, st_c;
158+
159+
st_a.exponent = a == .0f ? 0 : ((__clc_as_uint(a) & 0x7f800000) >> 23) - 127;
160+
st_b.exponent = b == .0f ? 0 : ((__clc_as_uint(b) & 0x7f800000) >> 23) - 127;
161+
st_c.exponent = c == .0f ? 0 : ((__clc_as_uint(c) & 0x7f800000) >> 23) - 127;
162+
163+
st_a.mantissa =
164+
u2_set_u(a == .0f ? 0 : (__clc_as_uint(a) & 0x7fffff) | 0x800000);
165+
st_b.mantissa =
166+
u2_set_u(b == .0f ? 0 : (__clc_as_uint(b) & 0x7fffff) | 0x800000);
167+
st_c.mantissa =
168+
u2_set_u(c == .0f ? 0 : (__clc_as_uint(c) & 0x7fffff) | 0x800000);
169+
170+
st_a.sign = __clc_as_uint(a) & 0x80000000;
171+
st_b.sign = __clc_as_uint(b) & 0x80000000;
172+
st_c.sign = __clc_as_uint(c) & 0x80000000;
173+
174+
// Multiplication.
175+
// Move the product to the highest bits to maximize precision
176+
// mantissa is 24 bits => product is 48 bits, 2bits non-fraction.
177+
// Add one bit for future addition overflow,
178+
// add another bit to detect subtraction underflow
179+
struct fp st_mul;
180+
st_mul.sign = st_a.sign ^ st_b.sign;
181+
st_mul.mantissa = u2_sll(u2_mul(st_a.mantissa.lo, st_b.mantissa.lo), 14);
182+
st_mul.exponent =
183+
!u2_zero(st_mul.mantissa) ? st_a.exponent + st_b.exponent : 0;
184+
185+
// FIXME: Detecting a == 0 || b == 0 above crashed GCN isel
186+
if (st_mul.exponent == 0 && u2_zero(st_mul.mantissa))
187+
return c;
188+
189+
// Mantissa is 23 fractional bits, shift it the same way as product mantissa
190+
#define C_ADJUST 37ul
191+
192+
// both exponents are bias adjusted
193+
int exp_diff = st_mul.exponent - st_c.exponent;
194+
195+
st_c.mantissa = u2_sll(st_c.mantissa, C_ADJUST);
196+
uint2 cutoff_bits = u2_set_u(0);
197+
uint2 cutoff_mask = u2_add(u2_sll(u2_set_u(1), __clc_abs(exp_diff)),
198+
u2_set(0xffffffff, 0xffffffff));
199+
if (exp_diff > 0) {
200+
cutoff_bits =
201+
exp_diff >= 64 ? st_c.mantissa : u2_and(st_c.mantissa, cutoff_mask);
202+
st_c.mantissa =
203+
exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_c.mantissa, exp_diff);
204+
} else {
205+
cutoff_bits = -exp_diff >= 64 ? st_mul.mantissa
206+
: u2_and(st_mul.mantissa, cutoff_mask);
207+
st_mul.mantissa =
208+
-exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_mul.mantissa, -exp_diff);
209+
}
210+
211+
struct fp st_fma;
212+
st_fma.sign = st_mul.sign;
213+
st_fma.exponent = __clc_max(st_mul.exponent, st_c.exponent);
214+
if (st_c.sign == st_mul.sign) {
215+
st_fma.mantissa = u2_add(st_mul.mantissa, st_c.mantissa);
216+
} else {
217+
// cutoff bits borrow one
218+
st_fma.mantissa =
219+
u2_add(u2_add(st_mul.mantissa, u2_inv(st_c.mantissa)),
220+
(!u2_zero(cutoff_bits) && (st_mul.exponent > st_c.exponent)
221+
? u2_set(0xffffffff, 0xffffffff)
222+
: u2_set_u(0)));
223+
}
224+
225+
// underflow: st_c.sign != st_mul.sign, and magnitude switches the sign
226+
if (u2_gt(st_fma.mantissa, u2_set(0x7fffffff, 0xffffffff))) {
227+
st_fma.mantissa = u2_inv(st_fma.mantissa);
228+
st_fma.sign = st_mul.sign ^ 0x80000000;
229+
}
230+
231+
// detect overflow/underflow
232+
int overflow_bits = 3 - u2_clz(st_fma.mantissa);
233+
234+
// adjust exponent
235+
st_fma.exponent += overflow_bits;
236+
237+
// handle underflow
238+
if (overflow_bits < 0) {
239+
st_fma.mantissa = u2_sll(st_fma.mantissa, -overflow_bits);
240+
overflow_bits = 0;
241+
}
242+
243+
// rounding
244+
uint2 trunc_mask = u2_add(u2_sll(u2_set_u(1), C_ADJUST + overflow_bits),
245+
u2_set(0xffffffff, 0xffffffff));
246+
uint2 trunc_bits =
247+
u2_or(u2_and(st_fma.mantissa, trunc_mask), !u2_zero(cutoff_bits));
248+
uint2 last_bit =
249+
u2_and(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits));
250+
uint2 grs_bits = u2_sll(u2_set_u(4), C_ADJUST - 3 + overflow_bits);
251+
252+
// round to nearest even
253+
if (u2_gt(trunc_bits, grs_bits) ||
254+
(u2_eq(trunc_bits, grs_bits) && !u2_zero(last_bit))) {
255+
st_fma.mantissa =
256+
u2_add(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits));
257+
}
258+
259+
// Shift mantissa back to bit 23
260+
st_fma.mantissa = u2_srl(st_fma.mantissa, C_ADJUST + overflow_bits);
261+
262+
// Detect rounding overflow
263+
if (u2_gt(st_fma.mantissa, u2_set_u(0xffffff))) {
264+
++st_fma.exponent;
265+
st_fma.mantissa = u2_srl(st_fma.mantissa, 1);
266+
}
267+
268+
if (u2_zero(st_fma.mantissa)) {
269+
return 0.0f;
270+
}
271+
272+
// Flating point range limit
273+
if (st_fma.exponent > 127) {
274+
return __clc_as_float(__clc_as_uint(INFINITY) | st_fma.sign);
275+
}
276+
277+
// Flush denormals
278+
if (st_fma.exponent <= -127) {
279+
return __clc_as_float(st_fma.sign);
280+
}
281+
282+
return __clc_as_float(st_fma.sign | ((st_fma.exponent + 127) << 23) |
283+
((uint)st_fma.mantissa.lo & 0x7fffff));
284+
}
285+
286+
_CLC_TERNARY_VECTORIZE(_CLC_DEF _CLC_OVERLOAD, float, __clc_sw_fma, float,
287+
float, float)

libclc/clc/lib/generic/SOURCES

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ integer/clc_upsample.cl
2020
math/clc_ceil.cl
2121
math/clc_copysign.cl
2222
math/clc_fabs.cl
23+
math/clc_fma.cl
2324
math/clc_floor.cl
2425
math/clc_frexp.cl
2526
math/clc_mad.cl
2627
math/clc_modf.cl
2728
math/clc_nextafter.cl
2829
math/clc_rint.cl
30+
math/clc_sw_fma.cl
2931
math/clc_trunc.cl
3032
relational/clc_all.cl
3133
relational/clc_any.cl
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#include <clc/internal/clc.h>
2+
#include <clc/internal/math/clc_sw_fma.h>
3+
#include <clc/math/math.h>
4+
5+
#define __CLC_BODY <clc_fma.inc>
6+
#include <clc/math/gentype.inc>
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_CLC_DEF _CLC_OVERLOAD __CLC_GENTYPE __clc_fma(__CLC_GENTYPE a, __CLC_GENTYPE b,
2+
__CLC_GENTYPE c) {
3+
#if __CLC_FPSIZE == 32
4+
if (!__CLC_HAVE_HW_FMA32())
5+
return __clc_sw_fma(a, b, c);
6+
#endif
7+
return __builtin_elementwise_fma(a, b, c);
8+
}

0 commit comments

Comments
 (0)