Skip to content

[mlir][OpenMP] Annotate private vars with map_idx when needed #116770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 28, 2024

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Nov 19, 2024

This PR extends the MLIR representation for omp.target ops by adding a map_idx to private vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, the map_idx attribute is either not present at all or its value is -1.

This makes matching the private variable to its map info op easier (see #116576 for usage).

@llvmbot llvmbot added mlir flang Flang issues not falling into any other category mlir:openmp flang:fir-hlfir flang:openmp labels Nov 19, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-flang-openmp

Author: Kareem Ergawy (ergawy)

Changes

This PR extends the MLIR representation for omp.target ops by adding a map_idx to private vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, the map_idx attribute is either not present at all or its value is -1.


Patch is 20.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116770.diff

8 Files Affected:

  • (modified) flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp (+13-2)
  • (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 (+6-6)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+5-1)
  • (added) mlir/include/mlir/Dialect/OpenMP/Utils.h (+19)
  • (modified) mlir/lib/Dialect/OpenMP/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+68-32)
  • (added) mlir/lib/Dialect/OpenMP/IR/Utils.cpp (+22)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+24)
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 289e648eed8546..6e537300dfb7f1 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -1,5 +1,4 @@
-//===- MapsForPrivatizedSymbols.cpp
-//-----------------------------------------===//
+//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -28,8 +27,10 @@
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
@@ -124,6 +125,8 @@ class MapsForPrivatizedSymbolsPass
       if (targetOp.getPrivateVars().empty())
         return;
       OperandRange privVars = targetOp.getPrivateVars();
+      llvm::SmallVector<int64_t> privVarMapIdx;
+
       std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
       SmallVector<omp::MapInfoOp, 4> mapInfoOps;
       for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +136,25 @@ class MapsForPrivatizedSymbolsPass
             SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
                 targetOp, privatizerName);
         if (!privatizerNeedsMap(privatizer)) {
+          privVarMapIdx.push_back(-1);
           continue;
         }
+
+        privVarMapIdx.push_back(targetOp.getMapVars().size() +
+                                mapInfoOps.size());
+
         builder.setInsertionPoint(targetOp);
         Location loc = targetOp.getLoc();
         omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
         mapInfoOps.push_back(mapInfoOp);
+
         LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
         LLVM_DEBUG(mapInfoOp.dump());
       }
       if (!mapInfoOps.empty()) {
         mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+        targetOp.setPrivateMapsAttr(mlir::omp::utils::makeI64ArrayAttr(
+            privVarMapIdx, targetOp.getContext()));
       }
     });
     if (!mapInfoOpsForTarget.empty()) {
diff --git a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
index b0c76ff3845f83..602e98975e9dc5 100644
--- a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
+++ b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
@@ -171,12 +171,12 @@ end subroutine target_allocatable
 ! CHECK_SAME        %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
 ! CHECK-SAME        !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
 ! CHECK-SAME:     private(
-! CHECK-SAME:       @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
-! CHECK-SAME:       @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
-! CHECK-SAME:       @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
-! CHECK-SAME:       @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
-! CHECK-SAME:       @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
-! CHECK-SAME:       @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
+! CHECK-SAME:       @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
+! CHECK-SAME:       @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
+! CHECK-SAME:       @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
 ! CHECK-SAME:       !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
 ! CHECK-NOT:      fir.alloca
 ! CHECK:          hlfir.declare %[[ALLOC_ARG]]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 156e6eb371b85d..31ecbea8e0c211 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1227,6 +1227,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
     a device, if it is 0 then the target region is executed on the host device.
   }] # clausesDescription;
 
+  let arguments = !con(clausesArgs,
+                       (ins OptionalAttr<I64ArrayAttr>:$private_maps));
+
   let builders = [
     OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
   ];
@@ -1239,7 +1242,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
     custom<InReductionMapPrivateRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
-        $private_vars, type($private_vars), $private_syms) attr-dict
+        $private_vars, type($private_vars), $private_syms, $private_maps)
+        attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/Utils.h b/mlir/include/mlir/Dialect/OpenMP/Utils.h
new file mode 100644
index 00000000000000..f79e10b1e5ab38
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Utils.h
@@ -0,0 +1,19 @@
+//===- Utils.h - Utils for the OpenMP MLIR Dialect --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_UTILS_H_
+#define MLIR_DIALECT_OPENMP_UTILS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+                                 mlir::MLIRContext *context);
+} // namespace mlir::omp::utils
+
+#endif // MLIR_DIALECT_OPENMP_UTILS_H_
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d3445c151c..809bd1306563bd 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIROpenMPDialect
   IR/OpenMPDialect.cpp
+  IR/Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 94e71e089d4b18..4a13272b8f4a83 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   llvm::SmallVectorImpl<Type> &types;
   ArrayAttr &syms;
+  ArrayAttr *mapIndices;
   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
