Skip to content

Commit baf6725

Browse files
authored
[flang][runtime] Support NORM2 for REAL(16) with FortranFloat128Math lib. (#83219)
Changed the lowering to call Norm2DimReal16 for REAL(16). Added the corresponding entry point to FortranFloat128Math, which required some restructuring in the related templates.
1 parent c1b8c6c commit baf6725

File tree

9 files changed

+256
-99
lines changed

9 files changed

+256
-99
lines changed

flang/include/flang/Runtime/reduction.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,12 @@ double RTDECL(Norm2_8)(
364364
#if LDBL_MANT_DIG == 64
365365
long double RTDECL(Norm2_10)(
366366
const Descriptor &, const char *source, int line, int dim = 0);
367-
#elif LDBL_MANT_DIG == 113
367+
#endif
368+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
368369
long double RTDECL(Norm2_16)(
369370
const Descriptor &, const char *source, int line, int dim = 0);
371+
void RTDECL(Norm2DimReal16)(
372+
Descriptor &, const Descriptor &, int dim, const char *source, int line);
370373
#endif
371374
void RTDECL(Norm2Dim)(
372375
Descriptor &, const Descriptor &, int dim, const char *source, int line);

flang/lib/Optimizer/Builder/Runtime/Reduction.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,22 @@ struct ForcedNorm2Real16 {
149149
}
150150
};
151151

152+
/// Placeholder for real*16 version of Norm2Dim Intrinsic
153+
struct ForcedNorm2DimReal16 {
154+
static constexpr const char *name = ExpandAndQuoteKey(RTNAME(Norm2DimReal16));
155+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
156+
return [](mlir::MLIRContext *ctx) {
157+
auto boxTy =
158+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
159+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
160+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
161+
return mlir::FunctionType::get(
162+
ctx, {fir::ReferenceType::get(boxTy), boxTy, intTy, strTy, intTy},
163+
mlir::NoneType::get(ctx));
164+
};
165+
}
166+
};
167+
152168
/// Placeholder for real*10 version of Product Intrinsic
153169
struct ForcedProductReal10 {
154170
static constexpr const char *name = ExpandAndQuoteKey(RTNAME(ProductReal10));
@@ -876,7 +892,14 @@ mlir::Value fir::runtime::genMinval(fir::FirOpBuilder &builder,
876892
void fir::runtime::genNorm2Dim(fir::FirOpBuilder &builder, mlir::Location loc,
877893
mlir::Value resultBox, mlir::Value arrayBox,
878894
mlir::Value dim) {
879-
auto func = fir::runtime::getRuntimeFunc<mkRTKey(Norm2Dim)>(loc, builder);
895+
mlir::func::FuncOp func;
896+
auto ty = arrayBox.getType();
897+
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
898+
auto eleTy = arrTy.cast<fir::SequenceType>().getEleTy();
899+
if (eleTy.isF128())
900+
func = fir::runtime::getRuntimeFunc<ForcedNorm2DimReal16>(loc, builder);
901+
else
902+
func = fir::runtime::getRuntimeFunc<mkRTKey(Norm2Dim)>(loc, builder);
880903
auto fTy = func.getFunctionType();
881904
auto sourceFile = fir::factory::locationToFilename(builder, loc);
882905
auto sourceLine =

flang/runtime/Float128Math/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ set(sources
6969
log.cpp
7070
log10.cpp
7171
lround.cpp
72+
norm2.cpp
7273
pow.cpp
7374
round.cpp
7475
sin.cpp

flang/runtime/Float128Math/math-entries.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ namespace Fortran::runtime {
5454
};
5555

5656
// Define fallback callers.
57+
DEFINE_FALLBACK(Abs)
5758
DEFINE_FALLBACK(Acos)
5859
DEFINE_FALLBACK(Acosh)
5960
DEFINE_FALLBACK(Asin)
@@ -99,6 +100,7 @@ DEFINE_FALLBACK(Yn)
99100
// Use STD math functions. They provide IEEE-754 128-bit float
100101
// support either via 'long double' or __float128.
101102
// The Bessel's functions are not present in STD namespace.
103+
DEFINE_SIMPLE_ALIAS(Abs, std::abs)
102104
DEFINE_SIMPLE_ALIAS(Acos, std::acos)
103105
DEFINE_SIMPLE_ALIAS(Acosh, std::acosh)
104106
DEFINE_SIMPLE_ALIAS(Asin, std::asin)
@@ -155,6 +157,7 @@ DEFINE_SIMPLE_ALIAS(Yn, ynl)
155157
#elif HAS_QUADMATHLIB
156158
// Define wrapper callers for libquadmath.
157159
#include "quadmath.h"
160+
DEFINE_SIMPLE_ALIAS(Abs, fabsq)
158161
DEFINE_SIMPLE_ALIAS(Acos, acosq)
159162
DEFINE_SIMPLE_ALIAS(Acosh, acoshq)
160163
DEFINE_SIMPLE_ALIAS(Asin, asinq)
@@ -191,6 +194,19 @@ DEFINE_SIMPLE_ALIAS(Y0, y0q)
191194
DEFINE_SIMPLE_ALIAS(Y1, y1q)
192195
DEFINE_SIMPLE_ALIAS(Yn, ynq)
193196
#endif
197+
198+
extern "C" {
199+
// Declarations of the entry points that might be referenced
200+
// within the Float128Math library itself.
201+
// Note that not all of these entry points are actually
202+
// defined in this library. Some of them are used just
203+
// as template parameters to call the corresponding callee directly.
204+
CppTypeFor<TypeCategory::Real, 16> RTDECL(AbsF128)(
205+
CppTypeFor<TypeCategory::Real, 16> x);
206+
CppTypeFor<TypeCategory::Real, 16> RTDECL(SqrtF128)(
207+
CppTypeFor<TypeCategory::Real, 16> x);
208+
} // extern "C"
209+
194210
} // namespace Fortran::runtime
195211

196212
#endif // FORTRAN_RUNTIME_FLOAT128MATH_MATH_ENTRIES_H_

flang/runtime/Float128Math/norm2.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
//===-- runtime/Float128Math/norm2.cpp ------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "math-entries.h"
10+
#include "reduction-templates.h"
11+
#include <cmath>
12+
13+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
14+
15+
namespace {
16+
using namespace Fortran::runtime;
17+
18+
using AccumType = Norm2AccumType<16>;
19+
20+
struct ABSTy {
21+
static AccumType compute(AccumType x) {
22+
return Sqrt<RTNAME(AbsF128)>::invoke(x);
23+
}
24+
};
25+
26+
struct SQRTTy {
27+
static AccumType compute(AccumType x) {
28+
return Sqrt<RTNAME(SqrtF128)>::invoke(x);
29+
}
30+
};
31+
32+
using Float128Norm2Accumulator = Norm2Accumulator<16, ABSTy, SQRTTy>;
33+
} // namespace
34+
35+
namespace Fortran::runtime {
36+
extern "C" {
37+
38+
CppTypeFor<TypeCategory::Real, 16> RTDEF(Norm2_16)(
39+
const Descriptor &x, const char *source, int line, int dim) {
40+
auto accumulator{::Float128Norm2Accumulator(x)};
41+
return GetTotalReduction<TypeCategory::Real, 16>(
42+
x, source, line, dim, nullptr, accumulator, "NORM2");
43+
}
44+
45+
void RTDEF(Norm2DimReal16)(Descriptor &result, const Descriptor &x, int dim,
46+
const char *source, int line) {
47+
Terminator terminator{source, line};
48+
auto type{x.type().GetCategoryAndKind()};
49+
RUNTIME_CHECK(terminator, type);
50+
RUNTIME_CHECK(
51+
terminator, type->first == TypeCategory::Real && type->second == 16);
52+
DoMaxMinNorm2<TypeCategory::Real, 16, ::Float128Norm2Accumulator>(
53+
result, x, dim, nullptr, "NORM2", terminator);
54+
}
55+
56+
} // extern "C"
57+
} // namespace Fortran::runtime
58+
59+
#endif

flang/runtime/extrema.cpp

Lines changed: 12 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -528,35 +528,6 @@ inline RT_API_ATTRS CppTypeFor<CAT, KIND> TotalNumericMaxOrMin(
528528
NumericExtremumAccumulator<CAT, KIND, IS_MAXVAL>{x}, intrinsic);
529529
}
530530

531-
template <TypeCategory CAT, int KIND, typename ACCUMULATOR>
532-
static RT_API_ATTRS void DoMaxMinNorm2(Descriptor &result, const Descriptor &x,
533-
int dim, const Descriptor *mask, const char *intrinsic,
534-
Terminator &terminator) {
535-
using Type = CppTypeFor<CAT, KIND>;
536-
ACCUMULATOR accumulator{x};
537-
if (dim == 0 || x.rank() == 1) {
538-
// Total reduction
539-
540-
// Element size of the destination descriptor is the same
541-
// as the element size of the source.
542-
result.Establish(x.type(), x.ElementBytes(), nullptr, 0, nullptr,
543-
CFI_attribute_allocatable);
544-
if (int stat{result.Allocate()}) {
545-
terminator.Crash(
546-
"%s: could not allocate memory for result; STAT=%d", intrinsic, stat);
547-
}
548-
DoTotalReduction<Type>(x, dim, mask, accumulator, intrinsic, terminator);
549-
accumulator.GetResult(result.OffsetElement<Type>());
550-
} else {
551-
// Partial reduction
552-
553-
// Element size of the destination descriptor is the same
554-
// as the element size of the source.
555-
PartialReduction<ACCUMULATOR, CAT, KIND>(result, x, x.ElementBytes(), dim,
556-
mask, terminator, intrinsic, accumulator);
557-
}
558-
}
559-
560531
template <TypeCategory CAT, bool IS_MAXVAL> struct MaxOrMinHelper {
561532
template <int KIND> struct Functor {
562533
RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
@@ -802,66 +773,11 @@ RT_EXT_API_GROUP_END
802773

803774
// NORM2
804775

805-
RT_VAR_GROUP_BEGIN
806-
807-
// Use at least double precision for accumulators.
808-
// Don't use __float128, it doesn't work with abs() or sqrt() yet.
809-
static constexpr RT_CONST_VAR_ATTRS int largestLDKind {
810-
#if LDBL_MANT_DIG == 113
811-
16
812-
#elif LDBL_MANT_DIG == 64
813-
10
814-
#else
815-
8
816-
#endif
817-
};
818-
819-
RT_VAR_GROUP_END
820-
821-
template <int KIND> class Norm2Accumulator {
822-
public:
823-
using Type = CppTypeFor<TypeCategory::Real, KIND>;
824-
using AccumType =
825-
CppTypeFor<TypeCategory::Real, std::clamp(KIND, 8, largestLDKind)>;
826-
explicit RT_API_ATTRS Norm2Accumulator(const Descriptor &array)
827-
: array_{array} {}
828-
RT_API_ATTRS void Reinitialize() { max_ = sum_ = 0; }
829-
template <typename A>
830-
RT_API_ATTRS void GetResult(A *p, int /*zeroBasedDim*/ = -1) const {
831-
// m * sqrt(1 + sum((others(:)/m)**2))
832-
*p = static_cast<Type>(max_ * std::sqrt(1 + sum_));
833-
}
834-
RT_API_ATTRS bool Accumulate(Type x) {
835-
auto absX{std::abs(static_cast<AccumType>(x))};
836-
if (!max_) {
837-
max_ = absX;
838-
} else if (absX > max_) {
839-
auto t{max_ / absX}; // < 1.0
840-
auto tsq{t * t};
841-
sum_ *= tsq; // scale sum to reflect change to the max
842-
sum_ += tsq; // include a term for the previous max
843-
max_ = absX;
844-
} else { // absX <= max_
845-
auto t{absX / max_};
846-
sum_ += t * t;
847-
}
848-
return true;
849-
}
850-
template <typename A>
851-
RT_API_ATTRS bool AccumulateAt(const SubscriptValue at[]) {
852-
return Accumulate(*array_.Element<A>(at));
853-
}
854-
855-
private:
856-
const Descriptor &array_;
857-
AccumType max_{0}; // value (m) with largest magnitude
858-
AccumType sum_{0}; // sum((others(:)/m)**2)
859-
};
860-
861776
template <int KIND> struct Norm2Helper {
862777
RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x, int dim,
863778
const Descriptor *mask, Terminator &terminator) const {
864-
DoMaxMinNorm2<TypeCategory::Real, KIND, Norm2Accumulator<KIND>>(
779+
DoMaxMinNorm2<TypeCategory::Real, KIND,
780+
typename Norm2AccumulatorGetter<KIND>::Type>(
865781
result, x, dim, mask, "NORM2", terminator);
866782
}
867783
};
@@ -872,26 +788,27 @@ RT_EXT_API_GROUP_BEGIN
872788
// TODO: REAL(2 & 3)
873789
CppTypeFor<TypeCategory::Real, 4> RTDEF(Norm2_4)(
874790
const Descriptor &x, const char *source, int line, int dim) {
875-
return GetTotalReduction<TypeCategory::Real, 4>(
876-
x, source, line, dim, nullptr, Norm2Accumulator<4>{x}, "NORM2");
791+
return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, nullptr,
792+
Norm2AccumulatorGetter<4>::create(x), "NORM2");
877793
}
878794
CppTypeFor<TypeCategory::Real, 8> RTDEF(Norm2_8)(
879795
const Descriptor &x, const char *source, int line, int dim) {
880-
return GetTotalReduction<TypeCategory::Real, 8>(
881-
x, source, line, dim, nullptr, Norm2Accumulator<8>{x}, "NORM2");
796+
return GetTotalReduction<TypeCategory::Real, 8>(x, source, line, dim, nullptr,
797+
Norm2AccumulatorGetter<8>::create(x), "NORM2");
882798
}
883799
#if LDBL_MANT_DIG == 64
884800
CppTypeFor<TypeCategory::Real, 10> RTDEF(Norm2_10)(
885801
const Descriptor &x, const char *source, int line, int dim) {
886-
return GetTotalReduction<TypeCategory::Real, 10>(
887-
x, source, line, dim, nullptr, Norm2Accumulator<10>{x}, "NORM2");
802+
return GetTotalReduction<TypeCategory::Real, 10>(x, source, line, dim,
803+
nullptr, Norm2AccumulatorGetter<10>::create(x), "NORM2");
888804
}
889805
#endif
890806
#if LDBL_MANT_DIG == 113
807+
// The __float128 implementation resides in FortranFloat128Math library.
891808
CppTypeFor<TypeCategory::Real, 16> RTDEF(Norm2_16)(
892809
const Descriptor &x, const char *source, int line, int dim) {
893-
return GetTotalReduction<TypeCategory::Real, 16>(
894-
x, source, line, dim, nullptr, Norm2Accumulator<16>{x}, "NORM2");
810+
return GetTotalReduction<TypeCategory::Real, 16>(x, source, line, dim,
811+
nullptr, Norm2AccumulatorGetter<16>::create(x), "NORM2");
895812
}
896813
#endif
897814

@@ -901,7 +818,7 @@ void RTDEF(Norm2Dim)(Descriptor &result, const Descriptor &x, int dim,
901818
auto type{x.type().GetCategoryAndKind()};
902819
RUNTIME_CHECK(terminator, type);
903820
if (type->first == TypeCategory::Real) {
904-
ApplyFloatingPointKind<Norm2Helper, void>(
821+
ApplyFloatingPointKind<Norm2Helper, void, true>(
905822
type->second, terminator, result, x, dim, nullptr, terminator);
906823
} else {
907824
terminator.Crash("NORM2: bad type code %d", x.type().raw());

0 commit comments

Comments
 (0)