Skip to content

Commit 8ce1aed

Browse files
authored
[flang] Lower MATMUL to type specific runtime calls. (#97547)
Lower MATMUL to the new runtime entries added in #97406.
1 parent 32f7672 commit 8ce1aed

File tree

15 files changed

+230
-392
lines changed

15 files changed

+230
-392
lines changed

flang/include/flang/Optimizer/Support/Utils.h

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ inline std::string mlirTypeToString(mlir::Type type) {
8484
return result;
8585
}
8686

87-
inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
88-
mlir::Type type, mlir::Location loc,
89-
const llvm::Twine &name) {
87+
inline std::string mlirTypeToIntrinsicFortran(fir::FirOpBuilder &builder,
88+
mlir::Type type,
89+
mlir::Location loc,
90+
const llvm::Twine &name) {
9091
if (type.isF16())
9192
return "REAL(KIND=2)";
9293
else if (type.isBF16())
@@ -123,6 +124,14 @@ inline std::string numericMlirTypeToFortran(fir::FirOpBuilder &builder,
123124
return "COMPLEX(KIND=10)";
124125
else if (type == fir::ComplexType::get(builder.getContext(), 16))
125126
return "COMPLEX(KIND=16)";
127+
else if (type == fir::LogicalType::get(builder.getContext(), 1))
128+
return "LOGICAL(KIND=1)";
129+
else if (type == fir::LogicalType::get(builder.getContext(), 2))
130+
return "LOGICAL(KIND=2)";
131+
else if (type == fir::LogicalType::get(builder.getContext(), 4))
132+
return "LOGICAL(KIND=4)";
133+
else if (type == fir::LogicalType::get(builder.getContext(), 8))
134+
return "LOGICAL(KIND=8)";
126135
else
127136
fir::emitFatalError(loc, "unsupported type in " + name + ": " +
128137
fir::mlirTypeToString(type));
@@ -133,10 +142,54 @@ inline void intrinsicTypeTODO(fir::FirOpBuilder &builder, mlir::Type type,
133142
const llvm::Twine &intrinsicName) {
134143
TODO(loc,
135144
"intrinsic: " +
136-
fir::numericMlirTypeToFortran(builder, type, loc, intrinsicName) +
145+
fir::mlirTypeToIntrinsicFortran(builder, type, loc, intrinsicName) +
137146
" in " + intrinsicName);
138147
}
139148

149+
inline void intrinsicTypeTODO2(fir::FirOpBuilder &builder, mlir::Type type1,
150+
mlir::Type type2, mlir::Location loc,
151+
const llvm::Twine &intrinsicName) {
152+
TODO(loc,
153+
"intrinsic: {" +
154+
fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
155+
", " +
156+
fir::mlirTypeToIntrinsicFortran(builder, type2, loc, intrinsicName) +
157+
"} in " + intrinsicName);
158+
}
159+
160+
inline std::pair<Fortran::common::TypeCategory, KindMapping::KindTy>
161+
mlirTypeToCategoryKind(mlir::Location loc, mlir::Type type) {
162+
if (type.isF16())
163+
return {Fortran::common::TypeCategory::Real, 2};
164+
else if (type.isBF16())
165+
return {Fortran::common::TypeCategory::Real, 3};
166+
else if (type.isF32())
167+
return {Fortran::common::TypeCategory::Real, 4};
168+
else if (type.isF64())
169+
return {Fortran::common::TypeCategory::Real, 8};
170+
else if (type.isF80())
171+
return {Fortran::common::TypeCategory::Real, 10};
172+
else if (type.isF128())
173+
return {Fortran::common::TypeCategory::Real, 16};
174+
else if (type.isInteger(8))
175+
return {Fortran::common::TypeCategory::Integer, 1};
176+
else if (type.isInteger(16))
177+
return {Fortran::common::TypeCategory::Integer, 2};
178+
else if (type.isInteger(32))
179+
return {Fortran::common::TypeCategory::Integer, 4};
180+
else if (type.isInteger(64))
181+
return {Fortran::common::TypeCategory::Integer, 8};
182+
else if (type.isInteger(128))
183+
return {Fortran::common::TypeCategory::Integer, 16};
184+
else if (auto complexType = mlir::dyn_cast<fir::ComplexType>(type))
185+
return {Fortran::common::TypeCategory::Complex, complexType.getFKind()};
186+
else if (auto logicalType = mlir::dyn_cast<fir::LogicalType>(type))
187+
return {Fortran::common::TypeCategory::Logical, logicalType.getFKind()};
188+
else
189+
fir::emitFatalError(loc,
190+
"unsupported type: " + fir::mlirTypeToString(type));
191+
}
192+
140193
/// Find the fir.type_info that was created for this \p recordType in \p module,
141194
/// if any. \p symbolTable can be provided to speed-up the lookup. This tool
142195
/// will match record type even if they have been "altered" in type conversion

flang/include/flang/Runtime/matmul-instances.inc

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
#error "Define MATMUL_DIRECT_INSTANCE before including this file"
1818
#endif
1919

20+
#ifndef MATMUL_FORCE_ALL_TYPES
21+
#error "Define MATMUL_FORCE_ALL_TYPES to 0 or 1 before including this file"
22+
#endif
23+
2024
// clang-format off
2125

2226
#define FOREACH_MATMUL_TYPE_PAIR(macro) \
@@ -88,7 +92,7 @@
8892
FOREACH_MATMUL_TYPE_PAIR(MATMUL_INSTANCE)
8993
FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
9094

91-
#if defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
95+
#if MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
9296
#define FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(macro) \
9397
macro(Integer, 16, Integer, 1) \
9498
macro(Integer, 16, Integer, 2) \
@@ -107,7 +111,7 @@ FOREACH_MATMUL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
107111
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_INSTANCE)
108112
FOREACH_MATMUL_TYPE_PAIR_WITH_INT16(MATMUL_DIRECT_INSTANCE)
109113

110-
#if LDBL_MANT_DIG == 64
114+
#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
111115
MATMUL_INSTANCE(Integer, 16, Real, 10)
112116
MATMUL_INSTANCE(Integer, 16, Complex, 10)
113117
MATMUL_INSTANCE(Real, 10, Integer, 16)
@@ -117,7 +121,7 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 10)
117121
MATMUL_DIRECT_INSTANCE(Real, 10, Integer, 16)
118122
MATMUL_DIRECT_INSTANCE(Complex, 10, Integer, 16)
119123
#endif
120-
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
124+
#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
121125
MATMUL_INSTANCE(Integer, 16, Real, 16)
122126
MATMUL_INSTANCE(Integer, 16, Complex, 16)
123127
MATMUL_INSTANCE(Real, 16, Integer, 16)
@@ -127,9 +131,9 @@ MATMUL_DIRECT_INSTANCE(Integer, 16, Complex, 16)
127131
MATMUL_DIRECT_INSTANCE(Real, 16, Integer, 16)
128132
MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
129133
#endif
130-
#endif // defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T
134+
#endif // MATMUL_FORCE_ALL_TYPES || (defined __SIZEOF_INT128__ && !AVOID_NATIVE_UINT128_T)
131135

132-
#if LDBL_MANT_DIG == 64
136+
#if MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
133137
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(macro) \
134138
macro(Integer, 1, Real, 10) \
135139
macro(Integer, 1, Complex, 10) \
@@ -171,7 +175,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Integer, 16)
171175
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_INSTANCE)
172176
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL10(MATMUL_DIRECT_INSTANCE)
173177

