Skip to content

Commit 88e1520

Browse files
committed
[mlir][OpenMP] Annotate private vars with map_idx when needed
This PR extends the MLIR representation for `omp.target` ops by adding a `map_idx` to `private` vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, the `map_idx` attribute is either not present at all or its value is `-1`.
1 parent 752dbd6 commit 88e1520

File tree

8 files changed

+159
-41
lines changed

8 files changed

+159
-41
lines changed

flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//===- MapsForPrivatizedSymbols.cpp
2-
//-----------------------------------------===//
1+
//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
@@ -28,8 +27,10 @@
2827
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
2928
#include "flang/Optimizer/HLFIR/HLFIROps.h"
3029
#include "flang/Optimizer/OpenMP/Passes.h"
30+
3131
#include "mlir/Dialect/Func/IR/FuncOps.h"
3232
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
33+
#include "mlir/Dialect/OpenMP/Utils.h"
3334
#include "mlir/IR/BuiltinAttributes.h"
3435
#include "mlir/IR/SymbolTable.h"
3536
#include "mlir/Pass/Pass.h"
@@ -124,6 +125,8 @@ class MapsForPrivatizedSymbolsPass
124125
if (targetOp.getPrivateVars().empty())
125126
return;
126127
OperandRange privVars = targetOp.getPrivateVars();
128+
llvm::SmallVector<int64_t> privVarMapIdx;
129+
127130
std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
128131
SmallVector<omp::MapInfoOp, 4> mapInfoOps;
129132
for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +136,25 @@ class MapsForPrivatizedSymbolsPass
133136
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
134137
targetOp, privatizerName);
135138
if (!privatizerNeedsMap(privatizer)) {
139+
privVarMapIdx.push_back(-1);
136140
continue;
137141
}
142+
143+
privVarMapIdx.push_back(targetOp.getMapVars().size() +
144+
mapInfoOps.size());
145+
138146
builder.setInsertionPoint(targetOp);
139147
Location loc = targetOp.getLoc();
140148
omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
141149
mapInfoOps.push_back(mapInfoOp);
150+
142151
LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
143152
LLVM_DEBUG(mapInfoOp.dump());
144153
}
145154
if (!mapInfoOps.empty()) {
146155
mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
156+
targetOp.setPrivateMapsAttr(mlir::omp::utils::makeI64ArrayAttr(
157+
privVarMapIdx, targetOp.getContext()));
147158
}
148159
});
149160
if (!mapInfoOpsForTarget.empty()) {

flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,12 @@ end subroutine target_allocatable
171171
! CHECK_SAME %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
172172
! CHECK-SAME !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
173173
! CHECK-SAME: private(
174-
! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
175-
! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
176-
! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
177-
! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
178-
! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
179-
! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
174+
! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
175+
! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]] [map_idx=-1],
176+
! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]] [map_idx=-1],
177+
! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
178+
! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]] [map_idx=-1],
179+
! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
180180
! CHECK-SAME: !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
181181
! CHECK-NOT: fir.alloca
182182
! CHECK: hlfir.declare %[[ALLOC_ARG]]

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
12271227
a device, if it is 0 then the target region is executed on the host device.
12281228
}] # clausesDescription;
12291229

1230+
let arguments = !con(clausesArgs,
1231+
(ins OptionalAttr<I64ArrayAttr>:$private_maps));
1232+
12301233
let builders = [
12311234
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
12321235
];
@@ -1239,7 +1242,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
12391242
custom<InReductionMapPrivateRegion>(
12401243
$region, $in_reduction_vars, type($in_reduction_vars),
12411244
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
1242-
$private_vars, type($private_vars), $private_syms) attr-dict
1245+
$private_vars, type($private_vars), $private_syms, $private_maps)
1246+
attr-dict
12431247
}];
12441248

12451249
let hasVerifier = 1;
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===- Utils.h - Utils for the OpenMP MLIR Dialect --------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_OPENMP_UTILS_H_
10+
#define MLIR_DIALECT_OPENMP_UTILS_H_
11+
12+
#include "mlir/IR/BuiltinAttributes.h"
13+
14+
namespace mlir::omp::utils {
15+
mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
16+
mlir::MLIRContext *context);
17+
} // namespace mlir::omp::utils
18+
19+
#endif // MLIR_DIALECT_OPENMP_UTILS_H_