-                   SmallVectorImpl<Type> &types, ArrayAttr &syms)
-      : vars(vars), types(types), syms(syms) {}
+                   SmallVectorImpl<Type> &types, ArrayAttr &syms,
+                   ArrayAttr *mapIndices=nullptr)
+      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionParseArgs {
   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
     SmallVectorImpl<Type> &types,
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
-    ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+    ArrayAttr *symbols = nullptr, ArrayAttr *mapIndices = nullptr,
+    DenseBoolArrayAttr *byref = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
+  SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
   unsigned regionArgOffset = regionPrivateArgs.size();
 
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
             parser.parseArgument(regionPrivateArgs.emplace_back()))
           return failure();
 
+        if (mapIndices) {
+          if (parser.parseOptionalLSquare().succeeded()) {
+            if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
+                parser.parseInteger(mapIndicesVec.emplace_back()) ||
+                parser.parseRSquare())
+              return failure();
+          } else
+            mapIndicesVec.push_back(-1);
+        }
+
         return success();
       }))
     return failure();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
     *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
   }
 
+  if (!mapIndicesVec.empty())
+    *mapIndices = utils::makeI64ArrayAttr(mapIndicesVec, parser.getContext());
+
   if (byref)
     *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
 
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
 static ParseResult parseBlockArgClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
-    StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
+    StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
-    if (!reductionArgs)
+    if (!privateArgs)
       return failure();
 
-    if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
-                                         reductionArgs->types, entryBlockArgs,
-                                         &reductionArgs->syms)))
+    if (failed(parseClauseWithRegionArgs(
+            parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
+            &privateArgs->syms, privateArgs->mapIndices)))
       return failure();
   }
   return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
 
     if (failed(parseClauseWithRegionArgs(
             parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
-            &reductionArgs->syms, &reductionArgs->byref)))
+            &reductionArgs->syms, /*mapIndices=*/nullptr,
+            &reductionArgs->byref)))
       return failure();
   }
   return success();
@@ -674,12 +693,13 @@ static ParseResult parseInReductionMapPrivateRegion(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, &privateMaps);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -776,8 +796,10 @@ struct PrivatePrintArgs {
   ValueRange vars;
   TypeRange types;
   ArrayAttr syms;
-  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
-      : vars(vars), types(types), syms(syms) {}
+  ArrayAttr mapIndices;
+  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
+                   ArrayAttr mapIndices)
+      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
   ValueRange vars;
@@ -804,6 +826,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
                                       ValueRange argsSubrange,
                                       ValueRange operands, TypeRange types,
                                       ArrayAttr symbols = nullptr,
+                                      ArrayAttr mapIndices = nullptr,
                                       DenseBoolArrayAttr byref = nullptr) {
   if (argsSubrange.empty())
     return;
@@ -815,21 +838,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
     symbols = ArrayAttr::get(ctx, values);
   }
 
+  if (!mapIndices) {
+    llvm::SmallVector<Attribute> values(operands.size(), nullptr);
+    mapIndices = ArrayAttr::get(ctx, values);
+  }
+
   if (!byref) {
     mlir::SmallVector<bool> values(operands.size(), false);
     byref = DenseBoolArrayAttr::get(ctx, values);
   }
 
-  llvm::interleaveComma(
-      llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
-      [&p](auto t) {
-        auto [op, arg, sym, isByRef] = t;
-        if (isByRef)
-          p << "byref ";
-        if (sym)
-          p << sym << " ";
-        p << op << " -> " << arg;
-      });
+  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
+                                        mapIndices, byref.asArrayRef()),
+                        p, [&p](auto t) {
+                          auto [op, arg, sym, map, isByRef] = t;
+                          if (isByRef)
+                            p << "byref ";
+                          if (sym)
+                            p << sym << " ";
+
+                          p << op << " -> " << arg;
+
+                          if (map)
+                            p << " [map_idx="
+                              << llvm::cast<IntegerAttr>(map).getInt() << "]";
+                        });
   p << " : ";
   llvm::interleaveComma(types, p);
   p << ") ";
@@ -849,7 +882,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
   if (privateArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               privateArgs->vars, privateArgs->types,
-                              privateArgs->syms);
+                              privateArgs->syms, privateArgs->mapIndices);
 }
 
 static void
@@ -859,7 +892,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
   if (reductionArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               reductionArgs->vars, reductionArgs->types,
-                              reductionArgs->syms, reductionArgs->byref);
+                              reductionArgs->syms, nullptr,
+                              reductionArgs->byref);
 }
 
 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -891,12 +925,13 @@ static void printInReductionMapPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
-    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
+    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+    ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -908,7 +943,7 @@ static void printInReductionPrivateRegion(
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -921,7 +956,7 @@ static void printInReductionPrivateReductionRegion(
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms);
   printBlockArgRegion(p, op, region, args);
@@ -931,7 +966,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                ValueRange privateVars, TypeRange privateTypes,
                                ArrayAttr privateSyms) {
   AllRegionPrintArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -941,7 +976,7 @@ static void printPrivateReductionRegion(
     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
     ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms);
   printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1691,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
-                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
+                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+                  /*private_maps=*/nullptr);
 }
 
 LogicalResult TargetOp::verify() {
diff --git a/mlir/lib/Dialect/OpenMP/IR/Utils.cpp b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
new file mode 100644
index 00000000000000..64e84817c818b2
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for the OpenMP MLIR Dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+                                 mlir::MLIRContext *context) {
+  llvm::SmallVector<mlir::Attribute, 4> attrs;
+  attrs.reserve(values.size());
+  for (auto &v : values)
+    attrs.push_back(mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64),
+                                           mlir::APInt(64, v)));
+  return mlir::ArrayAttr::get(context, attrs);
+}
+} // namespace mlir::omp::utils
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c25a6ef4b4849b..94c63dd8e9aa0e 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
   return
 }
 