174-
#if HAS_FLOAT128
178+
#if MATMUL_FORCE_ALL_TYPES || HAS_FLOAT128
175179
MATMUL_INSTANCE(Real, 10, Real, 16)
176180
MATMUL_INSTANCE(Real, 10, Complex, 16)
177181
MATMUL_INSTANCE(Real, 16, Real, 10)
@@ -189,9 +193,9 @@ MATMUL_DIRECT_INSTANCE(Complex, 10, Complex, 16)
189193
MATMUL_DIRECT_INSTANCE(Complex, 16, Real, 10)
190194
MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
191195
#endif
192-
#endif // LDBL_MANT_DIG == 64
196+
#endif // MATMUL_FORCE_ALL_TYPES || LDBL_MANT_DIG == 64
193197

194-
#if LDBL_MANT_DIG == 113 || HAS_FLOAT128
198+
#if MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
195199
#define FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(macro) \
196200
macro(Integer, 1, Real, 16) \
197201
macro(Integer, 1, Complex, 16) \
@@ -232,7 +236,7 @@ MATMUL_DIRECT_INSTANCE(Complex, 16, Complex, 10)
232236

233237
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_INSTANCE)
234238
FOREACH_MATMUL_TYPE_PAIR_WITH_REAL16(MATMUL_DIRECT_INSTANCE)
235-
#endif // LDBL_MANT_DIG == 113 || HAS_FLOAT128
239+
#endif // MATMUL_FORCE_ALL_TYPES || (LDBL_MANT_DIG == 113 || HAS_FLOAT128)
236240

