Skip to content

Commit dd22085

Browse files
authored
[flang][runtime] Split MATMUL[_TRANSPOSE] into separate entries. (#97406)
Device compilation is much faster for separate MATMUL[_TRANPOSE] entries than for a single one that covers all data types. The lowering changes and the removal of the generic entries will follow.
1 parent a9c44fd commit dd22085

File tree

7 files changed

+646
-2
lines changed

7 files changed

+646
-2
lines changed
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
//===-- include/flang/Runtime/matmul-instances.inc --------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
// Helper macros to instantiate MATMUL/MATMUL_TRANSPOSE definitions
9+
// for different data types of the input arguments.
10+
//===----------------------------------------------------------------------===//
11+
12+
#ifndef MATMUL_INSTANCE
13+
#error "Define MATMUL_INSTANCE before including this file"
14+
#endif
15+
16+
#ifndef MATMUL_DIRECT_INSTANCE
17+
#error "Define MATMUL_DIRECT_INSTANCE before including this file"
18+
#endif
19+
20+
// clang-format off
21+
22+
#define FOREACH_MATMUL_TYPE_PAIR(macro) \
23+
macro(Integer, 1, Integer, 1) \
24+
macro(Integer, 1, Integer, 2) \
25+
macro(Integer, 1, Integer, 4) \
26+
macro(Integer, 1, Integer, 8) \
27+
macro(Integer, 2, Integer, 1) \
28+
macro(Integer, 2, Integer, 2) \
29+
macro(Integer, 2, Integer, 4) \
30+
macro(Integer, 2, Integer, 8) \
31+
macro(Integer, 4, Integer, 1) \
32+
macro(Integer, 4, Integer, 2) \
33+
macro(Integer, 4, Integer, 4) \
34+
macro(Integer, 4, Integer, 8) \
35+
macro(Integer, 8, Integer, 1) \
36+
macro(Integer, 8, Integer, 2) \
37+
macro(Integer, 8, Integer, 4) \
38+
macro(Integer, 8, Integer, 8) \
39+
macro(Integer, 1, Real, 4) \
40+
macro(Integer, 1, Real, 8) \
41+
macro(Integer, 2, Real, 4) \
42+
macro(Integer, 2, Real, 8) \
43+
macro(Integer, 4, Real, 4) \
44+
macro(Integer, 4, Real, 8) \
45+
macro(Integer, 8, Real, 4) \
46+
macro(Integer, 8, Real, 8) \
47+
macro(Integer, 1, Complex, 4) \
48+
macro(Integer, 1, Complex, 8) \
49+
macro(Integer, 2, Complex, 4) \
50+
macro(Integer, 2, Complex, 8) \
51+
macro(Integer, 4, Complex, 4) \
52+
macro(Integer, 4, Complex, 8) \
53+
macro(Integer, 8, Complex, 4) \
54+
macro(Integer, 8, Complex, 8) \
55+
macro(Real, 4, Integer, 1) \
56+
macro(Real, 4, Integer, 2) \
57+
macro(Real, 4, Integer, 4) \
58+
macro(Real, 4, Integer, 8) \
59+
macro(Real, 8, Integer, 1) \
60+
macro(Real, 8, Integer, 2) \
61+
macro(Real, 8, Integer, 4) \
62+
macro(Real, 8, Integer, 8) \
63+
macro(Real, 4, Real, 4) \
64+
macro(Real, 4, Real, 8) \
65+
macro(Real, 8, Real, 4) \
66+
macro(Real, 8, Real, 8) \
67+
macro(Real, 4, Complex, 4) \
68+
macro(Real, 4, Complex, 8) \
69+
macro(Real, 8, Complex, 4) \
70+
macro(Real, 8, Complex, 8) \
71+
macro(Complex, 4, Integer, 1) \
72+
macro(Complex, 4, Integer, 2) \
73+
macro(Complex, 4, Integer, 4) \
74+
macro(Complex, 4, Integer, 8) \
75+
macro(Complex, 8, Integer, 1) \
76+
macro(Complex, 8, Integer, 2) \
77+
macro(Complex, 8, Integer, 4) \
78+
macro(Complex, 8, Integer, 8) \
79+
macro(Complex, 4, Real, 4) \
80+
macro(Complex, 4, Real, 8) \
81+
macro(Complex, 8, Real, 4) \
82+
macro(Complex, 8, Real, 8) \
83+
macro(Complex, 4, Complex, 4) \
84+
macro(Complex, 4, Complex, 8) \
85+
macro(Complex, 8, Complex, 4) \
86+
macro(Complex, 8, Complex, 8) \
87+
88+
FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
89+
FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
90+
91+
#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
92+
#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro) \
93+
macro(Integer, 16, Integer, 1) \
94+
macro(Integer, 16, Integer, 2) \
95+
macro(Integer, 16, Integer, 4) \
96+
macro(Integer, 16, Integer, 8) \
97+
macro(Integer, 16, Integer, 16) \
98+
macro(Integer, 16, Real, 4) \
99+
macro(Integer, 16, Real, 8) \
100+
macro(Integer, 16, Complex, 4) \
101+
macro(Integer, 16, Complex, 8) \
102+
macro(Real, 4, Integer, 16) \
103+
macro(Real, 8, Integer, 16) \
104+
macro(Complex, 4, Integer, 16) \
105+
macro(Complex, 8, Integer, 16) \
106+
107+
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
108+
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
109+
110+
#if LDBL_MANT_DIG == 64
111+
MATMUL_INSTANCE(Integer, 16, Real, 10)
112+
MATMUL_INSTANCE(Integer, 16, Complex, 10)
113+
MATMUL_INSTANCE(Real, 10, Integer, 16)
114+
MATMUL_INSTANCE(Complex, 10, Integer, 16)
115+
MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 10)
116+
MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
117+
MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
118+
MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
119+
#endif
120+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
121+
MATMUL_INSTANCE(Integer, 16, Real, 16)
122+
MATMUL_INSTANCE(Integer, 16, Complex, 16)
123+
MATMUL_INSTANCE(Real, 16, Integer, 16)
124+
MATMUL_INSTANCE(Complex, 16, Integer, 16)
125+
MATMUL_DIRECT_INSTANCE(Integer, 16, Real, 16)
126+
MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
127+
MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
128+
MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
129+
#endif
130+
#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
131+
132+
#if LDBL_MANT_DIG == 64
133+
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
134+
macro(Integer, 1, Real, 10) \
135+
macro(Integer, 1, Complex, 10) \
136+
macro(Integer, 2, Real, 10) \
137+
macro(Integer, 2, Complex, 10) \
138+
macro(Integer, 4, Real, 10) \
139+
macro(Integer, 4, Complex, 10) \
140+
macro(Integer, 8, Real, 10) \
141+
macro(Integer, 8, Complex, 10) \
142+
macro(Real, 4, Real, 10) \
143+
macro(Real, 4, Complex, 10) \
144+
macro(Real, 8, Real, 10) \
145+
macro(Real, 8, Complex, 10) \
146+
macro(Real, 10, Integer, 1) \
147+
macro(Real, 10, Integer, 2) \
148+
macro(Real, 10, Integer, 4) \
149+
macro(Real, 10, Integer, 8) \
150+
macro(Real, 10, Real, 4) \
151+
macro(Real, 10, Real, 8) \
152+
macro(Real, 10, Real, 10) \
153+
macro(Real, 10, Complex, 4) \
154+
macro(Real, 10, Complex, 8) \
155+
macro(Real, 10, Complex, 10) \
156+
macro(Complex, 4, Real, 10) \
157+
macro(Complex, 4, Complex, 10) \
158+
macro(Complex, 8, Real, 10) \
159+
macro(Complex, 8, Complex, 10) \
160+
macro(Complex, 10, Integer, 1) \
161+
macro(Complex, 10, Integer, 2) \
162+
macro(Complex, 10, Integer, 4) \
163+
macro(Complex, 10, Integer, 8) \
164+
macro(Complex, 10, Real, 4) \
165+
macro(Complex, 10, Real, 8) \
166+
macro(Complex, 10, Real, 10) \
167+
macro(Complex, 10, Complex, 4) \
168+
macro(Complex, 10, Complex, 8) \
169+
macro(Complex, 10, Complex, 10) \
170+
171+
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
172+
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
173+
174+
#if HAS_FLOAT128
175+
MATMUL_INSTANCE(Real, 10, Real, 16)
176+
MATMUL_INSTANCE(Real, 10, Complex, 16)
177+
MATMUL_INSTANCE(Real, 16, Real, 10)
178+
MATMUL_INSTANCE(Real, 16, Complex, 10)
179+
MATMUL_INSTANCE(Complex, 10, Real, 16)
180+
MATMUL_INSTANCE(Complex, 10, Complex, 16)
181+
MATMUL_INSTANCE(Complex, 16, Real, 10)
182+
MATMUL_INSTANCE(Complex, 16, Complex, 10)
183+
MATMUL_DIRECT_INSTANCE(Real, 10, Real, 16)
184+
MATMUL_DIRECT_INSTANCE(Real, 10, Complex, 16)
185+
MATMUL_DIRECT_INSTANCE(Real, 16, Real, 10)
186+
MATMUL_DIRECT_INSTANCE(Real, 16, Complex, 10)
187+
MATMUL_DIRECT_INSTANCE(Complex, 10, Real, 16)
188+
MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
189+
MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
190+
MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
191+
#endif
192+
#endif // LDBL_MANT_DIG == 64
193+
194+
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
195+
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \
196+
macro(Integer, 1, Real, 16) \
197+
macro(Integer, 1, Complex, 16) \
198+
macro(Integer, 2, Real, 16) \
199+
macro(Integer, 2, Complex, 16) \
200+
macro(Integer, 4, Real, 16) \
201+
macro(Integer, 4, Complex, 16) \
202+
macro(Integer, 8, Real, 16) \
203+
macro(Integer, 8, Complex, 16) \
204+
macro(Real, 4, Real, 16) \
205+
macro(Real, 4, Complex, 16) \
206+
macro(Real, 8, Real, 16) \
207+
macro(Real, 8, Complex, 16) \
208+
macro(Real, 16, Integer, 1) \
209+
macro(Real, 16, Integer, 2) \
210+
macro(Real, 16, Integer, 4) \
211+
macro(Real, 16, Integer, 8) \
212+
macro(Real, 16, Real, 4) \
213+
macro(Real, 16, Real, 8) \
214+
macro(Real, 16, Real, 16) \
215+
macro(Real, 16, Complex, 4) \
216+
macro(Real, 16, Complex, 8) \
217+
macro(Real, 16, Complex, 16) \
218+
macro(Complex, 4, Real, 16) \
219+
macro(Complex, 4, Complex, 16) \
220+
macro(Complex, 8, Real, 16) \
221+
macro(Complex, 8, Complex, 16) \
222+
macro(Complex, 16, Integer, 1) \
223+
macro(Complex, 16, Integer, 2) \
224+
macro(Complex, 16, Integer, 4) \
225+
macro(Complex, 16, Integer, 8) \
226+
macro(Complex, 16, Real, 4) \
227+
macro(Complex, 16, Real, 8) \
228+
macro(Complex, 16, Real, 16) \
229+
macro(Complex, 16, Complex, 4) \
230+
macro(Complex, 16, Complex, 8) \
231+
macro(Complex, 16, Complex, 16) \
232+
233+
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
234+
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
235+
#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
236+
237+
#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
238+
macro(Logical, 1, Logical, 1) \
239+
macro(Logical, 1, Logical, 2) \
240+
macro(Logical, 1, Logical, 4) \
241+
macro(Logical, 1, Logical, 8) \
242+
macro(Logical, 2, Logical, 1) \
243+
macro(Logical, 2, Logical, 2) \
244+
macro(Logical, 2, Logical, 4) \
245+
macro(Logical, 2, Logical, 8) \
246+
macro(Logical, 4, Logical, 1) \
247+
macro(Logical, 4, Logical, 2) \
248+
macro(Logical, 4, Logical, 4) \
249+
macro(Logical, 4, Logical, 8) \
250+
macro(Logical, 8, Logical, 1) \
251+
macro(Logical, 8, Logical, 2) \
252+
macro(Logical, 8, Logical, 4) \
253+
macro(Logical, 8, Logical, 8) \
254+
255+
FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_INSTANCE)
256+
FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
257+
258+
#undef MATMUL_INSTANCE
259+
#undef MATMUL_DIRECT_INSTANCE
260+
261+
// clang-format on