+// CHECK-LABEL: omp_target_private_with_map_idx
+func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+  %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-mlir

Author: Kareem Ergawy (ergawy)

Changes

This PR extends the MLIR representation for omp.target ops by adding a map_idx to private vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, the map_idx attribute is either not present at all or its value is -1.


Patch is 20.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116770.diff

8 Files Affected:

  • (modified) flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp (+13-2)
  • (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 (+6-6)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+5-1)
  • (added) mlir/include/mlir/Dialect/OpenMP/Utils.h (+19)
  • (modified) mlir/lib/Dialect/OpenMP/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+68-32)
  • (added) mlir/lib/Dialect/OpenMP/IR/Utils.cpp (+22)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+24)
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 289e648eed8546..6e537300dfb7f1 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -1,5 +1,4 @@
-//===- MapsForPrivatizedSymbols.cpp
-//-----------------------------------------===//
+//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -28,8 +27,10 @@
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
@@ -124,6 +125,8 @@ class MapsForPrivatizedSymbolsPass
       if (targetOp.getPrivateVars().empty())
         return;
       OperandRange privVars = targetOp.getPrivateVars();
+      llvm::SmallVector<int64_t> privVarMapIdx;
+
       std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
       SmallVector<omp::MapInfoOp, 4> mapInfoOps;
       for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +136,25 @@ class MapsForPrivatizedSymbolsPass
             SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
                 targetOp, privatizerName);
         if (!privatizerNeedsMap(privatizer)) {
+          privVarMapIdx.push_back(-1);
           continue;
         }
+
+        privVarMapIdx.push_back(targetOp.getMapVars().size() +
+                                mapInfoOps.size());
+
         builder.setInsertionPoint(targetOp);
         Location loc = targetOp.getLoc();
         omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
         mapInfoOps.push_back(mapInfoOp);
+
         LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
         LLVM_DEBUG(mapInfoOp.dump());
       }
       if (!mapInfoOps.empty()) {
         mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+        targetOp.setPrivateMapsAttr(mlir::omp::utils::makeI64ArrayAttr(
+            privVarMapIdx, targetOp.getContext()));
       }
     });
     if (!mapInfoOpsForTarget.empty()) {
diff --git a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
index b0c76ff3845f83..602e98975e9dc5 100644
--- a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
+++ b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
@@ -171,12 +171,12 @@ end subroutine target_allocatable
 ! CHECK_SAME        %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
 ! CHECK-SAME        !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
 ! CHECK-SAME:     private(
-! CHECK-SAME:       @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
-! CHECK-SAME:       @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
-! CHECK-SAME:       @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
-! CHECK-SAME:       @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
-! CHECK-SAME:       @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
-! CHECK-SAME:       @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
+! CHECK-SAME:       @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
+! CHECK-SAME:       @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
+! CHECK-SAME:       @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
 ! CHECK-SAME:       !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
 ! CHECK-NOT:      fir.alloca
 ! CHECK:          hlfir.declare %[[ALLOC_ARG]]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 156e6eb371b85d..31ecbea8e0c211 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1227,6 +1227,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
     a device, if it is 0 then the target region is executed on the host device.
   }] # clausesDescription;
 
+  let arguments = !con(clausesArgs,
+                       (ins OptionalAttr<I64ArrayAttr>:$private_maps));
+
   let builders = [
     OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
   ];
@@ -1239,7 +1242,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
     custom<InReductionMapPrivateRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
-        $private_vars, type($private_vars), $private_syms) attr-dict
+        $private_vars, type($private_vars), $private_syms, $private_maps)
+        attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/Utils.h b/mlir/include/mlir/Dialect/OpenMP/Utils.h
new file mode 100644
index 00000000000000..f79e10b1e5ab38
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Utils.h
@@ -0,0 +1,19 @@
+//===- Utils.h - Utils for the OpenMP MLIR Dialect --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_UTILS_H_
+#define MLIR_DIALECT_OPENMP_UTILS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+                                 mlir::MLIRContext *context);
+} // namespace mlir::omp::utils
+
+#endif // MLIR_DIALECT_OPENMP_UTILS_H_
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d3445c151c..809bd1306563bd 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIROpenMPDialect
   IR/OpenMPDialect.cpp
+  IR/Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 94e71e089d4b18..4a13272b8f4a83 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   llvm::SmallVectorImpl<Type> &types;
   ArrayAttr &syms;
+  ArrayAttr *mapIndices;
   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
