Skip to content

Commit 6ffdcfa

Browse files
authored
[flang] Lower REDUCE intrinsic with DIM argument (#94771)
This is a follow up patch to #94652 and handles the lowering of the reduce intrinsic with DIM argument and non scalar result.
1 parent 1934208 commit 6ffdcfa

File tree

4 files changed

+443
-1
lines changed

4 files changed

+443
-1
lines changed

flang/include/flang/Optimizer/Builder/Runtime/Reduction.h

+7
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ mlir::Value genReduce(fir::FirOpBuilder &builder, mlir::Location loc,
240240
mlir::Value maskBox, mlir::Value identity,
241241
mlir::Value ordered);
242242

243+
/// Generate call to `Reduce` intrinsic runtime routine. This is the version
244+
/// that takes arrays of any rank with a dim argument specified.
245+
void genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
246+
mlir::Value arrayBox, mlir::Value operation, mlir::Value dim,
247+
mlir::Value maskBox, mlir::Value identity,
248+
mlir::Value ordered, mlir::Value resultBox);
249+
243250
} // namespace fir::runtime
244251

245252
#endif // FORTRAN_OPTIMIZER_BUILDER_RUNTIME_REDUCTION_H

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -5790,7 +5790,17 @@ IntrinsicLibrary::genReduce(mlir::Type resultType,
57905790
return fir::runtime::genReduce(builder, loc, array, operation, mask,
57915791
identity, ordered);
57925792
}
5793-
TODO(loc, "reduce with array result");
5793+
// Handle cases that have an array result.
5794+
// Create mutable fir.box to be passed to the runtime for the result.
5795+
mlir::Type resultArrayType = builder.getVarLenSeqTy(resultType, rank - 1);
5796+
fir::MutableBoxValue resultMutableBox =
5797+
fir::factory::createTempMutableBox(builder, loc, resultArrayType);
5798+
mlir::Value resultIrBox =
5799+
fir::factory::getMutableIRBox(builder, loc, resultMutableBox);
5800+
mlir::Value dim = fir::getBase(args[2]);
5801+
fir::runtime::genReduceDim(builder, loc, array, operation, dim, mask,
5802+
identity, ordered, resultIrBox);
5803+
return readAndAddCleanUp(resultMutableBox, resultType, "REDUCE");
57945804
}
57955805

57965806
// REPEAT

flang/lib/Optimizer/Builder/Runtime/Reduction.cpp

+204
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,50 @@ struct ForcedReduceReal16 {
505505
}
506506
};
507507

508+
/// Placeholder for DIM real*10 version of Reduce Intrinsic
509+
struct ForcedReduceReal10Dim {
510+
static constexpr const char *name =
511+
ExpandAndQuoteKey(RTNAME(ReduceReal10Dim));
512+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
513+
return [](mlir::MLIRContext *ctx) {
514+
auto ty = mlir::FloatType::getF80(ctx);
515+
auto boxTy =
516+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
517+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
518+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
519+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
520+
auto refTy = fir::ReferenceType::get(ty);
521+
auto refBoxTy = fir::ReferenceType::get(boxTy);
522+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
523+
return mlir::FunctionType::get(
524+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
525+
{});
526+
};
527+
}
528+
};
529+
530+
/// Placeholder for DIM real*16 version of Reduce Intrinsic
531+
struct ForcedReduceReal16Dim {
532+
static constexpr const char *name =
533+
ExpandAndQuoteKey(RTNAME(ReduceReal16Dim));
534+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
535+
return [](mlir::MLIRContext *ctx) {
536+
auto ty = mlir::FloatType::getF128(ctx);
537+
auto boxTy =
538+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
539+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
540+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
541+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
542+
auto refTy = fir::ReferenceType::get(ty);
543+
auto refBoxTy = fir::ReferenceType::get(boxTy);
544+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
545+
return mlir::FunctionType::get(
546+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
547+
{});
548+
};
549+
}
550+
};
551+
508552
/// Placeholder for integer*16 version of Reduce Intrinsic
509553
struct ForcedReduceInteger16 {
510554
static constexpr const char *name =
@@ -525,6 +569,28 @@ struct ForcedReduceInteger16 {
525569
}
526570
};
527571