flang/include/flang/Runtime/matmul-transpose.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#ifndef FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
1212
#define FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_
13+
#include "flang/Common/float128.h"
14+
#include "flang/Common/uint128.h"
1315
#include "flang/Runtime/entry-names.h"
1416
namespace Fortran::runtime {
1517
class Descriptor;
@@ -25,6 +27,21 @@ void RTDECL(MatmulTranspose)(Descriptor &, const Descriptor &,
2527
// and have a valid base address.
2628
void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
2729
const Descriptor &, const char *sourceFile = nullptr, int line = 0);
30+
31+
// MATMUL(TRANSPOSE()) versions specialized by the categories of the operand
32+
// types. The KIND and shape information is taken from the argument's
33+
// descriptors.
34+
#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
35+
void RTDECL(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
36+
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
37+
int line);
38+
#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
39+
void RTDECL(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
40+
Descriptor & result, const Descriptor &x, const Descriptor &y, \
41+
const char *sourceFile, int line);
42+
43+
#include "matmul-instances.inc"
44+
2845
} // extern "C"
2946
} // namespace Fortran::runtime
3047
#endif // FORTRAN_RUNTIME_MATMUL_TRANSPOSE_H_

flang/include/flang/Runtime/matmul.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
#ifndef FORTRAN_RUNTIME_MATMUL_H_
1212
#define FORTRAN_RUNTIME_MATMUL_H_
13+
#include "flang/Common/float128.h"
14+
#include "flang/Common/uint128.h"
1315
#include "flang/Runtime/entry-names.h"
1416
namespace Fortran::runtime {
1517
class Descriptor;
@@ -24,6 +26,21 @@ void RTDECL(Matmul)(Descriptor &, const Descriptor &, const Descriptor &,
2426
// and have a valid base address.
2527
void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
2628
const Descriptor &, const char *sourceFile = nullptr, int line = 0);
29+
30+
// MATMUL versions specialized by the categories of the operand types.
31+
// The KIND and shape information is taken from the argument's
32+
// descriptors.
33+
#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
34+
void RTDECL(Matmul##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
35+
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
36+
int line);
37+
#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
38+
void RTDECL(MatmulDirect##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
39+
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
40+
int line);
41+
42+
#include "matmul-instances.inc"
43+
2744
} // extern "C"
2845
} // namespace Fortran::runtime
2946
#endif // FORTRAN_RUNTIME_MATMUL_H_