-                   SmallVectorImpl<Type> &types, ArrayAttr &syms)
-      : vars(vars), types(types), syms(syms) {}
+                   SmallVectorImpl<Type> &types, ArrayAttr &syms,
+                   ArrayAttr *mapIndices=nullptr)
+      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionParseArgs {
   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
     SmallVectorImpl<Type> &types,
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
-    ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+    ArrayAttr *symbols = nullptr, ArrayAttr *mapIndices = nullptr,
+    DenseBoolArrayAttr *byref = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
+  SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
   unsigned regionArgOffset = regionPrivateArgs.size();
 
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
             parser.parseArgument(regionPrivateArgs.emplace_back()))
           return failure();
 
+        if (mapIndices) {
+          if (parser.parseOptionalLSquare().succeeded()) {
+            if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
+                parser.parseInteger(mapIndicesVec.emplace_back()) ||
+                parser.parseRSquare())
+              return failure();
+          } else
+            mapIndicesVec.push_back(-1);
+        }
+
         return success();
       }))
     return failure();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
     *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
   }
 
+  if (!mapIndicesVec.empty())
+    *mapIndices = utils::makeI64ArrayAttr(mapIndicesVec, parser.getContext());
+
   if (byref)
     *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
 
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
 static ParseResult parseBlockArgClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
-    StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
+    StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
-    if (!reductionArgs)
+    if (!privateArgs)
       return failure();
 
-    if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
-                                         reductionArgs->types, entryBlockArgs,
-                                         &reductionArgs->syms)))
+    if (failed(parseClauseWithRegionArgs(
+            parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
+            &privateArgs->syms, privateArgs->mapIndices)))
       return failure();
   }
   return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
 
     if (failed(parseClauseWithRegionArgs(
             parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
-            &reductionArgs->syms, &reductionArgs->byref)))
+            &reductionArgs->syms, /*mapIndices=*/nullptr,
+            &reductionArgs->byref)))
       return failure();
   }
   return success();
@@ -674,12 +693,13 @@ static ParseResult parseInReductionMapPrivateRegion(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, &privateMaps);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -776,8 +796,10 @@ struct PrivatePrintArgs {
   ValueRange vars;
   TypeRange types;
   ArrayAttr syms;
-  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
-      : vars(vars), types(types), syms(syms) {}
+  ArrayAttr mapIndices;
+  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
+                   ArrayAttr mapIndices)
+      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
   ValueRange vars;
@@ -804,6 +826,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
                                       ValueRange argsSubrange,
                                       ValueRange operands, TypeRange types,
                                       ArrayAttr symbols = nullptr,
+                                      ArrayAttr mapIndices = nullptr,
                                       DenseBoolArrayAttr byref = nullptr) {
   if (argsSubrange.empty())
     return;
@@ -815,21 +838,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
     symbols = ArrayAttr::get(ctx, values);
   }
 
+  if (!mapIndices) {
+    llvm::SmallVector<Attribute> values(operands.size(), nullptr);
+    mapIndices = ArrayAttr::get(ctx, values);
+  }
+
   if (!byref) {
     mlir::SmallVector<bool> values(operands.size(), false);
     byref = DenseBoolArrayAttr::get(ctx, values);
   }
 
-  llvm::interleaveComma(
-      llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
-      [&p](auto t) {
-        auto [op, arg, sym, isByRef] = t;
-        if (isByRef)
-          p << "byref ";
-        if (sym)
-          p << sym << " ";
-        p << op << " -> " << arg;
-      });
+  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
+                                        mapIndices, byref.asArrayRef()),
+                        p, [&p](auto t) {
+                          auto [op, arg, sym, map, isByRef] = t;
+                          if (isByRef)
+                            p << "byref ";
+                          if (sym)
+                            p << sym << " ";
+
+                          p << op << " -> " << arg;
+
+                          if (map)
+                            p << " [map_idx="
+                              << llvm::cast<IntegerAttr>(map).getInt() << "]";
+                        });
   p << " : ";
   llvm::interleaveComma(types, p);
   p << ") ";
@@ -849,7 +882,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
   if (privateArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               privateArgs->vars, privateArgs->types,
-                              privateArgs->syms);
+                              privateArgs->syms, privateArgs->mapIndices);
 }
 
 static void
@@ -859,7 +892,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
   if (reductionArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               reductionArgs->vars, reductionArgs->types,
-                              reductionArgs->syms, reductionArgs->byref);
+                              reductionArgs->syms, nullptr,
+                              reductionArgs->byref);
 }
 
 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -891,12 +925,13 @@ static void printInReductionMapPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
