Skip to content

Commit 88cdd99

Browse files
authored
[flang] Add reduction semantics to fir.do_loop (#93934)
Derived from #92480. This PR introduces reduction semantics into loops for DO CONCURRENT REDUCE. The `fir.do_loop` operation now invisibly has the `operandSegmentsizes` attribute and takes variable-length reduction operands with their operations given as `fir.reduce_attr`. For the sake of compatibility, `fir.do_loop`'s builder has additional arguments at the end. The `iter_args` operand should be placed in front of the declaration of result types, so the new operand for reduction variables (`reduce`) is put in the middle of arguments.
1 parent 9eac38a commit 88cdd99

File tree

5 files changed

+142
-17
lines changed

5 files changed

+142
-17
lines changed

flang/include/flang/Optimizer/Dialect/FIRAttr.td

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,36 @@ def fir_BoxFieldAttr : I32EnumAttr<
6767
let cppNamespace = "fir";
6868
}
6969

70+
def fir_ReduceOperationEnum : I32BitEnumAttr<"ReduceOperationEnum",
71+
"intrinsic operations and functions supported by DO CONCURRENT REDUCE",
72+
[
73+
I32BitEnumAttrCaseBit<"Add", 0, "add">,
74+
I32BitEnumAttrCaseBit<"Multiply", 1, "multiply">,
75+
I32BitEnumAttrCaseBit<"AND", 2, "and">,
76+
I32BitEnumAttrCaseBit<"OR", 3, "or">,
77+
I32BitEnumAttrCaseBit<"EQV", 4, "eqv">,
78+
I32BitEnumAttrCaseBit<"NEQV", 5, "neqv">,
79+
I32BitEnumAttrCaseBit<"MAX", 6, "max">,
80+
I32BitEnumAttrCaseBit<"MIN", 7, "min">,
81+
I32BitEnumAttrCaseBit<"IAND", 8, "iand">,
82+
I32BitEnumAttrCaseBit<"IOR", 9, "ior">,
83+
I32BitEnumAttrCaseBit<"EIOR", 10, "eior">
84+
]> {
85+
let separator = ", ";
86+
let cppNamespace = "::fir";
87+
let printBitEnumPrimaryGroups = 1;
88+
}
89+
90+
def fir_ReduceAttr : fir_Attr<"Reduce"> {
91+
let mnemonic = "reduce_attr";
92+
93+
let parameters = (ins
94+
"ReduceOperationEnum":$reduce_operation
95+
);
96+
97+
let assemblyFormat = "`<` $reduce_operation `>`";
98+
}
99+
70100
// mlir::SideEffects::Resource for modelling operations which add debugging information
71101
def DebuggingResource : Resource<"::fir::DebuggingResource">;
72102

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,8 +2125,8 @@ class region_Op<string mnemonic, list<Trait> traits = []> :
21252125
let hasVerifier = 1;
21262126
}
21272127

2128-
def fir_DoLoopOp : region_Op<"do_loop",
2129-
[DeclareOpInterfaceMethods<LoopLikeOpInterface,
2128+
def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
2129+
DeclareOpInterfaceMethods<LoopLikeOpInterface,
21302130
["getYieldedValuesMutable"]>]> {
21312131
let summary = "generalized loop operation";
21322132
let description = [{
@@ -2156,9 +2156,11 @@ def fir_DoLoopOp : region_Op<"do_loop",
21562156
Index:$lowerBound,
21572157
Index:$upperBound,
21582158
Index:$step,
2159+
Variadic<AnyType>:$reduceOperands,
21592160
Variadic<AnyType>:$initArgs,
21602161
OptionalAttr<UnitAttr>:$unordered,
2161-
OptionalAttr<UnitAttr>:$finalValue
2162+
OptionalAttr<UnitAttr>:$finalValue,
2163+
OptionalAttr<ArrayAttr>:$reduceAttrs
21622164
);
21632165
let results = (outs Variadic<AnyType>:$results);
21642166
let regions = (region SizedRegion<1>:$region);
@@ -2169,6 +2171,8 @@ def fir_DoLoopOp : region_Op<"do_loop",
21692171
"mlir::Value":$step, CArg<"bool", "false">:$unordered,
21702172
CArg<"bool", "false">:$finalCountValue,
21712173
CArg<"mlir::ValueRange", "std::nullopt">:$iterArgs,
2174+
CArg<"mlir::ValueRange", "std::nullopt">:$reduceOperands,
2175+
CArg<"llvm::ArrayRef<mlir::Attribute>", "{}">:$reduceAttrs,
21722176
CArg<"llvm::ArrayRef<mlir::NamedAttribute>", "{}">:$attributes)>
21732177
];
21742178

