Skip to content

Commit c3518fb

Browse files
committed
[Flang][OpenMP] Improve entry block argument creation and binding
Commit cherry-picked from PR llvm#110267. Will be removed when rebasing PR stack on top of a more recent amd-trunk-dev branch.
1 parent 15949dd commit c3518fb

File tree

7 files changed

+578
-615
lines changed

7 files changed

+578
-615
lines changed

flang/include/flang/Lower/OpenMP/Utils.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,7 @@ void insertChildMapInfoIntoParent(
153153
Fortran::lower::StatementContext &stmtCtx,
154154
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
155155
llvm::SmallVectorImpl<mlir::Value> &mapOperands,
156-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes,
157-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
158-
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols);
156+
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> &mapSymbols);
159157

160158
mlir::Type getLoopVarType(lower::AbstractConverter &converter,
161159
std::size_t loopVarTypeSize);

flang/lib/Lower/OpenMP/ClauseProcessor.cpp

Lines changed: 22 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -166,15 +166,11 @@ getIfClauseOperand(lower::AbstractConverter &converter,
166166
static void addUseDeviceClause(
167167
lower::AbstractConverter &converter, const omp::ObjectList &objects,
168168
llvm::SmallVectorImpl<mlir::Value> &operands,
169-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
170-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
171169
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) {
172170
genObjectList(objects, converter, operands);
173-
for (mlir::Value &operand : operands) {
171+
for (mlir::Value &operand : operands)
174172
checkMapType(operand.getLoc(), operand.getType());
175-
useDeviceTypes.push_back(operand.getType());
176-
useDeviceLocs.push_back(operand.getLoc());
177-
}
173+
178174
for (const omp::Object &object : objects)
179175
useDeviceSyms.push_back(object.sym());
180176
}
@@ -832,14 +828,12 @@ bool ClauseProcessor::processDepend(mlir::omp::DependClauseOps &result) const {
832828

833829
bool ClauseProcessor::processHasDeviceAddr(
834830
mlir::omp::HasDeviceAddrClauseOps &result,
835-
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
836-
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
837-
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const {
831+
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
838832
return findRepeatableClause<omp::clause::HasDeviceAddr>(
839833
[&](const omp::clause::HasDeviceAddr &devAddrClause,
840834
const parser::CharBlock &) {
841835
addUseDeviceClause(converter, devAddrClause.v, result.hasDeviceAddrVars,
842-
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
836+
isDeviceSyms);
843837
});
844838
}
845839

@@ -864,14 +858,12 @@ bool ClauseProcessor::processIf(
864858

865859
bool ClauseProcessor::processIsDevicePtr(
866860
mlir::omp::IsDevicePtrClauseOps &result,
867-
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
868-
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
869-
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const {
861+
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const {
870862
return findRepeatableClause<omp::clause::IsDevicePtr>(
871863
[&](const omp::clause::IsDevicePtr &devPtrClause,
872864
const parser::CharBlock &) {
873865
addUseDeviceClause(converter, devPtrClause.v, result.isDevicePtrVars,
874-
isDeviceTypes, isDeviceLocs, isDeviceSymbols);
866+
isDeviceSyms);
875867
});
876868
}
877869

@@ -891,9 +883,7 @@ void ClauseProcessor::processMapObjects(
891883
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
892884
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
893885
llvm::SmallVectorImpl<mlir::Value> &mapVars,
894-
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
895-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
896-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
886+
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const {
897887
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
898888

899889
for (const omp::Object &object : objects) {
@@ -948,21 +938,15 @@ void ClauseProcessor::processMapObjects(
948938
object, parentMemberIndices[parentObj.value()], mapOp, semaCtx);
949939
} else {
950940
mapVars.push_back(mapOp);
951-
mapSyms->push_back(object.sym());
952-
if (mapSymTypes)
953-
mapSymTypes->push_back(baseOp.getType());
954-
if (mapSymLocs)
955-
mapSymLocs->push_back(baseOp.getLoc());
941+
mapSyms.push_back(object.sym());
956942
}
957943
}
958944
}
959945

960946
bool ClauseProcessor::processMap(
961947
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
962948
mlir::omp::MapClauseOps &result,
963-
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
964-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs,
965-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes) const {
949+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms) const {
966950
// We always require tracking of symbols, even if the caller does not,
967951
// so we create an optionally used local set of symbols when the mapSyms
968952
// argument is not present.
@@ -1018,13 +1002,11 @@ bool ClauseProcessor::processMap(
10181002

10191003
processMapObjects(stmtCtx, clauseLocation,
10201004
std::get<omp::ObjectList>(clause.t), mapTypeBits,
1021-
parentMemberIndices, result.mapVars, ptrMapSyms,
1022-
mapSymLocs, mapSymTypes);
1005+
parentMemberIndices, result.mapVars, *ptrMapSyms);
10231006
});
10241007

10251008
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1026-
result.mapVars, mapSymTypes, mapSymLocs,
1027-
ptrMapSyms);
1009+
result.mapVars, *ptrMapSyms);
10281010
return clauseFound;
10291011
}
10301012

@@ -1044,7 +1026,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
10441026

10451027
processMapObjects(stmtCtx, clauseLocation, std::get<ObjectList>(clause.t),
10461028
mapTypeBits, parentMemberIndices, result.mapVars,
1047-
&mapSymbols);
1029+
mapSymbols);
10481030
};
10491031