-    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
+    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+    ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -908,7 +943,7 @@ static void printInReductionPrivateRegion(
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -921,7 +956,7 @@ static void printInReductionPrivateReductionRegion(
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms);
   printBlockArgRegion(p, op, region, args);
@@ -931,7 +966,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                ValueRange privateVars, TypeRange privateTypes,
                                ArrayAttr privateSyms) {
   AllRegionPrintArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -941,7 +976,7 @@ static void printPrivateReductionRegion(
     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
     ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms);
   printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1691,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
-                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
+                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+                  /*private_maps=*/nullptr);
 }
 
 LogicalResult TargetOp::verify() {
diff --git a/mlir/lib/Dialect/OpenMP/IR/Utils.cpp b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
new file mode 100644
index 00000000000000..64e84817c818b2
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for the OpenMP MLIR Dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+                                 mlir::MLIRContext *context) {
+  llvm::SmallVector<mlir::Attribute, 4> attrs;
+  attrs.reserve(values.size());
+  for (auto &v : values)
+    attrs.push_back(mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64),
+                                           mlir::APInt(64, v)));
+  return mlir::ArrayAttr::get(context, attrs);
+}
+} // namespace mlir::omp::utils
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c25a6ef4b4849b..94c63dd8e9aa0e 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
   return
 }
 
+// CHECK-LABEL: omp_target_private_with_map_idx
+func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+  %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

This PR extends the MLIR representation for omp.target ops by adding a map_idx to private vars. This annotation stores the index of the map info operand corresponding to the private var. If the variable does not have a map operand, the map_idx attribute is either not present at all or its value is -1.


Patch is 20.97 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116770.diff

8 Files Affected:

  • (modified) flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp (+13-2)
  • (modified) flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 (+6-6)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+5-1)
  • (added) mlir/include/mlir/Dialect/OpenMP/Utils.h (+19)
  • (modified) mlir/lib/Dialect/OpenMP/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+68-32)
  • (added) mlir/lib/Dialect/OpenMP/IR/Utils.cpp (+22)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+24)
diff --git a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
index 289e648eed8546..6e537300dfb7f1 100644
--- a/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
+++ b/flang/lib/Optimizer/OpenMP/MapsForPrivatizedSymbols.cpp
@@ -1,5 +1,4 @@
-//===- MapsForPrivatizedSymbols.cpp
-//-----------------------------------------===//
+//===- MapsForPrivatizedSymbols.cpp ---------------------------------------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
@@ -28,8 +27,10 @@
 #include "flang/Optimizer/Dialect/Support/KindMapping.h"
 #include "flang/Optimizer/HLFIR/HLFIROps.h"
 #include "flang/Optimizer/OpenMP/Passes.h"
+
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/SymbolTable.h"
 #include "mlir/Pass/Pass.h"
@@ -124,6 +125,8 @@ class MapsForPrivatizedSymbolsPass
       if (targetOp.getPrivateVars().empty())
         return;
       OperandRange privVars = targetOp.getPrivateVars();
+      llvm::SmallVector<int64_t> privVarMapIdx;
+
       std::optional<ArrayAttr> privSyms = targetOp.getPrivateSyms();
       SmallVector<omp::MapInfoOp, 4> mapInfoOps;
       for (auto [privVar, privSym] : llvm::zip_equal(privVars, *privSyms)) {
@@ -133,17 +136,25 @@ class MapsForPrivatizedSymbolsPass
             SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
                 targetOp, privatizerName);
         if (!privatizerNeedsMap(privatizer)) {
+          privVarMapIdx.push_back(-1);
           continue;
         }
+
+        privVarMapIdx.push_back(targetOp.getMapVars().size() +
+                                mapInfoOps.size());
+
         builder.setInsertionPoint(targetOp);
         Location loc = targetOp.getLoc();
         omp::MapInfoOp mapInfoOp = createMapInfo(loc, privVar, builder);
         mapInfoOps.push_back(mapInfoOp);
