@@ -48,6 +48,11 @@ static ArrayAttr makeArrayAttr(MLIRContext *context,
48
48
return attrs.empty () ? nullptr : ArrayAttr::get (context, attrs);
49
49
}
50
50
51
+ static DenseBoolArrayAttr
52
+ makeDenseBoolArrayAttr (MLIRContext *ctx, const ArrayRef<bool > boolArray) {
53
+ return boolArray.empty () ? nullptr : DenseBoolArrayAttr::get (ctx, boolArray);
54
+ }
55
+
51
56
namespace {
52
57
struct MemRefPointerLikeModel
53
58
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -499,7 +504,7 @@ static ParseResult parseClauseWithRegionArgs(
499
504
return success ();
500
505
})))
501
506
return failure ();
502
- isByRef = DenseBoolArrayAttr::get (parser.getContext (), isByRefVec);
507
+ isByRef = makeDenseBoolArrayAttr (parser.getContext (), isByRefVec);
503
508
504
509
auto *argsBegin = regionPrivateArgs.begin ();
505
510
MutableArrayRef argsSubrange (argsBegin + regionArgOffset,
@@ -591,7 +596,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
591
596
mlir::SmallVector<bool > isByRefVec;
592
597
isByRefVec.resize (privateVarTypes.size (), false );
593
598
DenseBoolArrayAttr isByRef =
594
- DenseBoolArrayAttr::get (op->getContext (), isByRefVec);
599
+ makeDenseBoolArrayAttr (op->getContext (), isByRefVec);
595
600
596
601
printClauseWithRegionArgs (p, op, argsSubrange, " private" ,
597
602
privateVarOperands, privateVarTypes, isByRef,
@@ -607,18 +612,22 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
607
612
static ParseResult
608
613
parseReductionVarList (OpAsmParser &parser,
609
614
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
610
- SmallVectorImpl<Type> &types,
615
+ SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef,
611
616
ArrayAttr &redcuctionSymbols) {
612
617
SmallVector<SymbolRefAttr> reductionVec;
618
+ SmallVector<bool > isByRefVec;
613
619
if (failed (parser.parseCommaSeparatedList ([&]() {
620
+ ParseResult optionalByref = parser.parseOptionalKeyword (" byref" );
614
621
if (parser.parseAttribute (reductionVec.emplace_back ()) ||
615
622
parser.parseArrow () ||
616
623
parser.parseOperand (operands.emplace_back ()) ||
617
624
parser.parseColonType (types.emplace_back ()))
618
625
return failure ();
626
+ isByRefVec.push_back (optionalByref.succeeded ());
619
627
return success ();
620
628
})))
621
629
return failure ();
630
+ isByRef = makeDenseBoolArrayAttr (parser.getContext (), isByRefVec);
622
631
SmallVector<Attribute> reductions (reductionVec.begin (), reductionVec.end ());
623
632
redcuctionSymbols = ArrayAttr::get (parser.getContext (), reductions);
624
633
return success ();
@@ -628,11 +637,21 @@ parseReductionVarList(OpAsmParser &parser,
628
637
static void printReductionVarList (OpAsmPrinter &p, Operation *op,
629
638
OperandRange reductionVars,
630
639
TypeRange reductionTypes,
640
+ std::optional<DenseBoolArrayAttr> isByRef,
631
641
std::optional<ArrayAttr> reductions) {
632
- for (unsigned i = 0 , e = reductions->size (); i < e; ++i) {
642
+ auto getByRef = [&](unsigned i) -> const char * {
643
+ if (!isByRef || !*isByRef)
644
+ return " " ;
645
+ assert (isByRef->empty () || i < isByRef->size ());
646
+ if (!isByRef->empty () && (*isByRef)[i])
647
+ return " byref " ;
648
+ return " " ;
649
+ };
650
+
651
+ for (unsigned i = 0 , e = reductionVars.size (); i < e; ++i) {
633
652
if (i != 0 )
634
653
p << " , " ;
635
- p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
654
+ p << getByRef (i) << (*reductions)[i] << " -> " << reductionVars[i] << " : "
636
655
<< reductionVars[i].getType ();
637
656
}
638
657
}
@@ -641,16 +660,12 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
641
660
static LogicalResult
642
661
verifyReductionVarList (Operation *op, std::optional<ArrayAttr> reductions,
643
662
OperandRange reductionVars,
644
- std::optional<ArrayRef<bool >> byRef = std::nullopt ) {
663
+ std::optional<ArrayRef<bool >> byRef) {
645
664
if (!reductionVars.empty ()) {
646
665
if (!reductions || reductions->size () != reductionVars.size ())
647
666
return op->emitOpError ()
648
667
<< " expected as many reduction symbol references "
649
668
" as reduction variables" ;
650
- if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
651
- assert (byRef);
652
- else
653
- assert (!byRef); // TODO: support byref reductions on other operations
654
669
if (byRef && byRef->size () != reductionVars.size ())
655
670
return op->emitError () << " expected as many reduction variable by "
656
671
" reference attributes as reduction variables" ;
@@ -1492,7 +1507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
1492
1507
ParallelOp::build (builder, state, clauses.ifVar , clauses.numThreadsVar ,
1493
1508
clauses.allocateVars , clauses.allocatorVars ,
1494
1509
clauses.reductionVars ,
1495
- DenseBoolArrayAttr::get (ctx, clauses.reduceVarByRef ),
1510
+ makeDenseBoolArrayAttr (ctx, clauses.reductionVarsByRef ),
1496
1511
makeArrayAttr (ctx, clauses.reductionDeclSymbols ),
1497
1512
clauses.procBindKindAttr , clauses.privateVars ,
1498
1513
makeArrayAttr (ctx, clauses.privatizers ));
@@ -1590,6 +1605,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
1590
1605
clauses.numTeamsUpperVar , clauses.ifVar ,
1591
1606
clauses.threadLimitVar , clauses.allocateVars ,
1592
1607
clauses.allocatorVars , clauses.reductionVars ,
1608
+ makeDenseBoolArrayAttr (ctx, clauses.reductionVarsByRef ),
1593
1609
makeArrayAttr (ctx, clauses.reductionDeclSymbols ));
1594
1610
}
1595
1611
@@ -1621,7 +1637,8 @@ LogicalResult TeamsOp::verify() {
1621
1637
return emitError (
1622
1638
" expected equal sizes for allocate and allocator variables" );
1623
1639
1624
- return verifyReductionVarList (*this , getReductions (), getReductionVars ());
1640
+ return verifyReductionVarList (*this , getReductions (), getReductionVars (),
1641
+ getReductionVarsByref ());
1625
1642
}
1626
1643
1627
1644
// ===----------------------------------------------------------------------===//
@@ -1633,6 +1650,7 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
1633
1650
MLIRContext *ctx = builder.getContext ();
1634
1651
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
1635
1652
SectionsOp::build (builder, state, clauses.reductionVars ,
1653
+ makeDenseBoolArrayAttr (ctx, clauses.reductionVarsByRef ),
1636
1654
makeArrayAttr (ctx, clauses.reductionDeclSymbols ),
1637
1655
clauses.allocateVars , clauses.allocatorVars ,
1638
1656
clauses.nowaitAttr );
@@ -1643,7 +1661,8 @@ LogicalResult SectionsOp::verify() {
1643
1661
return emitError (
1644
1662
" expected equal sizes for allocate and allocator variables" );
1645
1663
1646
- return verifyReductionVarList (*this , getReductions (), getReductionVars ());
1664
+ return verifyReductionVarList (*this , getReductions (), getReductionVars (),
1665
+ getReductionVarsByref ());
1647
1666
}
1648
1667
1649
1668
LogicalResult SectionsOp::verifyRegions () {
@@ -1733,7 +1752,7 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
1733
1752
// privatizers.
1734
1753
WsloopOp::build (builder, state, clauses.linearVars , clauses.linearStepVars ,
1735
1754
clauses.reductionVars ,
1736
- DenseBoolArrayAttr::get (ctx, clauses.reduceVarByRef ),
1755
+ makeDenseBoolArrayAttr (ctx, clauses.reductionVarsByRef ),
1737
1756
makeArrayAttr (ctx, clauses.reductionDeclSymbols ),
1738
1757
clauses.scheduleValAttr , clauses.scheduleChunkVar ,
1739
1758
clauses.scheduleModAttr , clauses.scheduleSimdAttr ,
@@ -1934,6 +1953,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
1934
1953
TaskOp::build (
1935
1954
builder, state, clauses.ifVar , clauses.finalVar , clauses.untiedAttr ,
1936
1955
clauses.mergeableAttr , clauses.inReductionVars ,
1956
+ makeDenseBoolArrayAttr (ctx, clauses.inReductionVarsByRef ),
1937
1957
makeArrayAttr (ctx, clauses.inReductionDeclSymbols ), clauses.priorityVar ,
1938
1958
makeArrayAttr (ctx, clauses.dependTypeAttrs ), clauses.dependVars ,
1939
1959
clauses.allocateVars , clauses.allocatorVars );
@@ -1945,7 +1965,8 @@ LogicalResult TaskOp::verify() {
1945
1965
return failed (verifyDependVars)
1946
1966
? verifyDependVars
1947
1967
: verifyReductionVarList (*this , getInReductions (),
1948
- getInReductionVars ());
1968
+ getInReductionVars (),
1969
+ getInReductionVarsByref ());
1949
1970
}
1950
1971
1951
1972
// ===----------------------------------------------------------------------===//
@@ -1955,14 +1976,17 @@ LogicalResult TaskOp::verify() {
1955
1976
void TaskgroupOp::build (OpBuilder &builder, OperationState &state,
1956
1977
const TaskgroupClauseOps &clauses) {
1957
1978
MLIRContext *ctx = builder.getContext ();
1958
- TaskgroupOp::build (builder, state, clauses.taskReductionVars ,
1959
- makeArrayAttr (ctx, clauses.taskReductionDeclSymbols ),
1960
- clauses.allocateVars , clauses.allocatorVars );
1979
+ TaskgroupOp::build (
1980
+ builder, state, clauses.taskReductionVars ,
1981
+ makeDenseBoolArrayAttr (ctx, clauses.taskReductionVarsByRef ),
1982
+ makeArrayAttr (ctx, clauses.taskReductionDeclSymbols ),
1983
+ clauses.allocateVars , clauses.allocatorVars );
1961
1984
}
1962
1985
1963
1986
LogicalResult TaskgroupOp::verify () {
1964
1987
return verifyReductionVarList (*this , getTaskReductions (),
1965
- getTaskReductionVars ());
1988
+ getTaskReductionVars (),
1989
+ getTaskReductionVarsByref ());
1966
1990
}
1967
1991
1968
1992
// ===----------------------------------------------------------------------===//
@@ -1976,7 +2000,9 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
1976
2000
TaskloopOp::build (
1977
2001
builder, state, clauses.ifVar , clauses.finalVar , clauses.untiedAttr ,
1978
2002
clauses.mergeableAttr , clauses.inReductionVars ,
2003
+ makeDenseBoolArrayAttr (ctx, clauses.inReductionVarsByRef ),
1979
2004
makeArrayAttr (ctx, clauses.inReductionDeclSymbols ), clauses.reductionVars ,
2005
+ makeDenseBoolArrayAttr (ctx, clauses.reductionVarsByRef ),
1980
2006
makeArrayAttr (ctx, clauses.reductionDeclSymbols ), clauses.priorityVar ,
1981
2007
clauses.allocateVars , clauses.allocatorVars , clauses.grainsizeVar ,
1982
2008
clauses.numTasksVar , clauses.nogroupAttr );
@@ -1994,10 +2020,11 @@ LogicalResult TaskloopOp::verify() {
1994
2020
if (getAllocateVars ().size () != getAllocatorsVars ().size ())
1995
2021
return emitError (
1996
2022
" expected equal sizes for allocate and allocator variables" );
1997
- if (failed (
1998
- verifyReductionVarList (* this , getReductions (), getReductionVars ())) ||
2023
+ if (failed (verifyReductionVarList (* this , getReductions (), getReductionVars (),
2024
+ getReductionVarsByref ())) ||
1999
2025
failed (verifyReductionVarList (*this , getInReductions (),
2000
- getInReductionVars ())))
2026
+ getInReductionVars (),
2027
+ getInReductionVarsByref ())))
2001
2028
return failure ();
2002
2029
2003
2030
if (!getReductionVars ().empty () && getNogroup ())
0 commit comments