572+
/// Placeholder for DIM integer*16 version of Reduce Intrinsic
573+
struct ForcedReduceInteger16Dim {
574+
static constexpr const char *name =
575+
ExpandAndQuoteKey(RTNAME(ReduceInteger16Dim));
576+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
577+
return [](mlir::MLIRContext *ctx) {
578+
auto ty = mlir::IntegerType::get(ctx, 128);
579+
auto boxTy =
580+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
581+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
582+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
583+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
584+
auto refTy = fir::ReferenceType::get(ty);
585+
auto refBoxTy = fir::ReferenceType::get(boxTy);
586+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
587+
return mlir::FunctionType::get(
588+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
589+
{});
590+
};
591+
}
592+
};
593+
528594
/// Placeholder for complex(10) version of Reduce Intrinsic
529595
struct ForcedReduceComplex10 {
530596
static constexpr const char *name =
@@ -546,6 +612,28 @@ struct ForcedReduceComplex10 {
546612
}
547613
};
548614

615+
/// Placeholder for Dim complex(10) version of Reduce Intrinsic
616+
struct ForcedReduceComplex10Dim {
617+
static constexpr const char *name =
618+
ExpandAndQuoteKey(RTNAME(CppReduceComplex10Dim));
619+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
620+
return [](mlir::MLIRContext *ctx) {
621+
auto ty = mlir::ComplexType::get(mlir::FloatType::getF80(ctx));
622+
auto boxTy =
623+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
624+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
625+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
626+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
627+
auto refTy = fir::ReferenceType::get(ty);
628+
auto refBoxTy = fir::ReferenceType::get(boxTy);
629+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
630+
return mlir::FunctionType::get(
631+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
632+
{});
633+
};
634+
}
635+
};
636+
549637
/// Placeholder for complex(16) version of Reduce Intrinsic
550638
struct ForcedReduceComplex16 {
551639
static constexpr const char *name =
@@ -567,6 +655,28 @@ struct ForcedReduceComplex16 {
567655
}
568656
};
569657

