15
15
#include " mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
16
16
#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
17
17
#include " mlir/Conversion/LLVMCommon/Pattern.h"
18
+ #include " mlir/Dialect/Arith/IR/Arith.h"
18
19
#include " mlir/Dialect/DLTI/DLTI.h"
19
20
#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
21
+ #include " mlir/Dialect/LLVMIR/LLVMTypes.h"
20
22
#include " mlir/Dialect/MPI/IR/MPI.h"
21
23
#include " mlir/Transforms/DialectConversion.h"
22
24
#include < memory>
@@ -57,9 +59,14 @@ std::pair<Value, Value> getRawPtrAndSize(const Location loc,
57
59
loc, rewriter.getI64Type (), memRef, 2 );
58
60
Value resPtr =
59
61
rewriter.create <LLVM::GEPOp>(loc, ptrType, elType, dataPtr, offset);
60
- Value size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
61
- ArrayRef<int64_t >{3 , 0 });
62
- size = rewriter.create <LLVM::TruncOp>(loc, rewriter.getI32Type (), size);
62
+ Value size;
63
+ if (cast<LLVM::LLVMStructType>(memRef.getType ()).getBody ().size () > 3 ) {
64
+ size = rewriter.create <LLVM::ExtractValueOp>(loc, memRef,
65
+ ArrayRef<int64_t >{3 , 0 });
66
+ size = rewriter.create <LLVM::TruncOp>(loc, rewriter.getI32Type (), size);
67
+ } else {
68
+ size = rewriter.create <arith::ConstantIntOp>(loc, 1 , 32 );
69
+ }
63
70
return {resPtr, size};
64
71
}
65
72
@@ -97,6 +104,9 @@ class MPIImplTraits {
97
104
// / Get the MPI_STATUS_IGNORE value (typically a pointer type).
98
105
virtual intptr_t getStatusIgnore () = 0;
99
106
107
+ // / Get the MPI_IN_PLACE value (void *).
108
+ virtual void *getInPlace () = 0;
109
+
100
110
// / Gets or creates an MPI datatype as a value which corresponds to the given
101
111
// / type.
102
112
virtual Value getDataType (const Location loc,
@@ -158,6 +168,8 @@ class MPICHImplTraits : public MPIImplTraits {
158
168
159
169
intptr_t getStatusIgnore () override { return 1 ; }
160
170
171
+ void *getInPlace () override { return reinterpret_cast <void *>(-1 ); }
172
+
161
173
Value getDataType (const Location loc, ConversionPatternRewriter &rewriter,
162
174
Type type) override {
163
175
int32_t mtype = 0 ;
@@ -283,6 +295,8 @@ class OMPIImplTraits : public MPIImplTraits {
283
295
284
296
intptr_t getStatusIgnore () override { return 0 ; }
285
297
298
+ void *getInPlace () override { return reinterpret_cast <void *>(1 ); }
299
+
286
300
Value getDataType (const Location loc, ConversionPatternRewriter &rewriter,
287
301
Type type) override {
288
302
StringRef mtype;
@@ -516,7 +530,8 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
516
530
outPtr.getRes ()});
517
531
518
532
// load the communicator into a register
519
- auto res = rewriter.create <LLVM::LoadOp>(loc, i32, outPtr.getResult ());
533
+ Value res = rewriter.create <LLVM::LoadOp>(loc, i32, outPtr.getResult ());
534
+ res = rewriter.create <LLVM::SExtOp>(loc, rewriter.getI64Type (), res);
520
535
521
536
// if retval is checked, replace uses of retval with the results from the
522
537
// call op
@@ -525,7 +540,7 @@ struct CommSplitOpLowering : public ConvertOpToLLVMPattern<mpi::CommSplitOp> {
525
540
replacements.push_back (callOp.getResult ());
526
541
527
542
// replace op
528
- replacements.push_back (res. getRes () );
543
+ replacements.push_back (res);
529
544
rewriter.replaceOp (op, replacements);
530
545
531
546
return success ();
@@ -709,6 +724,7 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
709
724
Location loc = op.getLoc ();
710
725
MLIRContext *context = rewriter.getContext ();
711
726
Type i32 = rewriter.getI32Type ();
727
+ Type i64 = rewriter.getI64Type ();
712
728
Type elemType = op.getSendbuf ().getType ().getElementType ();
713
729
714
730
// ptrType `!llvm.ptr`
@@ -719,6 +735,14 @@ struct AllReduceOpLowering : public ConvertOpToLLVMPattern<mpi::AllReduceOp> {
719
735
getRawPtrAndSize (loc, rewriter, adaptor.getSendbuf (), elemType);
720
736
auto [recvPtr, recvSize] =
721
737
getRawPtrAndSize (loc, rewriter, adaptor.getRecvbuf (), elemType);
738
+
739
+ // If input and output are the same, request in-place operation.
740
+ if (adaptor.getSendbuf () == adaptor.getRecvbuf ()) {
741
+ sendPtr = rewriter.create <LLVM::ConstantOp>(
742
+ loc, i64, reinterpret_cast <int64_t >(mpiTraits->getInPlace ()));
743
+ sendPtr = rewriter.create <LLVM::IntToPtrOp>(loc, ptrType, sendPtr);
744
+ }
745
+
722
746
Value dataType = mpiTraits->getDataType (loc, rewriter, elemType);
723
747
Value mpiOp = mpiTraits->getMPIOp (loc, rewriter, op.getOp ());
724
748
Value commWorld = mpiTraits->castComm (loc, rewriter, adaptor.getComm ());
0 commit comments