mlir/lib/Dialect/OpenMP/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIROpenMPDialect
22
IR/OpenMPDialect.cpp
3+
IR/Utils.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 69 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/Func/IR/FuncOps.h"
1616
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1717
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
18+
#include "mlir/Dialect/OpenMP/Utils.h"
1819
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/BuiltinAttributes.h"
2021
#include "mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
487488
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
488489
llvm::SmallVectorImpl<Type> &types;
489490
ArrayAttr &syms;
491+
ArrayAttr *mapIndices;
490492
PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
491-
SmallVectorImpl<Type> &types, ArrayAttr &syms)
492-
: vars(vars), types(types), syms(syms) {}
493+
SmallVectorImpl<Type> &types, ArrayAttr &syms,
494+
ArrayAttr *mapIndices = nullptr)
495+
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
493496
};
494497
struct ReductionParseArgs {
495498
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
517520
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
518521
SmallVectorImpl<Type> &types,
519522
SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
520-
ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
523+
ArrayAttr *symbols = nullptr, ArrayAttr *mapIndices = nullptr,
524+
DenseBoolArrayAttr *byref = nullptr) {
521525
SmallVector<SymbolRefAttr> symbolVec;
526+
SmallVector<int64_t> mapIndicesVec;
522527
SmallVector<bool> isByRefVec;
523528
unsigned regionArgOffset = regionPrivateArgs.size();
524529

@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
538543
parser.parseArgument(regionPrivateArgs.emplace_back()))
539544
return failure();
540545

546+
if (mapIndices) {
547+
if (parser.parseOptionalLSquare().succeeded()) {
548+
if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
549+
parser.parseInteger(mapIndicesVec.emplace_back()) ||
550+
parser.parseRSquare())
551+
return failure();
552+
} else
553+
mapIndicesVec.push_back(-1);
554+
}
555+
541556
return success();
542557
}))
543558
return failure();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
571586
*symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
572587
}
573588

589+
if (!mapIndicesVec.empty())
590+
*mapIndices = utils::makeI64ArrayAttr(mapIndicesVec, parser.getContext());
591+
574592
if (byref)
575593
*byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
576594

@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
595613
static ParseResult parseBlockArgClause(
596614
OpAsmParser &parser,
597615
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
598-
StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
616+
StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
599617
if (succeeded(parser.parseOptionalKeyword(keyword))) {
600-
if (!reductionArgs)
618+
if (!privateArgs)
601619
return failure();
602620

603-
if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
604-
reductionArgs->types, entryBlockArgs,
605-
&reductionArgs->syms)))
621+
if (failed(parseClauseWithRegionArgs(
622+
parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
623+
&privateArgs->syms, privateArgs->mapIndices)))
606624
return failure();
607625
}
608626
return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
618636

619637
if (failed(parseClauseWithRegionArgs(
620638
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
621-
&reductionArgs->syms, &reductionArgs->byref)))
639+
&reductionArgs->syms, /*mapIndices=*/nullptr,
640+
&reductionArgs->byref)))
622641
return failure();
623642
}
624643
return success();
@@ -674,12 +693,14 @@ static ParseResult parseInReductionMapPrivateRegion(
674693
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
675694
SmallVectorImpl<Type> &mapTypes,
676695
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
677-
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
696+
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
697+
ArrayAttr &privateMaps) {
678698
AllRegionParseArgs args;
679699
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
680700
inReductionByref, inReductionSyms);
681701
args.mapArgs.emplace(mapVars, mapTypes);
682-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
702+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
703+
&privateMaps);
683704
return parseBlockArgRegion(parser, region, args);
684705
}
685706

@@ -776,8 +797,10 @@ struct PrivatePrintArgs {
776797
ValueRange vars;
777798
TypeRange types;
778799
ArrayAttr syms;
779-
PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
780-
: vars(vars), types(types), syms(syms) {}
800+
ArrayAttr mapIndices;
801+
PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
802+
ArrayAttr mapIndices)
803+
: vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
781804
};
782805
struct ReductionPrintArgs {
783806
ValueRange vars;
@@ -804,6 +827,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
804827
ValueRange argsSubrange,
805828
ValueRange operands, TypeRange types,
806829
ArrayAttr symbols = nullptr,
830+
ArrayAttr mapIndices = nullptr,
807831
DenseBoolArrayAttr byref = nullptr) {
808832
if (argsSubrange.empty())
809833
return;
@@ -815,21 +839,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
815839
symbols = ArrayAttr::get(ctx, values);
816840
}
817841

