15
15
#include " mlir/Dialect/Func/IR/FuncOps.h"
16
16
#include " mlir/Dialect/LLVMIR/LLVMTypes.h"
17
17
#include " mlir/Dialect/OpenACCMPCommon/Interfaces/AtomicInterfaces.h"
18
+ #include " mlir/Dialect/OpenMP/Utils.h"
18
19
#include " mlir/IR/Attributes.h"
19
20
#include " mlir/IR/BuiltinAttributes.h"
20
21
#include " mlir/IR/DialectImplementation.h"
@@ -487,9 +488,11 @@ struct PrivateParseArgs {
487
488
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
488
489
llvm::SmallVectorImpl<Type> &types;
489
490
ArrayAttr &syms;
491
+ ArrayAttr *mapIndices;
490
492
PrivateParseArgs (SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
491
- SmallVectorImpl<Type> &types, ArrayAttr &syms)
492
- : vars(vars), types(types), syms(syms) {}
493
+ SmallVectorImpl<Type> &types, ArrayAttr &syms,
494
+ ArrayAttr *mapIndices = nullptr )
495
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
493
496
};
494
497
struct ReductionParseArgs {
495
498
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
@@ -517,8 +520,10 @@ static ParseResult parseClauseWithRegionArgs(
517
520
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
518
521
SmallVectorImpl<Type> &types,
519
522
SmallVectorImpl<OpAsmParser::Argument> ®ionPrivateArgs,
520
- ArrayAttr *symbols = nullptr , DenseBoolArrayAttr *byref = nullptr ) {
523
+ ArrayAttr *symbols = nullptr , ArrayAttr *mapIndices = nullptr ,
524
+ DenseBoolArrayAttr *byref = nullptr ) {
521
525
SmallVector<SymbolRefAttr> symbolVec;
526
+ SmallVector<int64_t > mapIndicesVec;
522
527
SmallVector<bool > isByRefVec;
523
528
unsigned regionArgOffset = regionPrivateArgs.size ();
524
529
@@ -538,6 +543,16 @@ static ParseResult parseClauseWithRegionArgs(
538
543
parser.parseArgument (regionPrivateArgs.emplace_back ()))
539
544
return failure ();
540
545
546
+ if (mapIndices) {
547
+ if (parser.parseOptionalLSquare ().succeeded ()) {
548
+ if (parser.parseKeyword (" map_idx" ) || parser.parseEqual () ||
549
+ parser.parseInteger (mapIndicesVec.emplace_back ()) ||
550
+ parser.parseRSquare ())
551
+ return failure ();
552
+ } else
553
+ mapIndicesVec.push_back (-1 );
554
+ }
555
+
541
556
return success ();
542
557
}))
543
558
return failure ();
@@ -571,6 +586,9 @@ static ParseResult parseClauseWithRegionArgs(
571
586
*symbols = ArrayAttr::get (parser.getContext (), symbolAttrs);
572
587
}
573
588
589
+ if (!mapIndicesVec.empty ())
590
+ *mapIndices = utils::makeI64ArrayAttr (mapIndicesVec, parser.getContext ());
591
+
574
592
if (byref)
575
593
*byref = makeDenseBoolArrayAttr (parser.getContext (), isByRefVec);
576
594
@@ -595,14 +613,14 @@ static ParseResult parseBlockArgClause(
595
613
static ParseResult parseBlockArgClause (
596
614
OpAsmParser &parser,
597
615
llvm::SmallVectorImpl<OpAsmParser::Argument> &entryBlockArgs,
598
- StringRef keyword, std::optional<PrivateParseArgs> reductionArgs ) {
616
+ StringRef keyword, std::optional<PrivateParseArgs> privateArgs ) {
599
617
if (succeeded (parser.parseOptionalKeyword (keyword))) {
600
- if (!reductionArgs )
618
+ if (!privateArgs )
601
619
return failure ();
602
620
603
- if (failed (parseClauseWithRegionArgs (parser, reductionArgs-> vars ,
604
- reductionArgs ->types , entryBlockArgs,
605
- &reductionArgs ->syms )))
621
+ if (failed (parseClauseWithRegionArgs (
622
+ parser, privateArgs-> vars , privateArgs ->types , entryBlockArgs,
623
+ &privateArgs ->syms , privateArgs-> mapIndices )))
606
624
return failure ();
607
625
}
608
626
return success ();
@@ -618,7 +636,8 @@ static ParseResult parseBlockArgClause(
618
636
619
637
if (failed (parseClauseWithRegionArgs (
620
638
parser, reductionArgs->vars , reductionArgs->types , entryBlockArgs,
621
- &reductionArgs->syms , &reductionArgs->byref )))
639
+ &reductionArgs->syms , /* mapIndices=*/ nullptr ,
640
+ &reductionArgs->byref )))
622
641
return failure ();
623
642
}
624
643
return success ();
@@ -674,12 +693,14 @@ static ParseResult parseInReductionMapPrivateRegion(
674
693
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &mapVars,
675
694
SmallVectorImpl<Type> &mapTypes,
676
695
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
677
- llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
696
+ llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
697
+ ArrayAttr &privateMaps) {
678
698
AllRegionParseArgs args;
679
699
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
680
700
inReductionByref, inReductionSyms);
681
701
args.mapArgs .emplace (mapVars, mapTypes);
682
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
702
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms,
703
+ &privateMaps);
683
704
return parseBlockArgRegion (parser, region, args);
684
705
}
685
706
@@ -776,8 +797,10 @@ struct PrivatePrintArgs {
776
797
ValueRange vars;
777
798
TypeRange types;
778
799
ArrayAttr syms;
779
- PrivatePrintArgs (ValueRange vars, TypeRange types, ArrayAttr syms)
780
- : vars(vars), types(types), syms(syms) {}
800
+ ArrayAttr mapIndices;
801
+ PrivatePrintArgs (ValueRange vars, TypeRange types, ArrayAttr syms,
802
+ ArrayAttr mapIndices)
803
+ : vars(vars), types(types), syms(syms), mapIndices(mapIndices) {}
781
804
};
782
805
struct ReductionPrintArgs {
783
806
ValueRange vars;
@@ -804,6 +827,7 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
804
827
ValueRange argsSubrange,
805
828
ValueRange operands, TypeRange types,
806
829
ArrayAttr symbols = nullptr ,
830
+ ArrayAttr mapIndices = nullptr ,
807
831
DenseBoolArrayAttr byref = nullptr ) {
808
832
if (argsSubrange.empty ())
809
833
return ;
@@ -815,21 +839,31 @@ static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx,
815
839
symbols = ArrayAttr::get (ctx, values);
816
840
}
817
841
842
+ if (!mapIndices) {
843
+ llvm::SmallVector<Attribute> values (operands.size (), nullptr );
844
+ mapIndices = ArrayAttr::get (ctx, values);
845
+ }
846
+
818
847
if (!byref) {
819
848
mlir::SmallVector<bool > values (operands.size (), false );
820
849
byref = DenseBoolArrayAttr::get (ctx, values);
821
850
}
822
851
823
- llvm::interleaveComma (
824
- llvm::zip_equal (operands, argsSubrange, symbols, byref.asArrayRef ()), p,
825
- [&p](auto t) {
826
- auto [op, arg, sym, isByRef] = t;
827
- if (isByRef)
828
- p << " byref " ;
829
- if (sym)
830
- p << sym << " " ;
831
- p << op << " -> " << arg;
832
- });
852
+ llvm::interleaveComma (llvm::zip_equal (operands, argsSubrange, symbols,
853
+ mapIndices, byref.asArrayRef ()),
854
+ p, [&p](auto t) {
855
+ auto [op, arg, sym, map, isByRef] = t;
856
+ if (isByRef)
857
+ p << " byref " ;
858
+ if (sym)
859
+ p << sym << " " ;
860
+
861
+ p << op << " -> " << arg;
862
+
863
+ if (map)
864
+ p << " [map_idx="
865
+ << llvm::cast<IntegerAttr>(map).getInt () << " ]" ;
866
+ });
833
867
p << " : " ;
834
868
llvm::interleaveComma (types, p);
835
869
p << " ) " ;
@@ -849,7 +883,7 @@ static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx,
849
883
if (privateArgs)
850
884
printClauseWithRegionArgs (p, ctx, clauseName, argsSubrange,
851
885
privateArgs->vars , privateArgs->types ,
852
- privateArgs->syms );
886
+ privateArgs->syms , privateArgs-> mapIndices );
853
887
}
854
888
855
889
static void
@@ -859,7 +893,8 @@ printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
859
893
if (reductionArgs)
860
894
printClauseWithRegionArgs (p, ctx, clauseName, argsSubrange,
861
895
reductionArgs->vars , reductionArgs->types ,
862
- reductionArgs->syms , reductionArgs->byref );
896
+ reductionArgs->syms , nullptr ,
897
+ reductionArgs->byref );
863
898
}
864
899
865
900
static void printBlockArgRegion (OpAsmPrinter &p, Operation *op, Region ®ion,
@@ -891,12 +926,13 @@ static void printInReductionMapPrivateRegion(
891
926
OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
892
927
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
893
928
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
894
- ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
929
+ ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
930
+ ArrayAttr privateMaps) {
895
931
AllRegionPrintArgs args;
896
932
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
897
933
inReductionByref, inReductionSyms);
898
934
args.mapArgs .emplace (mapVars, mapTypes);
899
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
935
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, privateMaps );
900
936
printBlockArgRegion (p, op, region, args);
901
937
}
902
938
@@ -908,7 +944,7 @@ static void printInReductionPrivateRegion(
908
944
AllRegionPrintArgs args;
909
945
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
910
946
inReductionByref, inReductionSyms);
911
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
947
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, nullptr );
912
948
printBlockArgRegion (p, op, region, args);
913
949
}
914
950
@@ -921,7 +957,7 @@ static void printInReductionPrivateReductionRegion(
921
957
AllRegionPrintArgs args;
922
958
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
923
959
inReductionByref, inReductionSyms);
924
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
960
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, nullptr );
925
961
args.reductionArgs .emplace (reductionVars, reductionTypes, reductionByref,
926
962
reductionSyms);
927
963
printBlockArgRegion (p, op, region, args);
@@ -931,7 +967,7 @@ static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
931
967
ValueRange privateVars, TypeRange privateTypes,
932
968
ArrayAttr privateSyms) {
933
969
AllRegionPrintArgs args;
934
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
970
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, nullptr );
935
971
printBlockArgRegion (p, op, region, args);
936
972
}
937
973
@@ -941,7 +977,7 @@ static void printPrivateReductionRegion(
941
977
TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
942
978
ArrayAttr reductionSyms) {
943
979
AllRegionPrintArgs args;
944
- args.privateArgs .emplace (privateVars, privateTypes, privateSyms);
980
+ args.privateArgs .emplace (privateVars, privateTypes, privateSyms, nullptr );
945
981
args.reductionArgs .emplace (reductionVars, reductionTypes, reductionByref,
946
982
reductionSyms);
947
983
printBlockArgRegion (p, op, region, args);
@@ -1656,7 +1692,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
1656
1692
/* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
1657
1693
/* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
1658
1694
clauses.mapVars , clauses.nowait , clauses.privateVars ,
1659
- makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit );
1695
+ makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit ,
1696
+ /* private_maps=*/ nullptr );
1660
1697
}
1661
1698
1662
1699
LogicalResult TargetOp::verify () {
0 commit comments