@@ -2181,11 +2185,12 @@ def fir_DoLoopOp : region_Op<"do_loop",
21812185
return getBody()->getArguments().drop_front();
21822186
}
21832187
mlir::Operation::operand_range getIterOperands() {
2184-
return getOperands().drop_front(getNumControlOperands());
2188+
return getOperands()
2189+
.drop_front(getNumControlOperands() + getNumReduceOperands());
21852190
}
21862191
llvm::MutableArrayRef<mlir::OpOperand> getInitsMutable() {
2187-
return
2188-
getOperation()->getOpOperands().drop_front(getNumControlOperands());
2192+
return getOperation()->getOpOperands()
2193+
.drop_front(getNumControlOperands() + getNumReduceOperands());
21892194
}
21902195

21912196
void setLowerBound(mlir::Value bound) { (*this)->setOperand(0, bound); }
@@ -2200,11 +2205,25 @@ def fir_DoLoopOp : region_Op<"do_loop",
22002205
unsigned getNumControlOperands() { return 3; }
22012206
/// Does the operation hold operands for loop-carried values
22022207
bool hasIterOperands() {
2203-
return (*this)->getNumOperands() > getNumControlOperands();
2208+
return getNumIterOperands() > 0;
2209+
}
2210+
/// Does the operation hold operands for reduction variables
2211+
bool hasReduceOperands() {
2212+
return getNumReduceOperands() > 0;
2213+
}
2214+
/// Get Number of variadic operands
2215+
unsigned getNumOperands(unsigned idx) {
2216+
auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
2217+
getOperandSegmentSizeAttr());
2218+
return static_cast<unsigned>(segments[idx]);
2219+
}
2220+
// Get Number of reduction operands
2221+
unsigned getNumReduceOperands() {
2222+
return getNumOperands(3);
22042223
}
22052224
/// Get Number of loop-carried values
22062225
unsigned getNumIterOperands() {
2207-
return (*this)->getNumOperands() - getNumControlOperands();
2226+
return getNumOperands(4);
22082227
}
22092228

22102229
/// Get the body of the loop

flang/lib/Optimizer/Dialect/FIRAttr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,6 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
297297

298298
void FIROpsDialect::registerAttributes() {
299299
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
300-
LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
301-
UpperBoundAttr>();
300+
LowerBoundAttr, PointIntervalAttr, RealAttr, ReduceAttr,
301+
SubclassAttr, UpperBoundAttr>();
302302
}

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,9 +2456,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
24562456
mlir::OperationState &result, mlir::Value lb,
24572457
mlir::Value ub, mlir::Value step, bool unordered,
24582458
bool finalCountValue, mlir::ValueRange iterArgs,
2459+
mlir::ValueRange reduceOperands,
2460+
llvm::ArrayRef<mlir::Attribute> reduceAttrs,
24592461
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
24602462
result.addOperands({lb, ub, step});
2463+
result.addOperands(reduceOperands);
24612464
result.addOperands(iterArgs);
2465+
result.addAttribute(getOperandSegmentSizeAttr(),
2466+
builder.getDenseI32ArrayAttr(
2467+
{1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
2468+
static_cast<int32_t>(iterArgs.size())}));
24622469
if (finalCountValue) {
24632470
result.addTypes(builder.getIndexType());
24642471
result.addAttribute(getFinalValueAttrName(result.name),
@@ -2477,6 +2484,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
24772484
if (unordered)
24782485
result.addAttribute(getUnorderedAttrName(result.name),
24792486
builder.getUnitAttr());
2487+
if (!reduceAttrs.empty())
2488+
result.addAttribute(getReduceAttrsAttrName(result.name),
2489+
builder.getArrayAttr(reduceAttrs));
24802490
result.addAttributes(attributes);
24812491
}
24822492

@@ -2502,24 +2512,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
25022512
if (mlir::succeeded(parser.parseOptionalKeyword("unordered")))
25032513
result.addAttribute("unordered", builder.getUnitAttr());
25042514

2515+
// Parse the reduction arguments.
2516+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
2517+
llvm::SmallVector<mlir::Type> reduceArgTypes;
2518+
if (succeeded(parser.parseOptionalKeyword("reduce"))) {
2519+
// Parse reduction attributes and variables.
2520+
llvm::SmallVector<ReduceAttr> attributes;
2521+
if (failed(parser.parseCommaSeparatedList(
2522+
mlir::AsmParser::Delimiter::Paren, [&]() {
2523+
if (parser.parseAttribute(attributes.emplace_back()) ||
2524+
parser.parseArrow() ||
2525+
parser.parseOperand(reduceOperands.emplace_back()) ||
2526+
parser.parseColonType(reduceArgTypes.emplace_back()))
2527+
return mlir::failure();
2528+
return mlir::success();
2529+
})))
2530+
return mlir::failure();
2531+
// Resolve input operands.
2532+
for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
2533+
if (parser.resolveOperand(std::get<0>(operand_type),
2534+
std::get<1>(operand_type), result.operands))
2535+
return mlir::failure();
2536+
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2537+
attributes.end());
2538+
result.addAttribute(getReduceAttrsAttrName(result.name),
2539+
builder.getArrayAttr(arrayAttr));
2540+
}
2541+
25052542
// Parse the optional initial iteration arguments.
25062543
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2507-
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
2544+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands;
25082545
llvm::SmallVector<mlir::Type> argTypes;
25092546
bool prependCount = false;
25102547
regionArgs.push_back(inductionVariable);
25112548