flang/runtime/matmul-transpose.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,30 @@ template <bool IS_ALLOCATING> struct MatmulTranspose {
384384
x, y, terminator, yCatKind->first, yCatKind->second);
385385
}
386386
};
387+
388+
template <bool IS_ALLOCATING, TypeCategory XCAT, int XKIND, TypeCategory YCAT,
389+
int YKIND>
390+
struct MatmulTransposeHelper {
391+
using ResultDescriptor =
392+
std::conditional_t<IS_ALLOCATING, Descriptor, const Descriptor>;
393+
RT_API_ATTRS void operator()(ResultDescriptor &result, const Descriptor &x,
394+
const Descriptor &y, const char *sourceFile, int line) const {
395+
Terminator terminator{sourceFile, line};
396+
auto xCatKind{x.type().GetCategoryAndKind()};
397+
auto yCatKind{y.type().GetCategoryAndKind()};
398+
RUNTIME_CHECK(terminator, xCatKind.has_value() && yCatKind.has_value());
399+
RUNTIME_CHECK(terminator, xCatKind->first == XCAT);
400+
RUNTIME_CHECK(terminator, yCatKind->first == YCAT);
401+
if constexpr (constexpr auto resultType{
402+
GetResultType(XCAT, XKIND, YCAT, YKIND)}) {
403+
return DoMatmulTranspose<IS_ALLOCATING, resultType->first,
404+
resultType->second, CppTypeFor<XCAT, XKIND>, CppTypeFor<YCAT, YKIND>>(
405+
result, x, y, terminator);
406+
}
407+
terminator.Crash("MATMUL-TRANSPOSE: bad operand types (%d(%d), %d(%d))",
408+
static_cast<int>(XCAT), XKIND, static_cast<int>(YCAT), YKIND);
409+
}
410+
};
387411
} // namespace
388412