+
         LLVM_DEBUG(llvm::dbgs() << "MapsForPrivatizedSymbolsPass created ->\n");
         LLVM_DEBUG(mapInfoOp.dump());
       }
       if (!mapInfoOps.empty()) {
         mapInfoOpsForTarget.insert({targetOp.getOperation(), mapInfoOps});
+        targetOp.setPrivateMapsAttr(mlir::omp::utils::makeI64ArrayAttr(
+            privVarMapIdx, targetOp.getContext()));
       }
     });
     if (!mapInfoOpsForTarget.empty()) {
diff --git a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90 b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
index b0c76ff3845f83..602e98975e9dc5 100644
--- a/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
+++ b/flang/test/Lower/OpenMP/DelayedPrivatization/target-private-multiple-variables.f90
@@ -171,12 +171,12 @@ end subroutine target_allocatable
 ! CHECK_SAME        %[[CHAR_VAR_DESC_MAP]] -> %[[MAPPED_ARG3:.[^,]+]] :
 ! CHECK-SAME        !fir.ref<i32>, !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.array<?xf32>>>, !fir.ref<!fir.boxchar<1>>)
 ! CHECK-SAME:     private(
-! CHECK-SAME:       @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]],
-! CHECK-SAME:       @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]],
-! CHECK-SAME:       @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]],
-! CHECK-SAME:       @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]],
-! CHECK-SAME:       @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]],
-! CHECK-SAME:       @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] :
+! CHECK-SAME:       @[[ALLOC_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ALLOC_ARG:[^,]+]] [map_idx=1],
+! CHECK-SAME:       @[[REAL_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[REAL_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[LB_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[LB_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[ARR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[ARR_ARG:[^,]+]] [map_idx=2],
+! CHECK-SAME:       @[[COMP_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[COMP_ARG:[^,]+]] [map_idx=-1],
+! CHECK-SAME:       @[[CHAR_PRIVATIZER_SYM]] %{{[^[:space:]]+}}#0 -> %[[CHAR_ARG:[^,]+]] [map_idx=3] :
 ! CHECK-SAME:       !fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<f32>, !fir.ref<i64>, !fir.box<!fir.array<?xf32>>, !fir.ref<complex<f32>>, !fir.boxchar<1>) {
 ! CHECK-NOT:      fir.alloca
 ! CHECK:          hlfir.declare %[[ALLOC_ARG]]
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 156e6eb371b85d..31ecbea8e0c211 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1227,6 +1227,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
     a device, if it is 0 then the target region is executed on the host device.
   }] # clausesDescription;
 
+  let arguments = !con(clausesArgs,
+                       (ins OptionalAttr<I64ArrayAttr>:$private_maps));
+
   let builders = [
     OpBuilder<(ins CArg<"const TargetOperands &">:$clauses)>
   ];
@@ -1239,7 +1242,8 @@ def TargetOp : OpenMP_Op<"target", traits = [
     custom<InReductionMapPrivateRegion>(
         $region, $in_reduction_vars, type($in_reduction_vars),
         $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
-        $private_vars, type($private_vars), $private_syms) attr-dict
+        $private_vars, type($private_vars), $private_syms, $private_maps)
+        attr-dict
   }];
 
   let hasVerifier = 1;
diff --git a/mlir/include/mlir/Dialect/OpenMP/Utils.h b/mlir/include/mlir/Dialect/OpenMP/Utils.h
new file mode 100644
index 00000000000000..f79e10b1e5ab38
--- /dev/null
+++ b/mlir/include/mlir/Dialect/OpenMP/Utils.h
@@ -0,0 +1,19 @@
+//===- Utils.h - Utils for the OpenMP MLIR Dialect --------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_OPENMP_UTILS_H_
+#define MLIR_DIALECT_OPENMP_UTILS_H_
+
+#include "mlir/IR/BuiltinAttributes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+                                 mlir::MLIRContext *context);
+} // namespace mlir::omp::utils
+
+#endif // MLIR_DIALECT_OPENMP_UTILS_H_
diff --git a/mlir/lib/Dialect/OpenMP/CMakeLists.txt b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
index 57a6d3445c151c..809bd1306563bd 100644
--- a/mlir/lib/Dialect/OpenMP/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenMP/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_mlir_dialect_library(MLIROpenMPDialect
   IR/OpenMPDialect.cpp
+  IR/Utils.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/OpenMP
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 94e71e089d4b18..4a13272b8f4a83 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
+#include "mlir/Dialect/OpenMP/Utils.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
   llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
   llvm::SmallVectorImpl<Type> &types;
   ArrayAttr &syms;
+  ArrayAttr *mapIndices;
   PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
-                   SmallVectorImpl<Type> &types, ArrayAttr &syms)
-      : vars(vars), types(types), syms(syms) {}
+                   SmallVectorImpl<Type> &types, ArrayAttr &syms,
+                   ArrayAttr *mapIndices=nullptr)
+      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionParseArgs {
   SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
     SmallVectorImpl<Type> &types,
     SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
-    ArrayAttr *symbols = nullptr, DenseBoolArrayAttr *byref = nullptr) {
+    ArrayAttr *symbols = nullptr, ArrayAttr *mapIndices = nullptr,
+    DenseBoolArrayAttr *byref = nullptr) {
   SmallVector<SymbolRefAttr> symbolVec;
+  SmallVector<int64_t> mapIndicesVec;
   SmallVector<bool> isByRefVec;
   unsigned regionArgOffset = regionPrivateArgs.size();
 
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
             parser.parseArgument(regionPrivateArgs.emplace_back()))
           return failure();
 
+        if (mapIndices) {
+          if (parser.parseOptionalLSquare().succeeded()) {
+            if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
+                parser.parseInteger(mapIndicesVec.emplace_back()) ||
+                parser.parseRSquare())
+              return failure();
+          } else
+            mapIndicesVec.push_back(-1);
+        }
+
         return success();
       }))
     return failure();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
     *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
   }
 
+  if (!mapIndicesVec.empty())
+    *mapIndices = utils::makeI64ArrayAttr(mapIndicesVec, parser.getContext());
+
   if (byref)
     *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
 
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
 static ParseResult parseBlockArgClause(
     OpAsmParser &parser,
     llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
-    StringRef keyword, std::optional<PrivateParseArgs> reductionArgs) {
+    StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
   if (succeeded(parser.parseOptionalKeyword(keyword))) {
-    if (!reductionArgs)
+    if (!privateArgs)
       return failure();
 
-    if (failed(parseClauseWithRegionArgs(parser, reductionArgs->vars,
-                                         reductionArgs->types, entryBlockArgs,
-                                         &reductionArgs->syms)))
+    if (failed(parseClauseWithRegionArgs(
+            parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
+            &privateArgs->syms, privateArgs->mapIndices)))
       return failure();
   }
   return success();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
 
     if (failed(parseClauseWithRegionArgs(
             parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
-            &reductionArgs->syms, &reductionArgs->byref)))
+            &reductionArgs->syms, /*mapIndices=*/nullptr,
+            &reductionArgs->byref)))
       return failure();
   }
   return success();
