Skip to content

Commit f11e08f

Browse files
authored
[flang] Generate fir.do_loop reduce from DO CONCURRENT REDUCE clause (#94718)
Derived from #92480. This PR updates the lowering process of DO CONCURRENT to support F'2023 REDUCE clause. The structure `IncrementLoopInfo` is extended to have both reduction operations and symbols in `reduceSymList`. The function `getConcurrentControl` constructs `reduceSymList` for the innermost loop. Finally, `genFIRIncrementLoopBegin` builds `fir.do_loop` with reduction operands.
1 parent 65d3009 commit f11e08f

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ struct IncrementLoopInfo {
104104

105105
bool hasLocalitySpecs() const {
106106
return !localSymList.empty() || !localInitSymList.empty() ||
107-
!sharedSymList.empty();
107+
!reduceSymList.empty() || !sharedSymList.empty();
108108
}
109109

110110
// Data members common to both structured and unstructured loops.
@@ -116,6 +116,9 @@ struct IncrementLoopInfo {
116116
bool isUnordered; // do concurrent, forall
117117
llvm::SmallVector<const Fortran::semantics::Symbol *> localSymList;
118118
llvm::SmallVector<const Fortran::semantics::Symbol *> localInitSymList;
119+
llvm::SmallVector<
120+
std::pair<fir::ReduceOperationEnum, const Fortran::semantics::Symbol *>>
121+
reduceSymList;
119122
llvm::SmallVector<const Fortran::semantics::Symbol *> sharedSymList;
120123
mlir::Value loopVariable = nullptr;
121124

@@ -1741,6 +1744,35 @@ class FirConverter : public Fortran::lower::AbstractConverter {
17411744
builder->create<fir::UnreachableOp>(loc);
17421745
}
17431746

1747+
fir::ReduceOperationEnum
1748+
getReduceOperationEnum(const Fortran::parser::ReductionOperator &rOpr) {
1749+
switch (rOpr.v) {
1750+
case Fortran::parser::ReductionOperator::Operator::Plus:
1751+
return fir::ReduceOperationEnum::Add;
1752+
case Fortran::parser::ReductionOperator::Operator::Multiply:
1753+
return fir::ReduceOperationEnum::Multiply;
1754+
case Fortran::parser::ReductionOperator::Operator::And:
1755+
return fir::ReduceOperationEnum::AND;
1756+
case Fortran::parser::ReductionOperator::Operator::Or:
1757+
return fir::ReduceOperationEnum::OR;
1758+
case Fortran::parser::ReductionOperator::Operator::Eqv:
1759+
return fir::ReduceOperationEnum::EQV;
1760+
case Fortran::parser::ReductionOperator::Operator::Neqv:
1761+
return fir::ReduceOperationEnum::NEQV;
1762+
case Fortran::parser::ReductionOperator::Operator::Max:
1763+
return fir::ReduceOperationEnum::MAX;
1764+
case Fortran::parser::ReductionOperator::Operator::Min:
1765+
return fir::ReduceOperationEnum::MIN;
1766+
case Fortran::parser::ReductionOperator::Operator::Iand:
1767+
return fir::ReduceOperationEnum::IAND;
1768+
case Fortran::parser::ReductionOperator::Operator::Ior:
1769+
return fir::ReduceOperationEnum::IOR;
1770+
case Fortran::parser::ReductionOperator::Operator::Ieor:
1771+
return fir::ReduceOperationEnum::EIOR;
1772+
}
1773+
llvm_unreachable("illegal reduction operator");
1774+
}
1775+
17441776
/// Collect DO CONCURRENT or FORALL loop control information.
17451777
IncrementLoopNestInfo getConcurrentControl(
17461778
const Fortran::parser::ConcurrentHeader &header,
@@ -1763,6 +1795,16 @@ class FirConverter : public Fortran::lower::AbstractConverter {
17631795
std::get_if<Fortran::parser::LocalitySpec::LocalInit>(&x.u))
17641796
for (const Fortran::parser::Name &x : localInitList->v)
17651797
info.localInitSymList.push_back(x.symbol);
1798+
if (const auto *reduceList =
1799+
std::get_if<Fortran::parser::LocalitySpec::Reduce>(&x.u)) {
1800+
fir::ReduceOperationEnum reduce_operation = getReduceOperationEnum(
1801+
std::get<Fortran::parser::ReductionOperator>(reduceList->t));
1802+
for (const Fortran::parser::Name &x :
1803+
std::get<std::list<Fortran::parser::Name>>(reduceList->t)) {
1804+
info.reduceSymList.push_back(
1805+
std::make_pair(reduce_operation, x.symbol));
1806+
}
1807+
}
17661808
if (const auto *sharedList =
17671809
std::get_if<Fortran::parser::LocalitySpec::Shared>(&x.u))
17681810
for (const Fortran::parser::Name &x : sharedList->v)
@@ -1955,9 +1997,23 @@ class FirConverter : public Fortran::lower::AbstractConverter {
19551997
mlir::Type loopVarType = info.getLoopVariableType();
19561998
mlir::Value loopValue;
19571999
if (info.isUnordered) {
2000+
llvm::SmallVector<mlir::Value> reduceOperands;
2001+
llvm::SmallVector<mlir::Attribute> reduceAttrs;
2002+
// Create DO CONCURRENT reduce operands and attributes
2003+
for (const auto reduceSym : info.reduceSymList) {
2004+
const fir::ReduceOperationEnum reduce_operation = reduceSym.first;
2005+
const Fortran::semantics::Symbol *sym = reduceSym.second;
2006+
fir::ExtendedValue exv = getSymbolExtendedValue(*sym, nullptr);
2007+
reduceOperands.push_back(fir::getBase(exv));
2008+
auto reduce_attr =
2009+
fir::ReduceAttr::get(builder->getContext(), reduce_operation);
2010+
reduceAttrs.push_back(reduce_attr);
2011+
}
19582012
// The loop variable value is explicitly updated.
19592013
info.doLoop = builder->create<fir::DoLoopOp>(
1960-
loc, lowerValue, upperValue, stepValue, /*unordered=*/true);
2014+
loc, lowerValue, upperValue, stepValue, /*unordered=*/true,
2015+
/*finalCountValue=*/false, /*iterArgs=*/std::nullopt,
2016+
llvm::ArrayRef<mlir::Value>(reduceOperands), reduceAttrs);
19612017
builder->setInsertionPointToStart(info.doLoop.getBody());
19622018
loopValue = builder->createConvert(loc, loopVarType,
19632019
info.doLoop.getInductionVar());

flang/test/Lower/loops3.f90

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
! Test do concurrent reduction
2+
! RUN: bbc -emit-fir -hlfir=false -o - %s | FileCheck %s
3+
4+
! CHECK-LABEL: loop_test
5+
subroutine loop_test
6+
integer(4) :: i, j, k, tmp, sum = 0
7+
real :: m
8+
9+
i = 100
10+
j = 200
11+
k = 300
12+
13+
! CHECK: %[[VAL_0:.*]] = fir.alloca f32 {bindc_name = "m", uniq_name = "_QFloop_testEm"}
14+
! CHECK: %[[VAL_1:.*]] = fir.address_of(@_QFloop_testEsum) : !fir.ref<i32>
15+
! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
16+
! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered {
17+
! CHECK: fir.do_loop %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} unordered reduce(#fir.reduce_attr<add> -> %[[VAL_1:.*]] : !fir.ref<i32>, #fir.reduce_attr<max> -> %[[VAL_0:.*]] : !fir.ref<f32>) {
18+
do concurrent (i=1:5, j=1:5, k=1:5) local(tmp) reduce(+:sum) reduce(max:m)
19+
tmp = i + j + k
20+
sum = tmp + sum
21+
m = max(m, sum)
22+
enddo
23+
end subroutine loop_test

0 commit comments

Comments
 (0)