25122549
if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
25132550
// Parse assignment list and results type list.
2514-
if (parser.parseAssignmentList(regionArgs, operands) ||
2551+
if (parser.parseAssignmentList(regionArgs, iterOperands) ||
25152552
parser.parseArrowTypeList(result.types))
25162553
return mlir::failure();
2517-
if (result.types.size() == operands.size() + 1)
2554+
if (result.types.size() == iterOperands.size() + 1)
25182555
prependCount = true;
25192556
// Resolve input operands.
25202557
llvm::ArrayRef<mlir::Type> resTypes = result.types;
2521-
for (auto operand_type :
2522-
llvm::zip(operands, prependCount ? resTypes.drop_front() : resTypes))
2558+
for (auto operand_type : llvm::zip(
2559+
iterOperands, prependCount ? resTypes.drop_front() : resTypes))
25232560
if (parser.resolveOperand(std::get<0>(operand_type),
25242561
std::get<1>(operand_type), result.operands))
25252562
return mlir::failure();
@@ -2530,6 +2567,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
25302567
prependCount = true;
25312568
}
25322569

2570+
// Set the operandSegmentSizes attribute
2571+
result.addAttribute(getOperandSegmentSizeAttr(),
2572+
builder.getDenseI32ArrayAttr(
2573+
{1, 1, 1, static_cast<int32_t>(reduceOperands.size()),
2574+
static_cast<int32_t>(iterOperands.size())}));
2575+
25332576
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
25342577
return mlir::failure();
25352578

@@ -2606,6 +2649,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
26062649

26072650
i++;
26082651
}
2652+
auto reduceAttrs = getReduceAttrsAttr();
2653+
if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
2654+
return emitOpError(
2655+
"mismatch in number of reduction variables and reduction attributes");
26092656
return mlir::success();
26102657
}
26112658

@@ -2615,6 +2662,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
26152662
<< getUpperBound() << " step " << getStep();
26162663
if (getUnordered())
26172664
p << " unordered";
2665+
if (hasReduceOperands()) {
2666+
p << " reduce(";
2667+
auto attrs = getReduceAttrsAttr();
2668+
auto operands = getReduceOperands();
2669+
llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
2670+
p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
2671+
<< std::get<1>(it).getType();
2672+
});
2673+
p << ')';
2674+
printBlockTerminators = true;
2675+
}
26182676
if (hasIterOperands()) {
26192677
p << " iter_args(";
26202678
auto regionArgs = getRegionIterArgs();
@@ -2628,8 +2686,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
26282686
p << " -> " << getResultTypes();
26292687
printBlockTerminators = true;
26302688
}
2631-
p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
2632-
{"unordered", "finalValue"});
2689+
p.printOptionalAttrDictWithKeyword(
2690+
(*this)->getAttrs(),
2691+
{"unordered", "finalValue", "reduceAttrs", "operandSegmentSizes"});
26332692
p << ' ';
26342693
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
26352694
printBlockTerminators);

flang/test/Fir/loop03.fir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Test the reduction semantics of fir.do_loop
2+
// RUN: fir-opt %s | FileCheck %s
3+
4+
func.func @reduction() {
5+
%bound = arith.constant 10 : index
6+
%step = arith.constant 1 : index
7+
%sum = fir.alloca i32
8+
// CHECK: %[[VAL_0:.*]] = fir.alloca i32
9+
// CHECK: fir.do_loop %[[VAL_1:.*]] = %[[VAL_2:.*]] to %[[VAL_3:.*]] step %[[VAL_4:.*]] unordered reduce(#fir.reduce_attr<add> -> %[[VAL_0]] : !fir.ref<i32>) {
10+
fir.do_loop %iv = %step to %bound step %step unordered reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
11+
%index = fir.convert %iv : (index) -> i32
12+
%1 = fir.load %sum : !fir.ref<i32>
13+
%2 = arith.addi %index, %1 : i32
14+
fir.store %2 to %sum : !fir.ref<i32>
15+
}
16+
return
17+
}

0 commit comments

Comments
 (0)