@@ -674,12 +693,13 @@ static ParseResult parseInReductionMapPrivateRegion(
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
     SmallVectorImpl<Type> &mapTypes,
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
-    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
+    llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
+    ArrayAttr &privateMaps) {
   AllRegionParseArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, &privateMaps);
   return parseBlockArgRegion(parser, region, args);
 }
 
@@ -776,8 +796,10 @@ struct PrivatePrintArgs {
   ValueRange vars;
   TypeRange types;
   ArrayAttr syms;
-  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms)
-      : vars(vars), types(types), syms(syms) {}
+  ArrayAttr mapIndices;
+  PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
+                   ArrayAttr mapIndices)
+      : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
 };
 struct ReductionPrintArgs {
   ValueRange vars;
@@ -804,6 +826,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
                                       ValueRange argsSubrange,
                                       ValueRange operands, TypeRange types,
                                       ArrayAttr symbols = nullptr,
+                                      ArrayAttr mapIndices = nullptr,
                                       DenseBoolArrayAttr byref = nullptr) {
   if (argsSubrange.empty())
     return;
@@ -815,21 +838,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
     symbols = ArrayAttr::get(ctx, values);
   }
 
+  if (!mapIndices) {
+    llvm::SmallVector<Attribute> values(operands.size(), nullptr);
+    mapIndices = ArrayAttr::get(ctx, values);
+  }
+
   if (!byref) {
     mlir::SmallVector<bool> values(operands.size(), false);
     byref = DenseBoolArrayAttr::get(ctx, values);
   }
 
-  llvm::interleaveComma(
-      llvm::zip_equal(operands, argsSubrange, symbols, byref.asArrayRef()), p,
-      [&p](auto t) {
-        auto [op, arg, sym, isByRef] = t;
-        if (isByRef)
-          p << "byref ";
-        if (sym)
-          p << sym << " ";
-        p << op << " -> " << arg;
-      });
+  llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
+                                        mapIndices, byref.asArrayRef()),
+                        p, [&p](auto t) {
+                          auto [op, arg, sym, map, isByRef] = t;
+                          if (isByRef)
+                            p << "byref ";
+                          if (sym)
+                            p << sym << " ";
+
+                          p << op << " -> " << arg;
+
+                          if (map)
+                            p << " [map_idx="
+                              << llvm::cast<IntegerAttr>(map).getInt() << "]";
+                        });
   p << " : ";
   llvm::interleaveComma(types, p);
   p << ") ";
@@ -849,7 +882,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
   if (privateArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               privateArgs->vars, privateArgs->types,
-                              privateArgs->syms);
+                              privateArgs->syms, privateArgs->mapIndices);
 }
 
 static void
@@ -859,7 +892,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
   if (reductionArgs)
     printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
                               reductionArgs->vars, reductionArgs->types,
-                              reductionArgs->syms, reductionArgs->byref);
+                              reductionArgs->syms, nullptr,
+                              reductionArgs->byref);
 }
 
 static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
@@ -891,12 +925,13 @@ static void printInReductionMapPrivateRegion(
     OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
-    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
+    ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
+    ArrayAttr privateMaps) {
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, privateMaps);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -908,7 +943,7 @@ static void printInReductionPrivateRegion(
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -921,7 +956,7 @@ static void printInReductionPrivateReductionRegion(
   AllRegionPrintArgs args;
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms);
   printBlockArgRegion(p, op, region, args);
@@ -931,7 +966,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region,
                                ValueRange privateVars, TypeRange privateTypes,
                                ArrayAttr privateSyms) {
   AllRegionPrintArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   printBlockArgRegion(p, op, region, args);
 }
 
@@ -941,7 +976,7 @@ static void printPrivateReductionRegion(
     TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
     ArrayAttr reductionSyms) {
   AllRegionPrintArgs args;
-  args.privateArgs.emplace(privateVars, privateTypes, privateSyms);
+  args.privateArgs.emplace(privateVars, privateTypes, privateSyms, nullptr);
   args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
                              reductionSyms);
   printBlockArgRegion(p, op, region, args);