10501032
bool clauseFound = findRepeatableClause<omp::clause::To>(callbackFn);
@@ -1053,7 +1035,7 @@ bool ClauseProcessor::processMotionClauses(lower::StatementContext &stmtCtx,
10531035

10541036
insertChildMapInfoIntoParent(
10551037
converter, semaCtx, stmtCtx, parentMemberIndices, result.mapVars,
1056-
/*mapSymTypes=*/nullptr, /*mapSymLocs=*/nullptr, &mapSymbols);
1038+
mapSymbols);
10571039
return clauseFound;
10581040
}
10591041

@@ -1071,34 +1053,24 @@ bool ClauseProcessor::processNontemporal(
10711053

10721054
bool ClauseProcessor::processReduction(
10731055
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
1074-
llvm::SmallVectorImpl<mlir::Type> *outReductionTypes,
1075-
llvm::SmallVectorImpl<const semantics::Symbol *> *outReductionSyms) const {
1056+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
10761057
return findRepeatableClause<omp::clause::Reduction>(
10771058
[&](const omp::clause::Reduction &clause, const parser::CharBlock &) {
10781059
llvm::SmallVector<mlir::Value> reductionVars;
10791060
llvm::SmallVector<bool> reduceVarByRef;
10801061
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
10811062
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
10821063
ReductionProcessor rp;
1083-
rp.addDeclareReduction(
1084-
currentLocation, converter, clause, reductionVars, reduceVarByRef,
1085-
reductionDeclSymbols, outReductionSyms ? &reductionSyms : nullptr);
1064+
rp.addDeclareReduction(currentLocation, converter, clause,
1065+
reductionVars, reduceVarByRef,
1066+
reductionDeclSymbols, reductionSyms);
10861067

10871068
// Copy local lists into the output.
10881069
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
10891070
llvm::copy(reduceVarByRef, std::back_inserter(result.reductionByref));
10901071
llvm::copy(reductionDeclSymbols,
10911072
std::back_inserter(result.reductionSyms));
1092-
1093-
if (outReductionTypes) {
1094-
outReductionTypes->reserve(outReductionTypes->size() +
1095-
reductionVars.size());
1096-
llvm::transform(reductionVars, std::back_inserter(*outReductionTypes),
1097-
[](mlir::Value v) { return v.getType(); });
1098-
}
1099-
1100-
if (outReductionSyms)
1101-
llvm::copy(reductionSyms, std::back_inserter(*outReductionSyms));
1073+
llvm::copy(reductionSyms, std::back_inserter(outReductionSyms));
11021074
});
11031075
}
11041076

@@ -1124,8 +1096,6 @@ bool ClauseProcessor::processEnter(
11241096

11251097
bool ClauseProcessor::processUseDeviceAddr(
11261098
lower::StatementContext &stmtCtx, mlir::omp::UseDeviceAddrClauseOps &result,
1127-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
1128-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
11291099
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
11301100
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
11311101
bool clauseFound = findRepeatableClause<omp::clause::UseDeviceAddr>(
@@ -1137,19 +1107,16 @@ bool ClauseProcessor::processUseDeviceAddr(
11371107
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
11381108
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
11391109
parentMemberIndices, result.useDeviceAddrVars,
1140-
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
1110+
useDeviceSyms);
11411111
});
11421112

11431113
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1144-
result.useDeviceAddrVars, &useDeviceTypes,
1145-
&useDeviceLocs, &useDeviceSyms);
1114+
result.useDeviceAddrVars, useDeviceSyms);
11461115
return clauseFound;
11471116
}
11481117

