-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][OpenMP] Annotate private
vars with map_idx
when needed
#116770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-openmp @llvm/pr-subscribers-flang-openmp Author: Kareem Ergawy (ergawy) ChangesThis PR extends the MLIR representation for Patch is 20.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116770.diff 8 Files Affected:
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 289e648eed8546..6e537300dfb7f1 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -1,5 +1,4 @@
-//===- MapsForPrivatizedSymbols.cpp
-//-----------------------------------------===//
+//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -28,8 +27,10 @@
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
+
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
@@ -124,6 +125,8 @@ class MapsForPrivatizedSymbolsPass
if (targetOp.getPrivateVars().empty())
return;
OperandRange privVars = targetOp.getPrivateVars();
+ llvm::SmallVector<int64_t> privVarMapIdx;
+
std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
SmallVector<omp::MapInfoOp, 4> mapInfoOps;
for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +136,25 @@ class MapsForPrivatizedSymbolsPass
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
targetOp, privatizerName);
if (!privatizerNeedsMap(privatizer)) {
+ privVarMapIdx.push_back(-1);
continue;
}
+
+ privVarMapIdx.push_back(targetOp.getMapVars().size() +
+ mapInfoOps.size());
+
builder.setInsertionPoint(targetOp);
Location loc = targetOp.getLoc();
omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
mapInfoOps.push_back(mapInfoOp);
+
LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
LLVM_DEBUG(mapInfoOp.dump());
}
if (!mapInfoOps.empty()) {
mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+ targetOp.setPrivateMapsAttr(mlir::omp::utils::makeI64ArrayAttr(
+ privVarMapIdx, targetOp.getContext()));
}
});
if (!mapInfoOpsForTarget.empty()) {
diff --git a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
index b0c76ff3845f83..602e98975e9dc5 100644
--- a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
+++ b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
@@ -171,12 +171,12 @@ end subroutine target_allocatable
! CHECK_SAME %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
! CHECK-SAME !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
! CHECK-SAME: private(
-! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
-! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
-! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
-! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
-! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
-! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
+! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
+! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
+! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
! CHECK-SAME: !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
! CHECK-NOT: fir.alloca
! CHECK: hlfir.declare %[[ALLOC_ARG]]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 156e6eb371b85d..31ecbea8e0c211 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1227,6 +1227,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
a device, if it is 0 then the target region is executed on the host device.
}] # clausesDescription;
+ let arguments = !con(clausesArgs,
+ (ins OptionalAttr<I64ArrayAttr>:$private_maps));
+
let builders = [
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
];
@@ -1239,7 +1242,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
custom<InReductionMapPrivateRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
- $private_vars, type($private_vars), $private_syms) attr-dict
+ $private_vars, type($private_vars), $private_syms, $private_maps)
+ attr-dict
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/Utils.h b/mlir/include/mlir/Dialect/OpenMP/Utils.h
new file mode 100644
index 00000000000000..f79e10b1e5ab38
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Utils.h
@@ -0,0 +1,19 @@
+//===- Utils.h - Utils for the OpenMP MLIR Dialect --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_UTILS_H_
+#define MLIR_DIALECT_OPENMP_UTILS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+ mlir::MLIRContext *context);
+} // namespace mlir::omp::utils
+
+#endif // MLIR_DIALECT_OPENMP_UTILS_H_
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d3445c151c..809bd1306563bd 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIROpenMPDialect
IR/OpenMPDialect.cpp
+ IR/Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 94e71e089d4b18..4a13272b8f4a83 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
llvm::SmallVectorImpl<Type> &types;
ArrayAttr &syms;
+ ArrayAttr *mapIndices;
PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
- SmallVectorImpl<Type> &types, ArrayAttr &syms)
- : vars(vars), types(types), syms(syms) {}
+ SmallVectorImpl<Type> &types, ArrayAttr &syms,
+ ArrayAttr *mapIndices=nullptr)
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
struct ReductionParseArgs {
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
- ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+ ArrayAttr *symbols = nullptr, ArrayAttr *mapIndices = nullptr,
+ DenseBoolArrayAttr *byref = nullptr) {
SmallVector<SymbolRefAttr> symbolVec;
+ SmallVector<int64_t> mapIndicesVec;
SmallVector<bool> isByRefVec;
unsigned regionArgOffset = regionPrivateArgs.size();
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
parser.parseArgument(regionPrivateArgs.emplace_back()))
return failure();
+ if (mapIndices) {
+ if (parser.parseOptionalLSquare().succeeded()) {
+ if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
+ parser.parseInteger(mapIndicesVec.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else
+ mapIndicesVec.push_back(-1);
+ }
+
return success();
}))
return failure();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
*symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
}
+ if (!mapIndicesVec.empty())
+ *mapIndices = utils::makeI64ArrayAttr(mapIndicesVec, parser.getContext());
+
if (byref)
*byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
static ParseResult parseBlockArgClause(
OpAsmParser &parser,
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
- StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
+ StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
if (succeeded(parser.parseOptionalKeyword(keyword))) {
- if (!reductionArgs)
+ if (!privateArgs)
return failure();
- if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
- reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms)))
+ if (failed(parseClauseWithRegionArgs(
+ parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
+ &privateArgs->syms, privateArgs->mapIndices)))
return failure();
}
return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
if (failed(parseClauseWithRegionArgs(
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms, &reductionArgs->byref)))
+ &reductionArgs->syms, /*mapIndices=*/nullptr,
+ &reductionArgs->byref)))
return failure();
}
return success();
@@ -674,12 +693,13 @@ static ParseResult parseInReductionMapPrivateRegion(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
SmallVectorImpl<Type> &mapTypes,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
- llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+ llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ArrayAttr &privateMaps) {
AllRegionParseArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, &privateMaps);
return parseBlockArgRegion(parser, region, args);
}
@@ -776,8 +796,10 @@ struct PrivatePrintArgs {
ValueRange vars;
TypeRange types;
ArrayAttr syms;
- PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
- : vars(vars), types(types), syms(syms) {}
+ ArrayAttr mapIndices;
+ PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
+ ArrayAttr mapIndices)
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
struct ReductionPrintArgs {
ValueRange vars;
@@ -804,6 +826,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
ValueRange argsSubrange,
ValueRange operands, TypeRange types,
ArrayAttr symbols = nullptr,
+ ArrayAttr mapIndices = nullptr,
DenseBoolArrayAttr byref = nullptr) {
if (argsSubrange.empty())
return;
@@ -815,21 +838,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
symbols = ArrayAttr::get(ctx, values);
}
+ if (!mapIndices) {
+ llvm::SmallVector<Attribute> values(operands.size(), nullptr);
+ mapIndices = ArrayAttr::get(ctx, values);
+ }
+
if (!byref) {
mlir::SmallVector<bool> values(operands.size(), false);
byref = DenseBoolArrayAttr::get(ctx, values);
}
- llvm::interleaveComma(
- llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
- [&p](auto t) {
- auto [op, arg, sym, isByRef] = t;
- if (isByRef)
- p << "byref ";
- if (sym)
- p << sym << " ";
- p << op << " -> " << arg;
- });
+ llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
+ mapIndices, byref.asArrayRef()),
+ p, [&p](auto t) {
+ auto [op, arg, sym, map, isByRef] = t;
+ if (isByRef)
+ p << "byref ";
+ if (sym)
+ p << sym << " ";
+
+ p << op << " -> " << arg;
+
+ if (map)
+ p << " [map_idx="
+ << llvm::cast<IntegerAttr>(map).getInt() << "]";
+ });
p << " : ";
llvm::interleaveComma(types, p);
p << ") ";
@@ -849,7 +882,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
if (privateArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
privateArgs->vars, privateArgs->types,
- privateArgs->syms);
+ privateArgs->syms, privateArgs->mapIndices);
}
static void
@@ -859,7 +892,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
if (reductionArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
reductionArgs->vars, reductionArgs->types,
- reductionArgs->syms, reductionArgs->byref);
+ reductionArgs->syms, nullptr,
+ reductionArgs->byref);
}
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -891,12 +925,13 @@ static void printInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
+ ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+ ArrayAttr privateMaps) {
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
printBlockArgRegion(p, op, region, args);
}
@@ -908,7 +943,7 @@ static void printInReductionPrivateRegion(
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
printBlockArgRegion(p, op, region, args);
}
@@ -921,7 +956,7 @@ static void printInReductionPrivateReductionRegion(
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms);
printBlockArgRegion(p, op, region, args);
@@ -931,7 +966,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms) {
AllRegionPrintArgs args;
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
printBlockArgRegion(p, op, region, args);
}
@@ -941,7 +976,7 @@ static void printPrivateReductionRegion(
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms);
printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1691,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+ /*private_maps=*/nullptr);
}
LogicalResult TargetOp::verify() {
diff --git a/mlir/lib/Dialect/OpenMP/IR/Utils.cpp b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
new file mode 100644
index 00000000000000..64e84817c818b2
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for the OpenMP MLIR Dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+ mlir::MLIRContext *context) {
+ llvm::SmallVector<mlir::Attribute, 4> attrs;
+ attrs.reserve(values.size());
+ for (auto &v : values)
+ attrs.push_back(mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64),
+ mlir::APInt(64, v)));
+ return mlir::ArrayAttr::get(context, attrs);
+}
+} // namespace mlir::omp::utils
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c25a6ef4b4849b..94c63dd8e9aa0e 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
return
}
+// CHECK-LABEL: omp_target_private_with_map_idx
+func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+ %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Kareem Ergawy (ergawy) ChangesThis PR extends the MLIR representation for Patch is 20.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116770.diff 8 Files Affected:
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 289e648eed8546..6e537300dfb7f1 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -1,5 +1,4 @@
-//===- MapsForPrivatizedSymbols.cpp
-//-----------------------------------------===//
+//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -28,8 +27,10 @@
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
+
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
@@ -124,6 +125,8 @@ class MapsForPrivatizedSymbolsPass
if (targetOp.getPrivateVars().empty())
return;
OperandRange privVars = targetOp.getPrivateVars();
+ llvm::SmallVector<int64_t> privVarMapIdx;
+
std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
SmallVector<omp::MapInfoOp, 4> mapInfoOps;
for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +136,25 @@ class MapsForPrivatizedSymbolsPass
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
targetOp, privatizerName);
if (!privatizerNeedsMap(privatizer)) {
+ privVarMapIdx.push_back(-1);
continue;
}
+
+ privVarMapIdx.push_back(targetOp.getMapVars().size() +
+ mapInfoOps.size());
+
builder.setInsertionPoint(targetOp);
Location loc = targetOp.getLoc();
omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
mapInfoOps.push_back(mapInfoOp);
+
LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
LLVM_DEBUG(mapInfoOp.dump());
}
if (!mapInfoOps.empty()) {
mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+ targetOp.setPrivateMapsAttr(mlir::omp::utils::makeI64ArrayAttr(
+ privVarMapIdx, targetOp.getContext()));
}
});
if (!mapInfoOpsForTarget.empty()) {
diff --git a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
index b0c76ff3845f83..602e98975e9dc5 100644
--- a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
+++ b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
@@ -171,12 +171,12 @@ end subroutine target_allocatable
! CHECK_SAME %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
! CHECK-SAME !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
! CHECK-SAME: private(
-! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
-! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
-! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
-! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
-! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
-! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
+! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
+! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
+! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
! CHECK-SAME: !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
! CHECK-NOT: fir.alloca
! CHECK: hlfir.declare %[[ALLOC_ARG]]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 156e6eb371b85d..31ecbea8e0c211 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1227,6 +1227,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
a device, if it is 0 then the target region is executed on the host device.
}] # clausesDescription;
+ let arguments = !con(clausesArgs,
+ (ins OptionalAttr<I64ArrayAttr>:$private_maps));
+
let builders = [
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
];
@@ -1239,7 +1242,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
custom<InReductionMapPrivateRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
- $private_vars, type($private_vars), $private_syms) attr-dict
+ $private_vars, type($private_vars), $private_syms, $private_maps)
+ attr-dict
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/Utils.h b/mlir/include/mlir/Dialect/OpenMP/Utils.h
new file mode 100644
index 00000000000000..f79e10b1e5ab38
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Utils.h
@@ -0,0 +1,19 @@
+//===- Utils.h - Utils for the OpenMP MLIR Dialect --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_UTILS_H_
+#define MLIR_DIALECT_OPENMP_UTILS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+ mlir::MLIRContext *context);
+} // namespace mlir::omp::utils
+
+#endif // MLIR_DIALECT_OPENMP_UTILS_H_
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d3445c151c..809bd1306563bd 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIROpenMPDialect
IR/OpenMPDialect.cpp
+ IR/Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 94e71e089d4b18..4a13272b8f4a83 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
llvm::SmallVectorImpl<Type> &types;
ArrayAttr &syms;
+ ArrayAttr *mapIndices;
PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
- SmallVectorImpl<Type> &types, ArrayAttr &syms)
- : vars(vars), types(types), syms(syms) {}
+ SmallVectorImpl<Type> &types, ArrayAttr &syms,
+ ArrayAttr *mapIndices=nullptr)
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
struct ReductionParseArgs {
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
- ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+ ArrayAttr *symbols = nullptr, ArrayAttr *mapIndices = nullptr,
+ DenseBoolArrayAttr *byref = nullptr) {
SmallVector<SymbolRefAttr> symbolVec;
+ SmallVector<int64_t> mapIndicesVec;
SmallVector<bool> isByRefVec;
unsigned regionArgOffset = regionPrivateArgs.size();
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
parser.parseArgument(regionPrivateArgs.emplace_back()))
return failure();
+ if (mapIndices) {
+ if (parser.parseOptionalLSquare().succeeded()) {
+ if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
+ parser.parseInteger(mapIndicesVec.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else
+ mapIndicesVec.push_back(-1);
+ }
+
return success();
}))
return failure();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
*symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
}
+ if (!mapIndicesVec.empty())
+ *mapIndices = utils::makeI64ArrayAttr(mapIndicesVec, parser.getContext());
+
if (byref)
*byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
static ParseResult parseBlockArgClause(
OpAsmParser &parser,
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
- StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
+ StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
if (succeeded(parser.parseOptionalKeyword(keyword))) {
- if (!reductionArgs)
+ if (!privateArgs)
return failure();
- if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
- reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms)))
+ if (failed(parseClauseWithRegionArgs(
+ parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
+ &privateArgs->syms, privateArgs->mapIndices)))
return failure();
}
return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
if (failed(parseClauseWithRegionArgs(
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms, &reductionArgs->byref)))
+ &reductionArgs->syms, /*mapIndices=*/nullptr,
+ &reductionArgs->byref)))
return failure();
}
return success();
@@ -674,12 +693,13 @@ static ParseResult parseInReductionMapPrivateRegion(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
SmallVectorImpl<Type> &mapTypes,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
- llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+ llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ArrayAttr &privateMaps) {
AllRegionParseArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, &privateMaps);
return parseBlockArgRegion(parser, region, args);
}
@@ -776,8 +796,10 @@ struct PrivatePrintArgs {
ValueRange vars;
TypeRange types;
ArrayAttr syms;
- PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
- : vars(vars), types(types), syms(syms) {}
+ ArrayAttr mapIndices;
+ PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
+ ArrayAttr mapIndices)
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
struct ReductionPrintArgs {
ValueRange vars;
@@ -804,6 +826,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
ValueRange argsSubrange,
ValueRange operands, TypeRange types,
ArrayAttr symbols = nullptr,
+ ArrayAttr mapIndices = nullptr,
DenseBoolArrayAttr byref = nullptr) {
if (argsSubrange.empty())
return;
@@ -815,21 +838,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
symbols = ArrayAttr::get(ctx, values);
}
+ if (!mapIndices) {
+ llvm::SmallVector<Attribute> values(operands.size(), nullptr);
+ mapIndices = ArrayAttr::get(ctx, values);
+ }
+
if (!byref) {
mlir::SmallVector<bool> values(operands.size(), false);
byref = DenseBoolArrayAttr::get(ctx, values);
}
- llvm::interleaveComma(
- llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
- [&p](auto t) {
- auto [op, arg, sym, isByRef] = t;
- if (isByRef)
- p << "byref ";
- if (sym)
- p << sym << " ";
- p << op << " -> " << arg;
- });
+ llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
+ mapIndices, byref.asArrayRef()),
+ p, [&p](auto t) {
+ auto [op, arg, sym, map, isByRef] = t;
+ if (isByRef)
+ p << "byref ";
+ if (sym)
+ p << sym << " ";
+
+ p << op << " -> " << arg;
+
+ if (map)
+ p << " [map_idx="
+ << llvm::cast<IntegerAttr>(map).getInt() << "]";
+ });
p << " : ";
llvm::interleaveComma(types, p);
p << ") ";
@@ -849,7 +882,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
if (privateArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
privateArgs->vars, privateArgs->types,
- privateArgs->syms);
+ privateArgs->syms, privateArgs->mapIndices);
}
static void
@@ -859,7 +892,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
if (reductionArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
reductionArgs->vars, reductionArgs->types,
- reductionArgs->syms, reductionArgs->byref);
+ reductionArgs->syms, nullptr,
+ reductionArgs->byref);
}
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -891,12 +925,13 @@ static void printInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
+ ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+ ArrayAttr privateMaps) {
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
printBlockArgRegion(p, op, region, args);
}
@@ -908,7 +943,7 @@ static void printInReductionPrivateRegion(
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
printBlockArgRegion(p, op, region, args);
}
@@ -921,7 +956,7 @@ static void printInReductionPrivateReductionRegion(
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms);
printBlockArgRegion(p, op, region, args);
@@ -931,7 +966,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms) {
AllRegionPrintArgs args;
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
printBlockArgRegion(p, op, region, args);
}
@@ -941,7 +976,7 @@ static void printPrivateReductionRegion(
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms);
printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1691,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+ /*private_maps=*/nullptr);
}
LogicalResult TargetOp::verify() {
diff --git a/mlir/lib/Dialect/OpenMP/IR/Utils.cpp b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
new file mode 100644
index 00000000000000..64e84817c818b2
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for the OpenMP MLIR Dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+ mlir::MLIRContext *context) {
+ llvm::SmallVector<mlir::Attribute, 4> attrs;
+ attrs.reserve(values.size());
+ for (auto &v : values)
+ attrs.push_back(mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64),
+ mlir::APInt(64, v)));
+ return mlir::ArrayAttr::get(context, attrs);
+}
+} // namespace mlir::omp::utils
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c25a6ef4b4849b..94c63dd8e9aa0e 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
return
}
+// CHECK-LABEL: omp_target_private_with_map_idx
+func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+ %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>...
[truncated]
|
@llvm/pr-subscribers-flang-fir-hlfir Author: Kareem Ergawy (ergawy) ChangesThis PR extends the MLIR representation for Patch is 20.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116770.diff 8 Files Affected:
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 289e648eed8546..6e537300dfb7f1 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -1,5 +1,4 @@
-//===- MapsForPrivatizedSymbols.cpp
-//-----------------------------------------===//
+//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -28,8 +27,10 @@
#include "flang/Optimizer/Dialect/Support/KindMapping.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/OpenMP/Passes.h"
+
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
@@ -124,6 +125,8 @@ class MapsForPrivatizedSymbolsPass
if (targetOp.getPrivateVars().empty())
return;
OperandRange privVars = targetOp.getPrivateVars();
+ llvm::SmallVector<int64_t> privVarMapIdx;
+
std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
SmallVector<omp::MapInfoOp, 4> mapInfoOps;
for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +136,25 @@ class MapsForPrivatizedSymbolsPass
SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
targetOp, privatizerName);
if (!privatizerNeedsMap(privatizer)) {
+ privVarMapIdx.push_back(-1);
continue;
}
+
+ privVarMapIdx.push_back(targetOp.getMapVars().size() +
+ mapInfoOps.size());
+
builder.setInsertionPoint(targetOp);
Location loc = targetOp.getLoc();
omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
mapInfoOps.push_back(mapInfoOp);
+
LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
LLVM_DEBUG(mapInfoOp.dump());
}
if (!mapInfoOps.empty()) {
mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+ targetOp.setPrivateMapsAttr(mlir::omp::utils::makeI64ArrayAttr(
+ privVarMapIdx, targetOp.getContext()));
}
});
if (!mapInfoOpsForTarget.empty()) {
diff --git a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
index b0c76ff3845f83..602e98975e9dc5 100644
--- a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
+++ b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
@@ -171,12 +171,12 @@ end subroutine target_allocatable
! CHECK_SAME %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
! CHECK-SAME !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
! CHECK-SAME: private(
-! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
-! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
-! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
-! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
-! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
-! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
+! CHECK-SAME: @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
+! CHECK-SAME: @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
+! CHECK-SAME: @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME: @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
! CHECK-SAME: !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
! CHECK-NOT: fir.alloca
! CHECK: hlfir.declare %[[ALLOC_ARG]]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 156e6eb371b85d..31ecbea8e0c211 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1227,6 +1227,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
a device, if it is 0 then the target region is executed on the host device.
}] # clausesDescription;
+ let arguments = !con(clausesArgs,
+ (ins OptionalAttr<I64ArrayAttr>:$private_maps));
+
let builders = [
OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
];
@@ -1239,7 +1242,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
custom<InReductionMapPrivateRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
- $private_vars, type($private_vars), $private_syms) attr-dict
+ $private_vars, type($private_vars), $private_syms, $private_maps)
+ attr-dict
}];
let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/Utils.h b/mlir/include/mlir/Dialect/OpenMP/Utils.h
new file mode 100644
index 00000000000000..f79e10b1e5ab38
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Utils.h
@@ -0,0 +1,19 @@
+//===- Utils.h - Utils for the OpenMP MLIR Dialect --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_UTILS_H_
+#define MLIR_DIALECT_OPENMP_UTILS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+ mlir::MLIRContext *context);
+} // namespace mlir::omp::utils
+
+#endif // MLIR_DIALECT_OPENMP_UTILS_H_
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d3445c151c..809bd1306563bd 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIROpenMPDialect
IR/OpenMPDialect.cpp
+ IR/Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 94e71e089d4b18..4a13272b8f4a83 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
llvm::SmallVectorImpl<Type> &types;
ArrayAttr &syms;
+ ArrayAttr *mapIndices;
PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
- SmallVectorImpl<Type> &types, ArrayAttr &syms)
- : vars(vars), types(types), syms(syms) {}
+ SmallVectorImpl<Type> &types, ArrayAttr &syms,
+ ArrayAttr *mapIndices=nullptr)
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
struct ReductionParseArgs {
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types,
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
- ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+ ArrayAttr *symbols = nullptr, ArrayAttr *mapIndices = nullptr,
+ DenseBoolArrayAttr *byref = nullptr) {
SmallVector<SymbolRefAttr> symbolVec;
+ SmallVector<int64_t> mapIndicesVec;
SmallVector<bool> isByRefVec;
unsigned regionArgOffset = regionPrivateArgs.size();
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
parser.parseArgument(regionPrivateArgs.emplace_back()))
return failure();
+ if (mapIndices) {
+ if (parser.parseOptionalLSquare().succeeded()) {
+ if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
+ parser.parseInteger(mapIndicesVec.emplace_back()) ||
+ parser.parseRSquare())
+ return failure();
+ } else
+ mapIndicesVec.push_back(-1);
+ }
+
return success();
}))
return failure();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
*symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
}
+ if (!mapIndicesVec.empty())
+ *mapIndices = utils::makeI64ArrayAttr(mapIndicesVec, parser.getContext());
+
if (byref)
*byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
static ParseResult parseBlockArgClause(
OpAsmParser &parser,
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
- StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
+ StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
if (succeeded(parser.parseOptionalKeyword(keyword))) {
- if (!reductionArgs)
+ if (!privateArgs)
return failure();
- if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
- reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms)))
+ if (failed(parseClauseWithRegionArgs(
+ parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
+ &privateArgs->syms, privateArgs->mapIndices)))
return failure();
}
return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
if (failed(parseClauseWithRegionArgs(
parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
- &reductionArgs->syms, &reductionArgs->byref)))
+ &reductionArgs->syms, /*mapIndices=*/nullptr,
+ &reductionArgs->byref)))
return failure();
}
return success();
@@ -674,12 +693,13 @@ static ParseResult parseInReductionMapPrivateRegion(
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
SmallVectorImpl<Type> &mapTypes,
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
- llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+ llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+ ArrayAttr &privateMaps) {
AllRegionParseArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, &privateMaps);
return parseBlockArgRegion(parser, region, args);
}
@@ -776,8 +796,10 @@ struct PrivatePrintArgs {
ValueRange vars;
TypeRange types;
ArrayAttr syms;
- PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
- : vars(vars), types(types), syms(syms) {}
+ ArrayAttr mapIndices;
+ PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
+ ArrayAttr mapIndices)
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
};
struct ReductionPrintArgs {
ValueRange vars;
@@ -804,6 +826,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
ValueRange argsSubrange,
ValueRange operands, TypeRange types,
ArrayAttr symbols = nullptr,
+ ArrayAttr mapIndices = nullptr,
DenseBoolArrayAttr byref = nullptr) {
if (argsSubrange.empty())
return;
@@ -815,21 +838,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
symbols = ArrayAttr::get(ctx, values);
}
+ if (!mapIndices) {
+ llvm::SmallVector<Attribute> values(operands.size(), nullptr);
+ mapIndices = ArrayAttr::get(ctx, values);
+ }
+
if (!byref) {
mlir::SmallVector<bool> values(operands.size(), false);
byref = DenseBoolArrayAttr::get(ctx, values);
}
- llvm::interleaveComma(
- llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
- [&p](auto t) {
- auto [op, arg, sym, isByRef] = t;
- if (isByRef)
- p << "byref ";
- if (sym)
- p << sym << " ";
- p << op << " -> " << arg;
- });
+ llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
+ mapIndices, byref.asArrayRef()),
+ p, [&p](auto t) {
+ auto [op, arg, sym, map, isByRef] = t;
+ if (isByRef)
+ p << "byref ";
+ if (sym)
+ p << sym << " ";
+
+ p << op << " -> " << arg;
+
+ if (map)
+ p << " [map_idx="
+ << llvm::cast<IntegerAttr>(map).getInt() << "]";
+ });
p << " : ";
llvm::interleaveComma(types, p);
p << ") ";
@@ -849,7 +882,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
if (privateArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
privateArgs->vars, privateArgs->types,
- privateArgs->syms);
+ privateArgs->syms, privateArgs->mapIndices);
}
static void
@@ -859,7 +892,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
if (reductionArgs)
printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
reductionArgs->vars, reductionArgs->types,
- reductionArgs->syms, reductionArgs->byref);
+ reductionArgs->syms, nullptr,
+ reductionArgs->byref);
}
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -891,12 +925,13 @@ static void printInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
+ ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+ ArrayAttr privateMaps) {
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
printBlockArgRegion(p, op, region, args);
}
@@ -908,7 +943,7 @@ static void printInReductionPrivateRegion(
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
printBlockArgRegion(p, op, region, args);
}
@@ -921,7 +956,7 @@ static void printInReductionPrivateReductionRegion(
AllRegionPrintArgs args;
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms);
printBlockArgRegion(p, op, region, args);
@@ -931,7 +966,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
ValueRange privateVars, TypeRange privateTypes,
ArrayAttr privateSyms) {
AllRegionPrintArgs args;
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
printBlockArgRegion(p, op, region, args);
}
@@ -941,7 +976,7 @@ static void printPrivateReductionRegion(
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
ArrayAttr reductionSyms) {
AllRegionPrintArgs args;
- args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+ args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
reductionSyms);
printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1691,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+ /*private_maps=*/nullptr);
}
LogicalResult TargetOp::verify() {
diff --git a/mlir/lib/Dialect/OpenMP/IR/Utils.cpp b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
new file mode 100644
index 00000000000000..64e84817c818b2
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for the OpenMP MLIR Dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+ mlir::MLIRContext *context) {
+ llvm::SmallVector<mlir::Attribute, 4> attrs;
+ attrs.reserve(values.size());
+ for (auto &v : values)
+ attrs.push_back(mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64),
+ mlir::APInt(64, v)));
+ return mlir::ArrayAttr::get(context, attrs);
+}
+} // namespace mlir::omp::utils
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c25a6ef4b4849b..94c63dd8e9aa0e 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
return
}
+// CHECK-LABEL: omp_target_private_with_map_idx
+func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+ %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
e218871
to
88e1520
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have to be an index? What if this was modeled as another reference to the same block argument? Everything is de-duplicated in the mlir context anyway. That would remove any risk of getting out of sync between indexes and arguments. So instead of having an array of integer attributes I think you could have a Varadic<OpenMP_PointerLikeType>
- this is just a suggestion.
I also left some comments on the interface for this on the other PR, because I think it is easier to discuss where you can see how it is used.
Thanks for the suggestion. Looking into it ... 👀 |
88e1520
to
074e4d2
Compare
@tblah I gave this a try. But the problem is that not all |
This PR extends the MLIR representation for `omp.target` ops by adding a `map_idx` to `private` vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, the `map_idx` attribute is either not present at all or its value is `-1`.
074e4d2
to
8d770fd
Compare
How stable are the objects that you refer too? If they are not deallocated/reallocated and thus go out of existence and/or change their position in memory, you could keep a reference to the object that you'd like to refer to. So, basically, instead of using a 0-based index, you use the object's address in memory as the index. Would that work? |
For the |
Ping 🔔! Please take a look when you have a chance. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This LGTM, I think the approach is reasonable, especially as you've tried the other suggestions.
However, it might be worth checking if the additional map arguments affect the block argument interface API at all that can be used to retrieve the correct block argument offset vs map arguments, we utilise this to know the correct grouping of map containing clause operations <-> block arguments for TargetOp/TargetDataOp. I don't think it will, but even if private variables don't need block arguments there may be a small chance it disturbs the ordering, @skatrak will more than likely be able to say one way or another! :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for considering other approaches. I'm happy so long as Michael doesn't have any further suggestions.
This PR extends the MLIR representation for
omp.target
ops by adding amap_idx
toprivate
vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, themap_idx
attribute is either not present at all or its value is-1
.This makes matching the private variable to its map info op easier (see #116576 for usage).