Skip to content

Commit 70b9440

Browse files
authored
Revert "[flang][cuda] Specialize entry point for scalar to desc data transfer" (#116458)
Reverts #116457
1 parent 43cb424 commit 70b9440

File tree

6 files changed

+16
-56
lines changed

6 files changed

+16
-56
lines changed

flang/include/flang/Runtime/CUDA/memory.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ void RTDECL(CUFDataTransferPtrDesc)(void *dst, Descriptor *src,
4444
void RTDECL(CUFDataTransferDescDesc)(Descriptor *dst, Descriptor *src,
4545
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
4646

47-
/// Data transfer from a scalar descriptor to a descriptor.
48-
void RTDECL(CUFDataTransferCstDesc)(Descriptor *dst, Descriptor *src,
49-
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);
50-
5147
/// Data transfer from a descriptor to a descriptor.
5248
void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dst, Descriptor *src,
5349
unsigned mode, const char *sourceFile = nullptr, int sourceLine = 0);

flang/lib/Optimizer/Transforms/CUFOpConversion.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,8 @@ struct CUFDataTransferOpConversion
563563
// until we have more infrastructure.
564564
mlir::Value src = emboxSrc(rewriter, op, symtab);
565565
mlir::Value dst = emboxDst(rewriter, op, symtab);
566-
mlir::func::FuncOp func =
567-
fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
568-
loc, builder);
566+
mlir::func::FuncOp func = fir::runtime::getRuntimeFunc<mkRTKey(
567+
CUFDataTransferDescDescNoRealloc)>(loc, builder);
569568
auto fTy = func.getFunctionType();
570569
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
571570
mlir::Value sourceLine =
@@ -649,9 +648,6 @@ struct CUFDataTransferOpConversion
649648
mlir::Value src = op.getSrc();
650649
if (!mlir::isa<fir::BaseBoxType>(srcTy)) {
651650
src = emboxSrc(rewriter, op, symtab);
652-
if (fir::isa_trivial(srcTy))
653-
func = fir::runtime::getRuntimeFunc<mkRTKey(CUFDataTransferCstDesc)>(
654-
loc, builder);
655651
}
656652
auto materializeBoxIfNeeded = [&](mlir::Value val) -> mlir::Value {
657653
if (mlir::isa<fir::EmboxOp>(val.getDefiningOp())) {

flang/runtime/CUDA/memory.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "flang/Runtime/CUDA/memory.h"
10-
#include "../assign-impl.h"
1110
#include "../terminator.h"
1211
#include "flang/Runtime/CUDA/common.h"
1312
#include "flang/Runtime/CUDA/descriptor.h"
@@ -121,24 +120,6 @@ void RTDECL(CUFDataTransferDescDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
121120
*dstDesc, *srcDesc, terminator, MaybeReallocate, memmoveFct);
122121
}
123122

124-
void RTDECL(CUFDataTransferCstDesc)(Descriptor *dstDesc, Descriptor *srcDesc,
125-
unsigned mode, const char *sourceFile, int sourceLine) {
126-
MemmoveFct memmoveFct;
127-
Terminator terminator{sourceFile, sourceLine};
128-
if (mode == kHostToDevice) {
129-
memmoveFct = &MemmoveHostToDevice;
130-
} else if (mode == kDeviceToHost) {
131-
memmoveFct = &MemmoveDeviceToHost;
132-
} else if (mode == kDeviceToDevice) {
133-
memmoveFct = &MemmoveDeviceToDevice;
134-
} else {
135-
terminator.Crash("host to host copy not supported");
136-
}
137-
138-
Fortran::runtime::DoFromSourceAssign(
139-
*dstDesc, *srcDesc, terminator, memmoveFct);
140-
}
141-
142123
void RTDECL(CUFDataTransferDescDescNoRealloc)(Descriptor *dstDesc,
143124
Descriptor *srcDesc, unsigned mode, const char *sourceFile,
144125
int sourceLine) {

flang/runtime/assign-impl.h

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,16 @@
99
#ifndef FORTRAN_RUNTIME_ASSIGN_IMPL_H_
1010
#define FORTRAN_RUNTIME_ASSIGN_IMPL_H_
1111

12-
#include "flang/Runtime/freestanding-tools.h"
13-
1412
namespace Fortran::runtime {
1513
class Descriptor;
1614
class Terminator;
1715

18-
using MemmoveFct = void *(*)(void *, const void *, std::size_t);
19-
2016
// Assign one object to another via allocate statement from source specifier.
2117
// Note that if allocate object and source expression have the same rank, the
2218
// value of the allocate object becomes the value provided; otherwise the value
2319
// of each element of allocate object becomes the value provided (9.7.1.2(7)).
24-
#ifdef RT_DEVICE_COMPILATION
25-
static RT_API_ATTRS void *MemmoveWrapper(
26-
void *dest, const void *src, std::size_t count) {
27-
return Fortran::runtime::memmove(dest, src, count);
28-
}
29-
RT_API_ATTRS void DoFromSourceAssign(Descriptor &, const Descriptor &,
30-
Terminator &, MemmoveFct memmoveFct = &MemmoveWrapper);
31-
#else
32-
RT_API_ATTRS void DoFromSourceAssign(Descriptor &, const Descriptor &,
33-
Terminator &, MemmoveFct memmoveFct = &Fortran::runtime::memmove);
34-
#endif
20+
RT_API_ATTRS void DoFromSourceAssign(
21+
Descriptor &, const Descriptor &, Terminator &);
3522

3623
} // namespace Fortran::runtime
3724
#endif // FORTRAN_RUNTIME_ASSIGN_IMPL_H_

flang/runtime/assign.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,8 @@ RT_API_ATTRS void Assign(Descriptor &to, const Descriptor &from,
509509

510510
RT_OFFLOAD_API_GROUP_BEGIN
511511

512-
RT_API_ATTRS void DoFromSourceAssign(Descriptor &alloc,
513-
const Descriptor &source, Terminator &terminator, MemmoveFct memmoveFct) {
512+
RT_API_ATTRS void DoFromSourceAssign(
513+
Descriptor &alloc, const Descriptor &source, Terminator &terminator) {
514514
if (alloc.rank() > 0 && source.rank() == 0) {
515515
// The value of each element of allocate object becomes the value of source.
516516
DescriptorAddendum *allocAddendum{alloc.Addendum()};
@@ -523,17 +523,17 @@ RT_API_ATTRS void DoFromSourceAssign(Descriptor &alloc,
523523
alloc.IncrementSubscripts(allocAt)) {
524524
Descriptor allocElement{*Descriptor::Create(*allocDerived,
525525
reinterpret_cast<void *>(alloc.Element<char>(allocAt)), 0)};
526-
Assign(allocElement, source, terminator, NoAssignFlags, memmoveFct);
526+
Assign(allocElement, source, terminator, NoAssignFlags);
527527
}
528528
} else { // intrinsic type
529529
for (std::size_t n{alloc.Elements()}; n-- > 0;
530530
alloc.IncrementSubscripts(allocAt)) {
531-
memmoveFct(alloc.Element<char>(allocAt), source.raw().base_addr,
532-
alloc.ElementBytes());
531+
Fortran::runtime::memmove(alloc.Element<char>(allocAt),
532+
source.raw().base_addr, alloc.ElementBytes());
533533
}
534534
}
535535
} else {
536-
Assign(alloc, source, terminator, NoAssignFlags, memmoveFct);
536+
Assign(alloc, source, terminator, NoAssignFlags);
537537
}
538538
}
539539

flang/test/Fir/CUDA/cuda-data-transfer.fir

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ func.func @_QPsub2() {
3838
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
3939
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
4040
// CHECK: %[[TEMP_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
41-
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
41+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[TEMP_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
4242

4343
func.func @_QPsub3() {
4444
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub3Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -58,7 +58,7 @@ func.func @_QPsub3() {
5858
// CHECK: fir.store %[[EMBOX]] to %[[TEMP_BOX]] : !fir.ref<!fir.box<i32>>
5959
// CHECK: %[[ADEV_BOX:.*]] = fir.convert %[[ADEV]]#0 : (!fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>) -> !fir.ref<!fir.box<none>>
6060
// CHECK: %[[V_CONV:.*]] = fir.convert %[[TEMP_BOX]] : (!fir.ref<!fir.box<i32>>) -> !fir.ref<!fir.box<none>>
61-
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
61+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%[[ADEV_BOX]], %[[V_CONV]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
6262

6363
func.func @_QPsub4() {
6464
%0 = cuf.alloc !fir.box<!fir.heap<!fir.array<?xi32>>> {bindc_name = "adev", data_attr = #cuf.cuda<device>, uniq_name = "_QFsub4Eadev"} -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
@@ -297,7 +297,7 @@ func.func @_QPscalar_to_array() {
297297
}
298298

299299
// CHECK-LABEL: func.func @_QPscalar_to_array()
300-
// CHECK: _FortranACUFDataTransferCstDesc
300+
// CHECK: _FortranACUFDataTransferDescDescNoRealloc
301301

302302
func.func @_QPtest_type() {
303303
%0 = cuf.alloc !fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}> {bindc_name = "a", data_attr = #cuf.cuda<device>, uniq_name = "_QFtest_typeEa"} -> !fir.ref<!fir.type<_QMbarTcmplx{id:i32,c:complex<f32>}>>
@@ -344,7 +344,7 @@ func.func @_QPshape_shift() {
344344
}
345345

346346
// CHECK-LABEL: func.func @_QPshape_shift()
347-
// CHECK: fir.call @_FortranACUFDataTransferCstDesc
347+
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc
348348

349349
func.func @_QPshape_shift2() {
350350
%c11 = arith.constant 11 : index
@@ -383,7 +383,7 @@ func.func @_QPdevice_addr_conv() {
383383
// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
384384
// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
385385
// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
386-
// CHECK: fir.call @_FortranACUFDataTransferCstDesc
386+
// CHECK: fir.call @_FortranACUFDataTransferDescDescNoRealloc
387387

388388
func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
389389
%c1 = arith.constant 1 : index
@@ -464,6 +464,6 @@ func.func @_QPlogical_cst() {
464464
// CHECK: %[[EMBOX:.*]] = fir.embox %[[CONST]] : (!fir.ref<!fir.logical<4>>) -> !fir.box<!fir.logical<4>>
465465
// CHECK: fir.store %[[EMBOX]] to %[[DESC]] : !fir.ref<!fir.box<!fir.logical<4>>>
466466
// CHECK: %[[BOX_NONE:.*]] = fir.convert %[[DESC]] : (!fir.ref<!fir.box<!fir.logical<4>>>) -> !fir.ref<!fir.box<none>>
467-
// CHECK: fir.call @_FortranACUFDataTransferCstDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
467+
// CHECK: fir.call @_FortranACUFDataTransferDescDesc(%{{.*}}, %[[BOX_NONE]], %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.ref<!fir.box<none>>, !fir.ref<!fir.box<none>>, i32, !fir.ref<i8>, i32) -> none
468468

469469
} // end of module

0 commit comments

Comments
 (0)