237241
#define FOREACH_MATMUL_LOGICAL_TYPE_PAIR(macro) \
238242
macro(Logical, 1, Logical, 1) \
@@ -257,5 +261,6 @@ FOREACH_MATMUL_LOGICAL_TYPE_PAIR(MATMUL_DIRECT_INSTANCE)
257261

258262
#undef MATMUL_INSTANCE
259263
#undef MATMUL_DIRECT_INSTANCE
264+
#undef MATMUL_FORCE_ALL_TYPES
260265

261266
// clang-format on

flang/include/flang/Runtime/matmul-transpose.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ void RTDECL(MatmulTransposeDirect)(const Descriptor &, const Descriptor &,
4040
Descriptor & result, const Descriptor &x, const Descriptor &y, \
4141
const char *sourceFile, int line);
4242

43+
#define MATMUL_FORCE_ALL_TYPES 0
44+
4345
#include "matmul-instances.inc"
4446

4547
} // extern "C"

flang/include/flang/Runtime/matmul.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ void RTDECL(MatmulDirect)(const Descriptor &, const Descriptor &,
3939
const Descriptor &x, const Descriptor &y, const char *sourceFile, \
4040
int line);
4141

42+
#define MATMUL_FORCE_ALL_TYPES 0
43+
4244
#include "matmul-instances.inc"
4345

4446
} // extern "C"

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -701,18 +701,19 @@ prettyPrintIntrinsicName(fir::FirOpBuilder &builder, mlir::Location loc,
701701
if (name == "pow") {
702702
assert(funcType.getNumInputs() == 2 && "power operator has two arguments");
703703
std::string displayName{" ** "};
704-
sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
705-
displayName)
704+
sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
705+
displayName)
706706
<< displayName
707-
<< numericMlirTypeToFortran(builder, funcType.getInput(1), loc,
708-
displayName);
707+
<< mlirTypeToIntrinsicFortran(builder, funcType.getInput(1), loc,
708+
displayName);
709709
} else {
710710
sstream << name.upper() << "(";
711711
if (funcType.getNumInputs() > 0)
712-
sstream << numericMlirTypeToFortran(builder, funcType.getInput(0), loc,
713-
name);
712+
sstream << mlirTypeToIntrinsicFortran(builder, funcType.getInput(0), loc,
713+
name);
714714
for (mlir::Type argType : funcType.getInputs().drop_front()) {
715-
sstream << ", " << numericMlirTypeToFortran(builder, argType, loc, name);
715+
sstream << ", "
716+
<< mlirTypeToIntrinsicFortran(builder, argType, loc, name);
716717
}
717718
sstream << ")";
718719
}

flang/lib/Optimizer/Builder/Runtime/Transformational.cpp

Lines changed: 92 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -329,11 +329,64 @@ void fir::runtime::genEoshiftVector(fir::FirOpBuilder &builder,
329329
builder.create<fir::CallOp>(loc, eoshiftFunc, args);
330330
}
331331

332+
/// Define ForcedMatmul<ACAT><AKIND><BCAT><BKIND> models.
333+
struct ForcedMatmulTypeModel {
334+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
335+
return [](mlir::MLIRContext *ctx) {
336+
auto boxRefTy =
337+
fir::runtime::getModel<Fortran::runtime::Descriptor &>()(ctx);
338+
auto boxTy =
339+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
340+
auto strTy = fir::runtime::getModel<const char *>()(ctx);
341+
auto intTy = fir::runtime::getModel<int>()(ctx);
342+
auto voidTy = fir::runtime::getModel<void>()(ctx);
343+
return mlir::FunctionType::get(
344+
ctx, {boxRefTy, boxTy, boxTy, strTy, intTy}, {voidTy});
345+
};
346+
}
347+
};
348+
349+
#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
350+
struct ForcedMatmul##ACAT##AKIND##BCAT##BKIND \
351+
: public ForcedMatmulTypeModel { \
352+
static constexpr const char *name = \
353+
ExpandAndQuoteKey(RTNAME(Matmul##ACAT##AKIND##BCAT##BKIND)); \
354+
};
355+
356+
#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
357+
#define MATMUL_FORCE_ALL_TYPES 1
358+
359+
#include "flang/Runtime/matmul-instances.inc"
360+
332361
/// Generate call to Matmul intrinsic runtime routine.
333362
void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
334363
mlir::Value resultBox, mlir::Value matrixABox,
335364
mlir::Value matrixBBox) {
336-
auto func = fir::runtime::getRuntimeFunc<mkRTKey(Matmul)>(loc, builder);
365+
mlir::func::FuncOp func;
366+
auto boxATy = matrixABox.getType();
367+
auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
368+
auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
369+
auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
370+
auto boxBTy = matrixBBox.getType();
371+
auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
372+
auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
373+
auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);
374+
375+
#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
376+
if (!func && aCat == TypeCategory::ACAT && aKind == AKIND && \
377+
bCat == TypeCategory::BCAT && bKind == BKIND) { \
378+
func = \
379+
fir::runtime::getRuntimeFunc<ForcedMatmul##ACAT##AKIND##BCAT##BKIND>( \
380+
loc, builder); \
381+
}
382+
383+
#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
384+
#define MATMUL_FORCE_ALL_TYPES 1
385+
#include "flang/Runtime/matmul-instances.inc"
386+
387+
if (!func) {
388+
fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc, "MATMUL");
389+
}
337390
auto fTy = func.getFunctionType();
338391
auto sourceFile = fir::factory::locationToFilename(builder, loc);
339392
auto sourceLine =
@@ -344,13 +397,48 @@ void fir::runtime::genMatmul(fir::FirOpBuilder &builder, mlir::Location loc,
344397
builder.create<fir::CallOp>(loc, func, args);
345398
}
346399

347-
/// Generate call to MatmulTranspose intrinsic runtime routine.
400+
/// Define ForcedMatmulTranspose<ACAT><AKIND><BCAT><BKIND> models.
401+
#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
402+
struct ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND \
403+
: public ForcedMatmulTypeModel { \
404+
static constexpr const char *name = \
405+
ExpandAndQuoteKey(RTNAME(MatmulTranspose##ACAT##AKIND##BCAT##BKIND)); \
406+
};
407+
408+
#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
409+
#define MATMUL_FORCE_ALL_TYPES 1
410+
411+
#include "flang/Runtime/matmul-instances.inc"
412+
348413
void fir::runtime::genMatmulTranspose(fir::FirOpBuilder &builder,
349414
mlir::Location loc, mlir::Value resultBox,
350415
mlir::Value matrixABox,
351416
mlir::Value matrixBBox) {
352-
auto func =
353-
fir::runtime::getRuntimeFunc<mkRTKey(MatmulTranspose)>(loc, builder);
417+
mlir::func::FuncOp func;
418+
auto boxATy = matrixABox.getType();
419+
auto arrATy = fir::dyn_cast_ptrOrBoxEleTy(boxATy);
420+
auto arrAEleTy = mlir::cast<fir::SequenceType>(arrATy).getEleTy();
421+
auto [aCat, aKind] = fir::mlirTypeToCategoryKind(loc, arrAEleTy);
422+
auto boxBTy = matrixBBox.getType();
423+
auto arrBTy = fir::dyn_cast_ptrOrBoxEleTy(boxBTy);
424+
auto arrBEleTy = mlir::cast<fir::SequenceType>(arrBTy).getEleTy();
425+
auto [bCat, bKind] = fir::mlirTypeToCategoryKind(loc, arrBEleTy);
426+
427+
#define MATMUL_INSTANCE(ACAT, AKIND, BCAT, BKIND) \
428+
if (!func && aCat == TypeCategory::ACAT && aKind == AKIND && \
429+
bCat == TypeCategory::BCAT && bKind == BKIND) { \
430+
func = fir::runtime::getRuntimeFunc< \
431+
ForcedMatmulTranspose##ACAT##AKIND##BCAT##BKIND>(loc, builder); \
432+
}
433+
434+
#define MATMUL_DIRECT_INSTANCE(ACAT, AKIND, BCAT, BKIND)
435+
#define MATMUL_FORCE_ALL_TYPES 1
436+
#include "flang/Runtime/matmul-instances.inc"
437+
438+
if (!func) {
439+
fir::intrinsicTypeTODO2(builder, arrAEleTy, arrBEleTy, loc,
440+
"MATMUL-TRANSPOSE");
441+
}
354442
auto fTy = func.getFunctionType();
355443
auto sourceFile = fir::factory::locationToFilename(builder, loc);
356444
auto sourceLine =

0 commit comments

Comments
 (0)