389413
namespace Fortran::runtime {
@@ -399,6 +423,24 @@ void RTDEF(MatmulTransposeDirect)(const Descriptor &result, const Descriptor &x,
399423
MatmulTranspose<false>{}(result, x, y, sourceFile, line);
400424
}
401425

426+
#define MATMUL_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
427+
void RTDEF(MatmulTranspose##XCAT##XKIND##YCAT##YKIND)(Descriptor & result, \
428+
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
429+
int line) { \
430+
MatmulTransposeHelper<true, TypeCategory::XCAT, XKIND, TypeCategory::YCAT, \
431+
YKIND>{}(result, x, y, sourceFile, line); \
432+
}
433+
434+
#define MATMUL_DIRECT_INSTANCE(XCAT, XKIND, YCAT, YKIND) \
435+
void RTDEF(MatmulTransposeDirect##XCAT##XKIND##YCAT##YKIND)( \
436+
Descriptor & result, const Descriptor &x, const Descriptor &y, \
437+
const char *sourceFile, int line) { \
438+
MatmulTransposeHelper<false, TypeCategory::XCAT, XKIND, \
439+
TypeCategory::YCAT, YKIND>{}(result, x, y, sourceFile, line); \
440+
}
441+
442+
#include "flang/Runtime/matmul-instances.inc"
443+
402444
RT_EXT_API_GROUP_END
403445
} // extern "C"
404446
} // namespace Fortran::runtime

0 commit comments

Comments
 (0)