Skip to content

Commit dfc21ac

Browse files
authored
[flang][cuda] Convert global allocation for pinned variable (#106807)
ALLOCATE/DEALLOCATE statements for module allocatable variable with the pinned attribute can be lowered to the standard runtime call and do not need further action since these variables will have a unique descriptor that is on the host.
1 parent 334d123 commit dfc21ac

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

flang/lib/Optimizer/Transforms/CufOpConversion.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,25 @@ using namespace Fortran::runtime::cuda;
3333
namespace {
3434

3535
template <typename OpTy>
36-
static bool isBoxGlobal(OpTy op) {
36+
static bool needDoubleDescriptor(OpTy op) {
3737
if (auto declareOp =
3838
mlir::dyn_cast_or_null<fir::DeclareOp>(op.getBox().getDefiningOp())) {
3939
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
40-
declareOp.getMemref().getDefiningOp()))
40+
declareOp.getMemref().getDefiningOp())) {
41+
if (declareOp.getDataAttr() &&
42+
*declareOp.getDataAttr() == cuf::DataAttribute::Pinned)
43+
return false;
4144
return true;
45+
}
4246
} else if (auto declareOp = mlir::dyn_cast_or_null<hlfir::DeclareOp>(
4347
op.getBox().getDefiningOp())) {
4448
if (mlir::isa_and_nonnull<fir::AddrOfOp>(
45-
declareOp.getMemref().getDefiningOp()))
49+
declareOp.getMemref().getDefiningOp())) {
50+
if (declareOp.getDataAttr() &&
51+
*declareOp.getDataAttr() == cuf::DataAttribute::Pinned)
52+
return false;
4653
return true;
54+
}
4755
}
4856
return false;
4957
}
@@ -100,7 +108,7 @@ struct CufAllocateOpConversion
100108

101109
// TODO: Allocation of module variable will need more work as the descriptor
102110
// will be duplicated and needs to be synced after allocation.
103-
if (isBoxGlobal(op))
111+
if (needDoubleDescriptor(op))
104112
return mlir::failure();
105113

106114
// Allocation for local descriptor falls back on the standard runtime
@@ -125,7 +133,7 @@ struct CufDeallocateOpConversion
125133
mlir::PatternRewriter &rewriter) const override {
126134
// TODO: Allocation of module variable will need more work as the descriptor
127135
// will be duplicated and needs to be synced after allocation.
128-
if (isBoxGlobal(op))
136+
if (needDoubleDescriptor(op))
129137
return mlir::failure();
130138

131139
// Deallocation for local descriptor falls back on the standard runtime
@@ -274,9 +282,9 @@ class CufOpConversion : public fir::impl::CufOpConversionBase<CufOpConversion> {
274282
return true;
275283
});
276284
target.addDynamicallyLegalOp<cuf::AllocateOp>(
277-
[](::cuf::AllocateOp op) { return isBoxGlobal(op); });
285+
[](::cuf::AllocateOp op) { return needDoubleDescriptor(op); });
278286
target.addDynamicallyLegalOp<cuf::DeallocateOp>(
279-
[](::cuf::DeallocateOp op) { return isBoxGlobal(op); });
287+
[](::cuf::DeallocateOp op) { return needDoubleDescriptor(op); });
280288
target.addLegalDialect<fir::FIROpsDialect>();
281289
patterns.insert<CufAllocOpConversion>(ctx, &*dl, &typeConverter);
282290
patterns.insert<CufAllocateOpConversion, CufDeallocateOpConversion,

flang/test/Fir/CUDA/cuda-allocate.fir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,31 @@ func.func @_QPsub4() attributes {cuf.proc_attr = #cuf.cuda_proc<device>} {
6868
// CHECK: fir.alloca
6969
// CHECK-NOT: cuf.free
7070

71+
fir.global @_QMglobalsEa_pinned {data_attr = #cuf.cuda<pinned>} : !fir.box<!fir.heap<!fir.array<?xf32>>> {
72+
%0 = fir.zero_bits !fir.heap<!fir.array<?xf32>>
73+
%c0 = arith.constant 0 : index
74+
%1 = fir.shape %c0 : (index) -> !fir.shape<1>
75+
%2 = fir.embox %0(%1) {allocator_idx = 1 : i32} : (!fir.heap<!fir.array<?xf32>>, !fir.shape<1>) -> !fir.box<!fir.heap<!fir.array<?xf32>>>
76+
fir.has_value %2 : !fir.box<!fir.heap<!fir.array<?xf32>>>
7177
}
7278

79+
func.func @_QPsub5() {
80+
%4 = fir.address_of(@_QMglobalsEa_pinned) : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
81+
%5:2 = hlfir.declare %4 {data_attr = #cuf.cuda<pinned>, fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QMglobalsEa_pinned"} : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>, !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>)
82+
%c1 = arith.constant 1 : index
83+
%c10_i32 = arith.constant 10 : i32
84+
%c0_i32 = arith.constant 0 : i32
85+
%6 = fir.convert %5#1 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>) -> !fir.ref<!fir.box<none>>
86+
%7 = fir.convert %c1 : (index) -> i64
87+
%8 = fir.convert %c10_i32 : (i32) -> i64
88+
%9 = fir.call @_FortranAAllocatableSetBounds(%6, %c0_i32, %7, %8) fastmath<contract> : (!fir.ref<!fir.box<none>>, i32, i64, i64) -> none
89+
%10 = cuf.allocate %5#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<pinned>} -> i32
90+
%11 = cuf.deallocate %5#1 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>> {data_attr = #cuf.cuda<pinned>} -> i32
91+
return
92+
}
93+
94+
// CHECK-LABEL: func.func @_QPsub5()
95+
// CHECK: fir.call @_FortranAAllocatableAllocate({{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
96+
// CHECK: fir.call @_FortranAAllocatableDeallocate({{.*}}) : (!fir.ref<!fir.box<none>>, i1, !fir.box<none>, !fir.ref<i8>, i32) -> i32
7397

98+
} // end of module

0 commit comments

Comments
 (0)