11491118
bool ClauseProcessor::processUseDevicePtr(
11501119
lower::StatementContext &stmtCtx, mlir::omp::UseDevicePtrClauseOps &result,
1151-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
1152-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
11531120
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const {
11541121
std::map<Object, OmpMapParentAndMemberData> parentMemberIndices;
11551122

@@ -1162,12 +1129,11 @@ bool ClauseProcessor::processUseDevicePtr(
11621129
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
11631130
processMapObjects(stmtCtx, location, clause.v, mapTypeBits,
11641131
parentMemberIndices, result.useDevicePtrVars,
1165-
&useDeviceSyms, &useDeviceLocs, &useDeviceTypes);
1132+
useDeviceSyms);
11661133
});
11671134

11681135
insertChildMapInfoIntoParent(converter, semaCtx, stmtCtx, parentMemberIndices,
1169-
result.useDevicePtrVars, &useDeviceTypes,
1170-
&useDeviceLocs, &useDeviceSyms);
1136+
result.useDevicePtrVars, useDeviceSyms);
11711137
return clauseFound;
11721138
}
11731139

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ class ClauseProcessor {
6868
mlir::omp::FinalClauseOps &result) const;
6969
bool processHasDeviceAddr(
7070
mlir::omp::HasDeviceAddrClauseOps &result,
71-
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
72-
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
73-
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
71+
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
7472
bool processHint(mlir::omp::HintClauseOps &result) const;
7573
bool processMergeable(mlir::omp::MergeableClauseOps &result) const;
7674
bool processNowait(mlir::omp::NowaitClauseOps &result) const;
@@ -104,43 +102,33 @@ class ClauseProcessor {
104102
mlir::omp::IfClauseOps &result) const;
105103
bool processIsDevicePtr(
106104
mlir::omp::IsDevicePtrClauseOps &result,
107-
llvm::SmallVectorImpl<mlir::Type> &isDeviceTypes,
108-
llvm::SmallVectorImpl<mlir::Location> &isDeviceLocs,
109-
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSymbols) const;
105+
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
110106
bool
111107
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
112108

113109
// This method is used to process a map clause.
114-
// The optional parameters - mapSymTypes, mapSymLocs & mapSyms are used to
115-
// store the original type, location and Fortran symbol for the map operands.
116-
// They may be used later on to create the block_arguments for some of the
117-
// target directives that require it.
118-
bool processMap(
119-
mlir::Location currentLocation, lower::StatementContext &stmtCtx,
120-
mlir::omp::MapClauseOps &result,
121-
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms = nullptr,
122-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
123-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
110+
// The optional parameter mapSyms is used to store the original Fortran symbol
111+
// for the map operands. It may be used later on to create the block_arguments
112+
// for some of the directives that require it.
113+
bool processMap(mlir::Location currentLocation,
114+
lower::StatementContext &stmtCtx,
115+
mlir::omp::MapClauseOps &result,
116+
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms =
117+
nullptr) const;
124118
bool processMotionClauses(lower::StatementContext &stmtCtx,
125119
mlir::omp::MapClauseOps &result);
126120
bool processNontemporal(mlir::omp::NontemporalClauseOps &result) const;
127121
bool processReduction(
128122
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
129-
llvm::SmallVectorImpl<mlir::Type> *reductionTypes = nullptr,
130-
llvm::SmallVectorImpl<const semantics::Symbol *> *reductionSyms =
131-
nullptr) const;
123+
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
132124
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
133125
bool processUseDeviceAddr(
134126
lower::StatementContext &stmtCtx,
135127
mlir::omp::UseDeviceAddrClauseOps &result,
136-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
137-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
138128
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
139129
bool processUseDevicePtr(
140130
lower::StatementContext &stmtCtx,
141131
mlir::omp::UseDevicePtrClauseOps &result,
142-
llvm::SmallVectorImpl<mlir::Type> &useDeviceTypes,
143-
llvm::SmallVectorImpl<mlir::Location> &useDeviceLocs,
144132
llvm::SmallVectorImpl<const semantics::Symbol *> &useDeviceSyms) const;
145133

146134
// Call this method for these clauses that should be supported but are not
@@ -180,9 +168,7 @@ class ClauseProcessor {
180168
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits,
181169
std::map<Object, OmpMapParentAndMemberData> &parentMemberIndices,
182170
llvm::SmallVectorImpl<mlir::Value> &mapVars,
183-
llvm::SmallVectorImpl<const semantics::Symbol *> *mapSyms,
184-
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
185-
llvm::SmallVectorImpl<mlir::Type> *mapSymTypes = nullptr) const;
171+
llvm::SmallVectorImpl<const semantics::Symbol *> &mapSyms) const;
186172

187173
lower::AbstractConverter &converter;
188174
semantics::SemanticsContext &semaCtx;

0 commit comments

Comments
 (0)