842+
if (!mapIndices) {
843+
llvm::SmallVector<Attribute> values(operands.size(), nullptr);
844+
mapIndices = ArrayAttr::get(ctx, values);
845+
}
846+
818847
if (!byref) {
819848
mlir::SmallVector<bool> values(operands.size(), false);
820849
byref = DenseBoolArrayAttr::get(ctx, values);
821850
}
822851

823-
llvm::interleaveComma(
824-
llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
825-
[&p](auto t) {
826-
auto [op, arg, sym, isByRef] = t;
827-
if (isByRef)
828-
p << "byref ";
829-
if (sym)
830-
p << sym << " ";
831-
p << op << " -> " << arg;
832-
});
852+
llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
853+
mapIndices, byref.asArrayRef()),
854+
p, [&p](auto t) {
855+
auto [op, arg, sym, map, isByRef] = t;
856+
if (isByRef)
857+
p << "byref ";
858+
if (sym)
859+
p << sym << " ";
860+
861+
p << op << " -> " << arg;
862+
863+
if (map)
864+
p << " [map_idx="
865+
<< llvm::cast<IntegerAttr>(map).getInt() << "]";
866+
});
833867
p << " : ";
834868
llvm::interleaveComma(types, p);
835869
p << ") ";
@@ -849,7 +883,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
849883
if (privateArgs)
850884
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
851885
privateArgs->vars, privateArgs->types,
852-
privateArgs->syms);
886+
privateArgs->syms, privateArgs->mapIndices);
853887
}
854888

855889
static void
@@ -859,7 +893,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
859893
if (reductionArgs)
860894
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
861895
reductionArgs->vars, reductionArgs->types,
862-
reductionArgs->syms, reductionArgs->byref);
896+
reductionArgs->syms, nullptr,
897+
reductionArgs->byref);
863898
}
864899

865900
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -891,12 +926,13 @@ static void printInReductionMapPrivateRegion(
891926
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
892927
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
893928
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
894-
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
929+
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
930+
ArrayAttr privateMaps) {
895931
AllRegionPrintArgs args;
896932
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
897933
inReductionByref, inReductionSyms);
898934
args.mapArgs.emplace(mapVars, mapTypes);
899-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
935+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
900936
printBlockArgRegion(p, op, region, args);
901937
}
902938

@@ -908,7 +944,7 @@ static void printInReductionPrivateRegion(
908944
AllRegionPrintArgs args;
909945
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
910946
inReductionByref, inReductionSyms);
911-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
947+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
912948
printBlockArgRegion(p, op, region, args);
913949
}
914950

@@ -921,7 +957,7 @@ static void printInReductionPrivateReductionRegion(
921957
AllRegionPrintArgs args;
922958
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
923959
inReductionByref, inReductionSyms);
924-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
960+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
925961
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
926962
reductionSyms);
927963
printBlockArgRegion(p, op, region, args);
@@ -931,7 +967,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
931967
ValueRange privateVars, TypeRange privateTypes,
932968
ArrayAttr privateSyms) {
933969
AllRegionPrintArgs args;
934-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
970+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
935971
printBlockArgRegion(p, op, region, args);
936972
}
937973

@@ -941,7 +977,7 @@ static void printPrivateReductionRegion(
941977
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
942978
ArrayAttr reductionSyms) {
943979
AllRegionPrintArgs args;
944-
args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
980+
args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
945981
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
946982
reductionSyms);
947983
printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1692,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
16561692
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
16571693
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
16581694
clauses.mapVars, clauses.nowait, clauses.privateVars,
1659-
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
1695+
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1696+
/*private_maps=*/nullptr);
16601697
}
16611698

16621699
LogicalResult TargetOp::verify() {

mlir/lib/Dialect/OpenMP/IR/Utils.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//===- Utils.cpp - Utils for the OpenMP MLIR Dialect ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/OpenMP/Utils.h"
10+
#include "mlir/IR/BuiltinTypes.h"
11+
12+
namespace mlir::omp::utils {
13+
mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
14+
mlir::MLIRContext *context) {
15+
llvm::SmallVector<mlir::Attribute, 4> attrs;
16+
attrs.reserve(values.size());
17+
for (auto &v : values)
18+
attrs.push_back(mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64),
19+
mlir::APInt(64, v)));
20+
return mlir::ArrayAttr::get(context, attrs);
21+
}
22+
} // namespace mlir::omp::utils

0 commit comments

Comments
 (0)