Skip to content

Commit 38fd018

Browse files
authored
[flang] Lower REDUCE intrinsic for reduction op with args by value (#95353)
#95297 Updates the runtime entry points to distinguish between reduction operation with arguments passed by value or by reference. Add lowering to support the arguments passed by value.
1 parent b7599da commit 38fd018

File tree

5 files changed

+674
-75
lines changed

5 files changed

+674
-75
lines changed

flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ using FuncTypeBuilderFunc = mlir::FunctionType (*)(mlir::MLIRContext *);
6464
}; \
6565
}
6666

67+
#define REDUCTION_VALUE_OPERATION_MODEL(T) \
68+
template <> \
69+
constexpr TypeBuilderFunc \
70+
getModel<Fortran::runtime::ValueReductionOperation<T>>() { \
71+
return [](mlir::MLIRContext *context) -> mlir::Type { \
72+
TypeBuilderFunc f{getModel<T>()}; \
73+
auto refTy = fir::ReferenceType::get(f(context)); \
74+
return mlir::FunctionType::get(context, {f(context), f(context)}, \
75+
refTy); \
76+
}; \
77+
}
78+
6779
#define REDUCTION_CHAR_OPERATION_MODEL(T) \
6880
template <> \
6981
constexpr TypeBuilderFunc \
@@ -481,17 +493,27 @@ constexpr TypeBuilderFunc getModel<void>() {
481493
}
482494

483495
REDUCTION_REF_OPERATION_MODEL(std::int8_t)
496+
REDUCTION_VALUE_OPERATION_MODEL(std::int8_t)
484497
REDUCTION_REF_OPERATION_MODEL(std::int16_t)
498+
REDUCTION_VALUE_OPERATION_MODEL(std::int16_t)
485499
REDUCTION_REF_OPERATION_MODEL(std::int32_t)
500+
REDUCTION_VALUE_OPERATION_MODEL(std::int32_t)
486501
REDUCTION_REF_OPERATION_MODEL(std::int64_t)
502+
REDUCTION_VALUE_OPERATION_MODEL(std::int64_t)
487503
REDUCTION_REF_OPERATION_MODEL(Fortran::common::int128_t)
504+
REDUCTION_VALUE_OPERATION_MODEL(Fortran::common::int128_t)
488505

489506
REDUCTION_REF_OPERATION_MODEL(float)
507+
REDUCTION_VALUE_OPERATION_MODEL(float)
490508
REDUCTION_REF_OPERATION_MODEL(double)
509+
REDUCTION_VALUE_OPERATION_MODEL(double)
491510
REDUCTION_REF_OPERATION_MODEL(long double)
511+
REDUCTION_VALUE_OPERATION_MODEL(long double)
492512

493513
REDUCTION_REF_OPERATION_MODEL(std::complex<float>)
514+
REDUCTION_VALUE_OPERATION_MODEL(std::complex<float>)
494515
REDUCTION_REF_OPERATION_MODEL(std::complex<double>)
516+
REDUCTION_VALUE_OPERATION_MODEL(std::complex<double>)
495517

496518
REDUCTION_CHAR_OPERATION_MODEL(char)
497519
REDUCTION_CHAR_OPERATION_MODEL(char16_t)

flang/include/flang/Optimizer/Builder/Runtime/Reduction.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,23 +229,23 @@ void genIParityDim(fir::FirOpBuilder &builder, mlir::Location loc,
229229
/// result value. This is used for COMPLEX, CHARACTER and DERIVED TYPES.
230230
void genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
231231
mlir::Value arrayBox, mlir::Value operation, mlir::Value maskBox,
232-
mlir::Value identity, mlir::Value ordered,
233-
mlir::Value resultBox);
232+
mlir::Value identity, mlir::Value ordered, mlir::Value resultBox,
233+
bool argByRef);
234234

235235
/// Generate call to `Reduce` intrinsic runtime routine. This is the version
236236
/// that does not take a dim argument and return a scalare result. This is used
237237
/// for REAL, INTEGER and LOGICAL TYPES.
238238
mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
239239
mlir::Value arrayBox, mlir::Value operation,
240240
mlir::Value maskBox, mlir::Value identity,
241-
mlir::Value ordered);
241+
mlir::Value ordered, bool argByRef);
242242

243243
/// Generate call to `Reduce` intrinsic runtime routine. This is the version
244244
/// that takes arrays of any rank with a dim argument specified.
245245
void genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
246246
mlir::Value arrayBox, mlir::Value operation, mlir::Value dim,
247247
mlir::Value maskBox, mlir::Value identity,
248-
mlir::Value ordered, mlir::Value resultBox);
248+
mlir::Value ordered, mlir::Value resultBox, bool argByRef);
249249

250250
} // namespace fir::runtime
251251

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5745,6 +5745,14 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
57455745
int rank = arrayTmp.rank();
57465746
assert(rank >= 1);
57475747

5748+
// Arguements to the reduction operation are passed by reference or value?
5749+
bool argByRef = true;
5750+
if (auto embox =
5751+
mlir::dyn_cast_or_null<fir::EmboxProcOp>(operation.getDefiningOp())) {
5752+
auto fctTy = mlir::dyn_cast<mlir::FunctionType>(embox.getFunc().getType());
5753+
argByRef = mlir::isa<fir::ReferenceType>(fctTy.getInput(0));
5754+
}
5755+
57485756
mlir::Type ty = array.getType();
57495757
mlir::Type arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
57505758
mlir::Type eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
@@ -5772,7 +5780,7 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
57725780
if (fir::isa_complex(eleTy) || fir::isa_derived(eleTy)) {
57735781
mlir::Value result = builder.createTemporary(loc, eleTy);
57745782
fir::runtime::genReduce(builder, loc, array, operation, mask, identity,
5775-
ordered, result);
5783+
ordered, result, argByRef);
57765784
if (fir::isa_derived(eleTy))
57775785
return result;
57785786
return builder.create<fir::LoadOp>(loc, result);
@@ -5789,11 +5797,11 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
57895797
charTy.getLen());
57905798
fir::CharBoxValue temp = charHelper.createCharacterTemp(eleTy, len);
57915799
fir::runtime::genReduce(builder, loc, array, operation, mask, identity,
5792-
ordered, temp.getBuffer());
5800+
ordered, temp.getBuffer(), argByRef);
57935801
return temp;
57945802
}
57955803
return fir::runtime::genReduce(builder, loc, array, operation, mask,
5796-
identity, ordered);
5804+
identity, ordered, argByRef);
57975805
}
57985806
// Handle cases that have an array result.
57995807
// Create mutable fir.box to be passed to the runtime for the result.
@@ -5804,7 +5812,7 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
58045812
fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
58055813
mlir::Value dim = fir::getBase(args[2]);
58065814
fir::runtime::genReduceDim(builder, loc, array, operation, dim, mask,
5807-
identity, ordered, resultIrBox);
5815+
identity, ordered, resultIrBox, argByRef);
58085816
return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
58095817
}
58105818

0 commit comments

Comments
 (0)