658+
/// Placeholder for Dim complex(16) version of Reduce Intrinsic
659+
struct ForcedReduceComplex16Dim {
660+
static constexpr const char *name =
661+
ExpandAndQuoteKey(RTNAME(CppReduceComplex16Dim));
662+
static constexpr fir::runtime::FuncTypeBuilderFunc getTypeModel() {
663+
return [](mlir::MLIRContext *ctx) {
664+
auto ty = mlir::ComplexType::get(mlir::FloatType::getF128(ctx));
665+
auto boxTy =
666+
fir::runtime::getModel<const Fortran::runtime::Descriptor &>()(ctx);
667+
auto opTy = mlir::FunctionType::get(ctx, {ty, ty}, ty);
668+
auto strTy = fir::ReferenceType::get(mlir::IntegerType::get(ctx, 8));
669+
auto intTy = mlir::IntegerType::get(ctx, 8 * sizeof(int));
670+
auto refTy = fir::ReferenceType::get(ty);
671+
auto refBoxTy = fir::ReferenceType::get(boxTy);
672+
auto i1Ty = mlir::IntegerType::get(ctx, 1);
673+
return mlir::FunctionType::get(
674+
ctx, {refBoxTy, boxTy, opTy, strTy, intTy, intTy, boxTy, refTy, i1Ty},
675+
{});
676+
};
677+
}
678+
};
679+
570680
/// Generate call to specialized runtime function that takes a mask and
571681
/// dim argument. The All, Any, and Count intrinsics use this pattern.
572682
template <typename FN>
@@ -1461,3 +1571,97 @@ mlir::Value fir::runtime::genReduce(fir::FirOpBuilder &builder,
14611571
maskBox, identity, ordered);
14621572
return builder.create<fir::CallOp>(loc, func, args).getResult(0);
14631573
}
1574+
1575+
void fir::runtime::genReduceDim(fir::FirOpBuilder &builder, mlir::Location loc,
1576+
mlir::Value arrayBox, mlir::Value operation,
1577+
mlir::Value dim, mlir::Value maskBox,
1578+
mlir::Value identity, mlir::Value ordered,
1579+
mlir::Value resultBox) {
1580+
mlir::func::FuncOp func;
1581+
auto ty = arrayBox.getType();
1582+
auto arrTy = fir::dyn_cast_ptrOrBoxEleTy(ty);
1583+
auto eleTy = mlir::cast<fir::SequenceType>(arrTy).getEleTy();
1584+
1585+
mlir::MLIRContext *ctx = builder.getContext();
1586+
fir::factory::CharacterExprHelper charHelper{builder, loc};
1587+
1588+
if (eleTy.isF16())
1589+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal2Dim)>(loc, builder);
1590+
else if (eleTy.isBF16())
1591+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal3Dim)>(loc, builder);
1592+
else if (eleTy.isF32())
1593+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal4Dim)>(loc, builder);
1594+
else if (eleTy.isF64())
1595+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceReal8Dim)>(loc, builder);
1596+
else if (eleTy.isF80())
1597+
func = fir::runtime::getRuntimeFunc<ForcedReduceReal10Dim>(loc, builder);
1598+
else if (eleTy.isF128())
1599+
func = fir::runtime::getRuntimeFunc<ForcedReduceReal16Dim>(loc, builder);
1600+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(1)))
1601+
func =
1602+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger1Dim)>(loc, builder);
1603+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(2)))
1604+
func =
1605+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger2Dim)>(loc, builder);
1606+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(4)))
1607+
func =
1608+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger4Dim)>(loc, builder);
1609+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(8)))
1610+
func =
1611+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceInteger8Dim)>(loc, builder);
1612+
else if (eleTy.isInteger(builder.getKindMap().getIntegerBitsize(16)))
1613+
func = fir::runtime::getRuntimeFunc<ForcedReduceInteger16Dim>(loc, builder);
1614+
else if (eleTy == fir::ComplexType::get(ctx, 2))
1615+
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex2Dim)>(loc,
1616+
builder);
1617+
else if (eleTy == fir::ComplexType::get(ctx, 3))
1618+
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex3Dim)>(loc,
1619+
builder);
1620+
else if (eleTy == fir::ComplexType::get(ctx, 4))
1621+
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex4Dim)>(loc,
1622+
builder);
1623+
else if (eleTy == fir::ComplexType::get(ctx, 8))
1624+
func = fir::runtime::getRuntimeFunc<mkRTKey(CppReduceComplex8Dim)>(loc,
1625+
builder);
1626+
else if (eleTy == fir::ComplexType::get(ctx, 10))
1627+
func = fir::runtime::getRuntimeFunc<ForcedReduceComplex10Dim>(loc, builder);
1628+
else if (eleTy == fir::ComplexType::get(ctx, 16))
1629+
func = fir::runtime::getRuntimeFunc<ForcedReduceComplex16Dim>(loc, builder);
1630+
else if (eleTy == fir::LogicalType::get(ctx, 1))
1631+
func =
1632+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical1Dim)>(loc, builder);
1633+
else if (eleTy == fir::LogicalType::get(ctx, 2))
1634+
func =
1635+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical2Dim)>(loc, builder);
1636+
else if (eleTy == fir::LogicalType::get(ctx, 4))
1637+
func =
1638+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical4Dim)>(loc, builder);
1639+
else if (eleTy == fir::LogicalType::get(ctx, 8))
1640+
func =
1641+
fir::runtime::getRuntimeFunc<mkRTKey(ReduceLogical8Dim)>(loc, builder);
1642+
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 1)
1643+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter1Dim)>(loc,
1644+
builder);
1645+
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 2)
1646+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter2Dim)>(loc,
1647+
builder);
1648+
else if (fir::isa_char(eleTy) && charHelper.getCharacterKind(eleTy) == 4)
1649+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceCharacter4Dim)>(loc,
1650+
builder);
1651+
else if (fir::isa_derived(eleTy))
1652+
func = fir::runtime::getRuntimeFunc<mkRTKey(ReduceDerivedTypeDim)>(loc,
1653+
builder);
1654+
else
1655+
fir::intrinsicTypeTODO(builder, eleTy, loc, "REDUCE");
1656+
1657+
auto fTy = func.getFunctionType();
1658+
auto sourceFile = fir::factory::locationToFilename(builder, loc);
1659+
1660+
auto sourceLine =
1661+
fir::factory::locationToLineNo(builder, loc, fTy.getInput(4));
1662+
auto opAddr = builder.create<fir::BoxAddrOp>(loc, fTy.getInput(2), operation);
1663+
auto args = fir::runtime::createArguments(
1664+
builder, loc, fTy, resultBox, arrayBox, opAddr, sourceFile, sourceLine,
1665+
dim, maskBox, identity, ordered);
1666+
builder.create<fir::CallOp>(loc, func, args);
1667+
}

0 commit comments

Comments
 (0)