@@ -1656,7 +1691,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
-                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit);
+                  makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
+                  /*private_maps=*/nullptr);
 }
 
 LogicalResult TargetOp::verify() {
diff --git a/mlir/lib/Dialect/OpenMP/IR/Utils.cpp b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
new file mode 100644
index 00000000000000..64e84817c818b2
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/IR/Utils.cpp
@@ -0,0 +1,22 @@
+//===- Utils.cpp - Utils for the OpenMP MLIR Dialect ------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenMP/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir::omp::utils {
+mlir::ArrayAttr makeI64ArrayAttr(llvm::ArrayRef<int64_t> values,
+                                 mlir::MLIRContext *context) {
+  llvm::SmallVector<mlir::Attribute, 4> attrs;
+  attrs.reserve(values.size());
+  for (auto &v : values)
+    attrs.push_back(mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64),
+                                           mlir::APInt(64, v)));
+  return mlir::ArrayAttr::get(context, attrs);
+}
+} // namespace mlir::omp::utils
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c25a6ef4b4849b..94c63dd8e9aa0e 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2750,6 +2750,30 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
   return
 }
 
+// CHECK-LABEL: omp_target_private_with_map_idx
+func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_var: !llvm.ptr) -> () {
+  %mapv1 = omp.map.info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>...
[truncated]

Copy link

github-actions bot commented Nov 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have to be an index? What if this was modeled as another reference to the same block argument? Everything is de-duplicated in the mlir context anyway. That would remove any risk of getting out of sync between indexes and arguments. So instead of having an array of integer attributes I think you could have a Varadic<OpenMP_PointerLikeType> - this is just a suggestion.

I also left some comments on the interface for this on the other PR, because I think it is easier to discuss where you can see how it is used.

@ergawy
Copy link
Member Author

ergawy commented Nov 21, 2024

Does this have to be an index? What if this was modeled as another reference to the same block argument?

Thanks for the suggestion. Looking into it ... 👀

@ergawy ergawy force-pushed the add_map_idx_to_private_vars branch from 88e1520 to 074e4d2 Compare November 21, 2024 08:00
@ergawy
Copy link
Member Author

ergawy commented Nov 21, 2024

Does this have to be an index? What if this was modeled as another reference to the same block argument?

@tblah I gave this a try. But the problem is that not all private operands required a map. So, for the private operand range, we have to construct a corresponding range where we can represent that some of the elements might be missing. This is not possible to with another operand range (i.e. Varadic<OpenMP_PointerLikeType> as you suggest) if we represent the private input data as individual arrays like we do now (i.e. Variadic<AnyType>:$private_vars, OptionalAttr<SymbolRefArrayAttr>:$private_syms). This is because some of the values of that new range will have to filled as empty values. The index array attribute gives that by using an index value that does not make sense to be an index like -1.

This PR extends the MLIR representation for `omp.target` ops by adding a
`map_idx` to `private` vars. This annotation stores the index of the map
info operand corresponding to the private var. If the variable does not
have a map operand, the `map_idx` attribute is either not present at all
or its value is `-1`.
@ergawy ergawy force-pushed the add_map_idx_to_private_vars branch from 074e4d2 to 8d770fd Compare November 21, 2024 09:32
@mjklemm
Copy link
Contributor

mjklemm commented Nov 21, 2024

How stable are the objects that you refer too? If they are not deallocated/reallocated and thus go out of existence and/or change their position in memory, you could keep a reference to the object that you'd like to refer to. So, basically, instead of using a 0-based index, you use the object's address in memory as the index. Would that work?

@ergawy
Copy link
Member Author

ergawy commented Nov 22, 2024

So, basically, instead of using a 0-based index, you use the object's address in memory as the index. Would that work?

For the private values that need a map, we can definitely store a refernce to the SSA value of the corresponding map operand to the operation. However, the problem happens when you have a mix of private values where some need a map and some don't. For the private values that don't need a map, we need to store a placeholder value since we need to keep 2 ranges of values of the same length (the first range being the private operands and the second range being the corresponding optional map operands to these private operands). If we choose to store a reference to the map value, we will have to store a "null" SSA value for the private values that don't need a map which invalidates the IR. Therefore, the attribute solves the problem by using index values and a special non-sensical value like -1 to represent the absence of a corresponding map operand.

@ergawy
Copy link
Member Author

ergawy commented Nov 25, 2024

Ping 🔔! Please take a look when you have a chance.

Copy link
Contributor

@agozillon agozillon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM, I think the approach is reasonable, especially as you've tried the other suggestions.

However, it might be worth checking if the additional map arguments affect the block argument interface API at all that can be used to retrieve the correct block argument offset vs map arguments, we utilise this to know the correct grouping of map containing clause operations <-> block arguments for TargetOp/TargetDataOp. I don't think it will, but even if private variables don't need block arguments there may be a small chance it disturbs the ordering, @skatrak will more than likely be able to say one way or another! :-)

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for considering other approaches. I'm happy so long as Michael doesn't have any further suggestions.

@ergawy ergawy merged commit 2918a47 into llvm:main Nov 28, 2024
8 checks passed
ergawy added a commit that referenced this pull request Dec 12, 2024
…tization of allocatables in `omp.target` ops (#116576)

This PR adds support to translate the `private` clause from MLIR to
LLVMIR when used on allocatables in the context of an `omp.target` op.

This replaces #113208.

Parent PR: #116770. Only the
latest commit is relevant to the PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang Flang issues not falling into any other category mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants