@@ -2456,9 +2456,16 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
2456
2456
mlir::OperationState &result, mlir::Value lb,
2457
2457
mlir::Value ub, mlir::Value step, bool unordered,
2458
2458
bool finalCountValue, mlir::ValueRange iterArgs,
2459
+ mlir::ValueRange reduceOperands,
2460
+ llvm::ArrayRef<mlir::Attribute> reduceAttrs,
2459
2461
llvm::ArrayRef<mlir::NamedAttribute> attributes) {
2460
2462
result.addOperands ({lb, ub, step});
2463
+ result.addOperands (reduceOperands);
2461
2464
result.addOperands (iterArgs);
2465
+ result.addAttribute (getOperandSegmentSizeAttr (),
2466
+ builder.getDenseI32ArrayAttr (
2467
+ {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2468
+ static_cast <int32_t >(iterArgs.size ())}));
2462
2469
if (finalCountValue) {
2463
2470
result.addTypes (builder.getIndexType ());
2464
2471
result.addAttribute (getFinalValueAttrName (result.name ),
@@ -2477,6 +2484,9 @@ void fir::DoLoopOp::build(mlir::OpBuilder &builder,
2477
2484
if (unordered)
2478
2485
result.addAttribute (getUnorderedAttrName (result.name ),
2479
2486
builder.getUnitAttr ());
2487
+ if (!reduceAttrs.empty ())
2488
+ result.addAttribute (getReduceAttrsAttrName (result.name ),
2489
+ builder.getArrayAttr (reduceAttrs));
2480
2490
result.addAttributes (attributes);
2481
2491
}
2482
2492
@@ -2502,24 +2512,51 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
2502
2512
if (mlir::succeeded (parser.parseOptionalKeyword (" unordered" )))
2503
2513
result.addAttribute (" unordered" , builder.getUnitAttr ());
2504
2514
2515
+ // Parse the reduction arguments.
2516
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
2517
+ llvm::SmallVector<mlir::Type> reduceArgTypes;
2518
+ if (succeeded (parser.parseOptionalKeyword (" reduce" ))) {
2519
+ // Parse reduction attributes and variables.
2520
+ llvm::SmallVector<ReduceAttr> attributes;
2521
+ if (failed (parser.parseCommaSeparatedList (
2522
+ mlir::AsmParser::Delimiter::Paren, [&]() {
2523
+ if (parser.parseAttribute (attributes.emplace_back ()) ||
2524
+ parser.parseArrow () ||
2525
+ parser.parseOperand (reduceOperands.emplace_back ()) ||
2526
+ parser.parseColonType (reduceArgTypes.emplace_back ()))
2527
+ return mlir::failure ();
2528
+ return mlir::success ();
2529
+ })))
2530
+ return mlir::failure ();
2531
+ // Resolve input operands.
2532
+ for (auto operand_type : llvm::zip (reduceOperands, reduceArgTypes))
2533
+ if (parser.resolveOperand (std::get<0 >(operand_type),
2534
+ std::get<1 >(operand_type), result.operands ))
2535
+ return mlir::failure ();
2536
+ llvm::SmallVector<mlir::Attribute> arrayAttr (attributes.begin (),
2537
+ attributes.end ());
2538
+ result.addAttribute (getReduceAttrsAttrName (result.name ),
2539
+ builder.getArrayAttr (arrayAttr));
2540
+ }
2541
+
2505
2542
// Parse the optional initial iteration arguments.
2506
2543
llvm::SmallVector<mlir::OpAsmParser::Argument> regionArgs;
2507
- llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands ;
2544
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> iterOperands ;
2508
2545
llvm::SmallVector<mlir::Type> argTypes;
2509
2546
bool prependCount = false ;
2510
2547
regionArgs.push_back (inductionVariable);
2511
2548
2512
2549
if (succeeded (parser.parseOptionalKeyword (" iter_args" ))) {
2513
2550
// Parse assignment list and results type list.
2514
- if (parser.parseAssignmentList (regionArgs, operands ) ||
2551
+ if (parser.parseAssignmentList (regionArgs, iterOperands ) ||
2515
2552
parser.parseArrowTypeList (result.types ))
2516
2553
return mlir::failure ();
2517
- if (result.types .size () == operands .size () + 1 )
2554
+ if (result.types .size () == iterOperands .size () + 1 )
2518
2555
prependCount = true ;
2519
2556
// Resolve input operands.
2520
2557
llvm::ArrayRef<mlir::Type> resTypes = result.types ;
2521
- for (auto operand_type :
2522
- llvm::zip (operands , prependCount ? resTypes.drop_front () : resTypes))
2558
+ for (auto operand_type : llvm::zip (
2559
+ iterOperands , prependCount ? resTypes.drop_front () : resTypes))
2523
2560
if (parser.resolveOperand (std::get<0 >(operand_type),
2524
2561
std::get<1 >(operand_type), result.operands ))
2525
2562
return mlir::failure ();
@@ -2530,6 +2567,12 @@ mlir::ParseResult fir::DoLoopOp::parse(mlir::OpAsmParser &parser,
2530
2567
prependCount = true ;
2531
2568
}
2532
2569
2570
+ // Set the operandSegmentSizes attribute
2571
+ result.addAttribute (getOperandSegmentSizeAttr (),
2572
+ builder.getDenseI32ArrayAttr (
2573
+ {1 , 1 , 1 , static_cast <int32_t >(reduceOperands.size ()),
2574
+ static_cast <int32_t >(iterOperands.size ())}));
2575
+
2533
2576
if (parser.parseOptionalAttrDictWithKeyword (result.attributes ))
2534
2577
return mlir::failure ();
2535
2578
@@ -2606,6 +2649,10 @@ mlir::LogicalResult fir::DoLoopOp::verify() {
2606
2649
2607
2650
i++;
2608
2651
}
2652
+ auto reduceAttrs = getReduceAttrsAttr ();
2653
+ if (getNumReduceOperands () != (reduceAttrs ? reduceAttrs.size () : 0 ))
2654
+ return emitOpError (
2655
+ " mismatch in number of reduction variables and reduction attributes" );
2609
2656
return mlir::success ();
2610
2657
}
2611
2658
@@ -2615,6 +2662,17 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
2615
2662
<< getUpperBound () << " step " << getStep ();
2616
2663
if (getUnordered ())
2617
2664
p << " unordered" ;
2665
+ if (hasReduceOperands ()) {
2666
+ p << " reduce(" ;
2667
+ auto attrs = getReduceAttrsAttr ();
2668
+ auto operands = getReduceOperands ();
2669
+ llvm::interleaveComma (llvm::zip (attrs, operands), p, [&](auto it) {
2670
+ p << std::get<0 >(it) << " -> " << std::get<1 >(it) << " : "
2671
+ << std::get<1 >(it).getType ();
2672
+ });
2673
+ p << ' )' ;
2674
+ printBlockTerminators = true ;
2675
+ }
2618
2676
if (hasIterOperands ()) {
2619
2677
p << " iter_args(" ;
2620
2678
auto regionArgs = getRegionIterArgs ();
@@ -2628,8 +2686,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
2628
2686
p << " -> " << getResultTypes ();
2629
2687
printBlockTerminators = true ;
2630
2688
}
2631
- p.printOptionalAttrDictWithKeyword ((*this )->getAttrs (),
2632
- {" unordered" , " finalValue" });
2689
+ p.printOptionalAttrDictWithKeyword (
2690
+ (*this )->getAttrs (),
2691
+ {" unordered" , " finalValue" , " reduceAttrs" , " operandSegmentSizes" });
2633
2692
p << ' ' ;
2634
2693
p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false ,
2635
2694
printBlockTerminators);
0 commit comments