Skip to content

Commit d699d9d

Browse files
authored
[flang][runtime] Support SUM/PRODUCT/DOT_PRODUCT reductions for REAL(16). (llvm#83169)
The reductions implementations rely on trivial operations that are supported by the build compiler runtime, so they can be enabled whenever the build compiler provides 128-bit float support. std::conj used by DOT_PRODUCT is a template implementation in most environments, so it should not introduce a dependency on any 128-bit float support library. I am not goind to test it in all the build environments before merging. If it fails for someone, I will deal with it.
1 parent 04e8653 commit d699d9d

File tree

8 files changed

+81
-31
lines changed

8 files changed

+81
-31
lines changed

flang/include/flang/Common/float128.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,20 @@
4949
#endif /* (defined(__FLOAT128__) || defined(__SIZEOF_FLOAT128__)) && \
5050
!defined(_LIBCPP_VERSION) && !defined(__CUDA_ARCH__) */
5151

52+
/* Define pure C CFloat128Type and CFloat128ComplexType. */
53+
#if LDBL_MANT_DIG == 113
54+
typedef long double CFloat128Type;
55+
typedef long double _Complex CFloat128ComplexType;
56+
#elif HAS_FLOAT128
57+
typedef __float128 CFloat128Type;
58+
/*
59+
* Use mode() attribute supported by GCC and Clang.
60+
* Adjust it for other compilers as needed.
61+
*/
62+
#if !defined(_ARCH_PPC) || defined(__LONG_DOUBLE_IEEE128__)
63+
typedef _Complex float __attribute__((mode(TC))) CFloat128ComplexType;
64+
#else
65+
typedef _Complex float __attribute__((mode(KC))) CFloat128ComplexType;
66+
#endif
67+
#endif
5268
#endif /* FORTRAN_COMMON_FLOAT128_H_ */

flang/include/flang/Runtime/reduction.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ void RTDECL(CppSumComplex8)(std::complex<double> &, const Descriptor &,
9292
void RTDECL(CppSumComplex10)(std::complex<long double> &, const Descriptor &,
9393
const char *source, int line, int dim = 0,
9494
const Descriptor *mask = nullptr);
95-
void RTDECL(CppSumComplex16)(std::complex<long double> &, const Descriptor &,
96-
const char *source, int line, int dim = 0,
95+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
96+
void RTDECL(CppSumComplex16)(std::complex<CppFloat128Type> &,
97+
const Descriptor &, const char *source, int line, int dim = 0,
9798
const Descriptor *mask = nullptr);
99+
#endif
98100

99101
void RTDECL(SumDim)(Descriptor &result, const Descriptor &array, int dim,
100102
const char *source, int line, const Descriptor *mask = nullptr);
@@ -145,12 +147,16 @@ void RTDECL(CppProductComplex4)(std::complex<float> &, const Descriptor &,
145147
void RTDECL(CppProductComplex8)(std::complex<double> &, const Descriptor &,
146148
const char *source, int line, int dim = 0,
147149
const Descriptor *mask = nullptr);
150+
#if LDBL_MANT_DIG == 64
148151
void RTDECL(CppProductComplex10)(std::complex<long double> &,
149152
const Descriptor &, const char *source, int line, int dim = 0,
150153
const Descriptor *mask = nullptr);
151-
void RTDECL(CppProductComplex16)(std::complex<long double> &,
154+
#endif
155+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
156+
void RTDECL(CppProductComplex16)(std::complex<CppFloat128Type> &,
152157
const Descriptor &, const char *source, int line, int dim = 0,
153158
const Descriptor *mask = nullptr);
159+
#endif
154160

155161
void RTDECL(ProductDim)(Descriptor &result, const Descriptor &array, int dim,
156162
const char *source, int line, const Descriptor *mask = nullptr);

flang/runtime/Float128Math/cabs.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ namespace Fortran::runtime {
1212
extern "C" {
1313

1414
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
15-
// FIXME: the argument should be CppTypeFor<TypeCategory::Complex, 16>,
16-
// and it should be translated into the underlying library's
17-
// corresponding complex128 type.
18-
CppTypeFor<TypeCategory::Real, 16> RTDEF(CAbsF128)(ComplexF128 x) {
15+
// NOTE: Flang calls the runtime APIs using C _Complex ABI
16+
CppTypeFor<TypeCategory::Real, 16> RTDEF(CAbsF128)(CFloat128ComplexType x) {
1917
return CAbs<RTNAME(CAbsF128)>::invoke(x);
2018
}
2119
#endif

flang/runtime/Float128Math/math-entries.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,6 @@ DEFINE_FALLBACK(Y0)
9191
DEFINE_FALLBACK(Y1)
9292
DEFINE_FALLBACK(Yn)
9393

94-
// Define ComplexF128 type that is compatible with
95-
// the type of results/arguments of libquadmath.
96-
// TODO: this may need more work for other libraries/compilers.
97-
#if !defined(_ARCH_PPC) || defined(__LONG_DOUBLE_IEEE128__)
98-
typedef _Complex float __attribute__((mode(TC))) ComplexF128;
99-
#else
100-
typedef _Complex float __attribute__((mode(KC))) ComplexF128;
101-
#endif
102-
10394
#if HAS_LIBM
10495
// Define wrapper callers for libm.
10596
#include <ccomplex>

flang/runtime/complex-reduction.c

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ struct CppComplexDouble {
1919
struct CppComplexLongDouble {
2020
long double r, i;
2121
};
22+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
23+
struct CppComplexFloat128 {
24+
CFloat128Type r, i;
25+
};
26+
#endif
2227

2328
/* Not all environments define CMPLXF, CMPLX, CMPLXL. */
2429

@@ -70,6 +75,27 @@ static long_double_Complex_t CMPLXL(long double r, long double i) {
7075
#endif
7176
#endif
7277

78+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
79+
/*
80+
* GCC 7.4.0 (currently minimum GCC version for llvm builds)
81+
* supports __builtin_complex. For Clang, require >=12.0.
82+
* Otherwise, rely on the memory layout compatibility.
83+
*/
84+
#if (defined(__clang_major__) && (__clang_major__ >= 12)) || defined(__GNUC__)
85+
#define CMPLXF128 __builtin_complex
86+
#else
87+
static CFloat128ComplexType CMPLXF128(CFloat128Type r, CFloat128Type i) {
88+
union {
89+
struct CppComplexFloat128 x;
90+
CFloat128ComplexType result;
91+
} u;
92+
u.x.r = r;
93+
u.x.i = i;
94+
return u.result;
95+
}
96+
#endif
97+
#endif
98+
7399
/* RTNAME(SumComplex4) calls RTNAME(CppSumComplex4) with the same arguments
74100
* and converts the members of its C++ complex result to C _Complex.
75101
*/
@@ -93,9 +119,10 @@ ADAPT_REDUCTION(SumComplex8, double_Complex_t, CppComplexDouble, CMPLX,
93119
#if LDBL_MANT_DIG == 64
94120
ADAPT_REDUCTION(SumComplex10, long_double_Complex_t, CppComplexLongDouble,
95121
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
96-
#elif LDBL_MANT_DIG == 113
97-
ADAPT_REDUCTION(SumComplex16, long_double_Complex_t, CppComplexLongDouble,
98-
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
122+
#endif
123+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
124+
ADAPT_REDUCTION(SumComplex16, CFloat128ComplexType, CppComplexFloat128,
125+
CMPLXF128, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
99126
#endif
100127

101128
/* PRODUCT() */
@@ -106,9 +133,10 @@ ADAPT_REDUCTION(ProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
106133
#if LDBL_MANT_DIG == 64
107134
ADAPT_REDUCTION(ProductComplex10, long_double_Complex_t, CppComplexLongDouble,
108135
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
109-
#elif LDBL_MANT_DIG == 113
110-
ADAPT_REDUCTION(ProductComplex16, long_double_Complex_t, CppComplexLongDouble,
111-
CMPLXL, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
136+
#endif
137+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
138+
ADAPT_REDUCTION(ProductComplex16, CFloat128ComplexType, CppComplexFloat128,
139+
CMPLXF128, REDUCTION_ARGS, REDUCTION_ARG_NAMES)
112140
#endif
113141

114142
/* DOT_PRODUCT() */
@@ -119,7 +147,8 @@ ADAPT_REDUCTION(DotProductComplex8, double_Complex_t, CppComplexDouble, CMPLX,
119147
#if LDBL_MANT_DIG == 64
120148
ADAPT_REDUCTION(DotProductComplex10, long_double_Complex_t,
121149
CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
122-
#elif LDBL_MANT_DIG == 113
123-
ADAPT_REDUCTION(DotProductComplex16, long_double_Complex_t,
124-
CppComplexLongDouble, CMPLXL, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
150+
#endif
151+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
152+
ADAPT_REDUCTION(DotProductComplex16, CFloat128ComplexType, CppComplexFloat128,
153+
CMPLXF128, DOT_PRODUCT_ARGS, DOT_PRODUCT_ARG_NAMES)
125154
#endif

flang/runtime/complex-reduction.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#ifndef FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
1616
#define FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_
1717

18+
#include "flang/Common/float128.h"
1819
#include "flang/Runtime/entry-names.h"
1920
#include <complex.h>
2021

@@ -40,14 +41,18 @@ float_Complex_t RTNAME(SumComplex3)(REDUCTION_ARGS);
4041
float_Complex_t RTNAME(SumComplex4)(REDUCTION_ARGS);
4142
double_Complex_t RTNAME(SumComplex8)(REDUCTION_ARGS);
4243
long_double_Complex_t RTNAME(SumComplex10)(REDUCTION_ARGS);
43-
long_double_Complex_t RTNAME(SumComplex16)(REDUCTION_ARGS);
44+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
45+
CFloat128ComplexType RTNAME(SumComplex16)(REDUCTION_ARGS);
46+
#endif
4447

4548
float_Complex_t RTNAME(ProductComplex2)(REDUCTION_ARGS);
4649
float_Complex_t RTNAME(ProductComplex3)(REDUCTION_ARGS);
4750
float_Complex_t RTNAME(ProductComplex4)(REDUCTION_ARGS);
4851
double_Complex_t RTNAME(ProductComplex8)(REDUCTION_ARGS);
4952
long_double_Complex_t RTNAME(ProductComplex10)(REDUCTION_ARGS);
50-
long_double_Complex_t RTNAME(ProductComplex16)(REDUCTION_ARGS);
53+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
54+
CFloat128ComplexType RTNAME(ProductComplex16)(REDUCTION_ARGS);
55+
#endif
5156

5257
#define DOT_PRODUCT_ARGS \
5358
const struct CppDescriptor *x, const struct CppDescriptor *y, \
@@ -60,6 +65,8 @@ float_Complex_t RTNAME(DotProductComplex3)(DOT_PRODUCT_ARGS);
6065
float_Complex_t RTNAME(DotProductComplex4)(DOT_PRODUCT_ARGS);
6166
double_Complex_t RTNAME(DotProductComplex8)(DOT_PRODUCT_ARGS);
6267
long_double_Complex_t RTNAME(DotProductComplex10)(DOT_PRODUCT_ARGS);
63-
long_double_Complex_t RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
68+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
69+
CFloat128ComplexType RTNAME(DotProductComplex16)(DOT_PRODUCT_ARGS);
70+
#endif
6471

6572
#endif // FORTRAN_RUNTIME_COMPLEX_REDUCTION_H_

flang/runtime/product.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ CppTypeFor<TypeCategory::Real, 10> RTDEF(ProductReal10)(const Descriptor &x,
123123
NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
124124
"PRODUCT");
125125
}
126-
#elif LDBL_MANT_DIG == 113
126+
#endif
127+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
127128
CppTypeFor<TypeCategory::Real, 16> RTDEF(ProductReal16)(const Descriptor &x,
128129
const char *source, int line, int dim, const Descriptor *mask) {
129130
return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim, mask,
@@ -154,7 +155,8 @@ void RTDEF(CppProductComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
154155
mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 10>>{x},
155156
"PRODUCT");
156157
}
157-
#elif LDBL_MANT_DIG == 113
158+
#endif
159+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
158160
void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
159161
const Descriptor &x, const char *source, int line, int dim,
160162
const Descriptor *mask) {

flang/runtime/sum.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ void RTDEF(CppSumComplex10)(CppTypeFor<TypeCategory::Complex, 10> &result,
175175
result = GetTotalReduction<TypeCategory::Complex, 10>(
176176
x, source, line, dim, mask, ComplexSumAccumulator<long double>{x}, "SUM");
177177
}
178-
#elif LDBL_MANT_DIG == 113
178+
#endif
179+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
179180
void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
180181
const Descriptor &x, const char *source, int line, int dim,
181182
const Descriptor *mask) {

0 commit comments

Comments
 (0)