Skip to content

Commit 78145a6

Browse files
authored
[flang][cuda] Lower attribute for procedure (#81336)
This PR adds a new attribute to represent the CUDA attribute attached to procedure. This attribute is attached to the func.func operation during lowering. Other procedures information such as `launch_bounds` and `cluster_dims` will be added separately.
1 parent 0144011 commit 78145a6

File tree

5 files changed

+106
-18
lines changed

5 files changed

+106
-18
lines changed

flang/include/flang/Optimizer/Dialect/FIRAttr.td

+37-17
Original file line numberDiff line numberDiff line change
@@ -58,19 +58,34 @@ def fir_FortranVariableFlagsAttr : fir_Attr<"FortranVariableFlags"> {
5858
"::fir::FortranVariableFlagsAttr::get($_builder.getContext(), $0)";
5959
}
6060

61-
def CUDAconstant : I32EnumAttrCase<"Constant", 0, "constant">;
62-
def CUDAdevice : I32EnumAttrCase<"Device", 1, "device">;
63-
def CUDAmanaged : I32EnumAttrCase<"Managed", 2, "managed">;
64-
def CUDApinned : I32EnumAttrCase<"Pinned", 3, "pinned">;
65-
def CUDAshared : I32EnumAttrCase<"Shared", 4, "shared">;
66-
def CUDAunified : I32EnumAttrCase<"Unified", 5, "unified">;
67-
// Texture is omitted since it is obsolete and rejected by semantic.
61+
def fir_BoxFieldAttr : I32EnumAttr<
62+
"BoxFieldAttr", "",
63+
[
64+
I32EnumAttrCase<"base_addr", 0>,
65+
I32EnumAttrCase<"derived_type", 1>
66+
]> {
67+
let cppNamespace = "fir";
68+
}
69+
70+
// mlir::SideEffects::Resource for modelling operations which add debugging information
71+
def DebuggingResource : Resource<"::fir::DebuggingResource">;
72+
73+
//===----------------------------------------------------------------------===//
74+
// CUDA Fortran specific attributes
75+
//===----------------------------------------------------------------------===//
6876

6977
def fir_CUDADataAttribute : I32EnumAttr<
7078
"CUDADataAttribute",
7179
"CUDA Fortran variable attributes",
72-
[CUDAconstant, CUDAdevice, CUDAmanaged, CUDApinned, CUDAshared,
73-
CUDAunified]> {
80+
[
81+
I32EnumAttrCase<"Constant", 0, "constant">,
82+
I32EnumAttrCase<"Device", 1, "device">,
83+
I32EnumAttrCase<"Managed", 2, "managed">,
84+
I32EnumAttrCase<"Pinned", 3, "pinned">,
85+
I32EnumAttrCase<"Shared", 4, "shared">,
86+
I32EnumAttrCase<"Unified", 5, "unified">,
87+
// Texture is omitted since it is obsolete and rejected by semantic.
88+
]> {
7489
let genSpecializedAttr = 0;
7590
let cppNamespace = "::fir";
7691
}
@@ -80,17 +95,22 @@ def fir_CUDADataAttributeAttr :
8095
let assemblyFormat = [{ ```<` $value `>` }];
8196
}
8297

83-
def fir_BoxFieldAttr : I32EnumAttr<
84-
"BoxFieldAttr", "",
98+
def fir_CUDAProcAttribute : I32EnumAttr<
99+
"CUDAProcAttribute", "CUDA Fortran procedure attributes",
85100
[
86-
I32EnumAttrCase<"base_addr", 0>,
87-
I32EnumAttrCase<"derived_type", 1>
101+
I32EnumAttrCase<"Host", 0, "host">,
102+
I32EnumAttrCase<"Device", 1, "device">,
103+
I32EnumAttrCase<"HostDevice", 2, "host_device">,
104+
I32EnumAttrCase<"Global", 3, "global">,
105+
I32EnumAttrCase<"GridGlobal", 4, "grid_global">,
88106
]> {
89-
let cppNamespace = "fir";
107+
let genSpecializedAttr = 0;
108+
let cppNamespace = "::fir";
90109
}
91110

92-
93-
// mlir::SideEffects::Resource for modelling operations which add debugging information
94-
def DebuggingResource : Resource<"::fir::DebuggingResource">;
111+
def fir_CUDAProcAttributeAttr :
112+
EnumAttr<fir_Dialect, fir_CUDAProcAttribute, "cuda_proc"> {
113+
let assemblyFormat = [{ ```<` $value `>` }];
114+
}
95115

96116
#endif // FIR_DIALECT_FIR_ATTRS

flang/include/flang/Optimizer/Support/Utils.h

+27
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,33 @@ getCUDADataAttribute(mlir::MLIRContext *mlirContext,
303303
return {};
304304
}
305305

306+
inline fir::CUDAProcAttributeAttr getCUDAProcAttribute(
307+
mlir::MLIRContext *mlirContext,
308+
std::optional<Fortran::common::CUDASubprogramAttrs> cudaAttr) {
309+
if (cudaAttr) {
310+
fir::CUDAProcAttribute attr;
311+
switch (*cudaAttr) {
312+
case Fortran::common::CUDASubprogramAttrs::Host:
313+
attr = fir::CUDAProcAttribute::Host;
314+
break;
315+
case Fortran::common::CUDASubprogramAttrs::Device:
316+
attr = fir::CUDAProcAttribute::Device;
317+
break;
318+
case Fortran::common::CUDASubprogramAttrs::HostDevice:
319+
attr = fir::CUDAProcAttribute::HostDevice;
320+
break;
321+
case Fortran::common::CUDASubprogramAttrs::Global:
322+
attr = fir::CUDAProcAttribute::Global;
323+
break;
324+
case Fortran::common::CUDASubprogramAttrs::Grid_Global:
325+
attr = fir::CUDAProcAttribute::GridGlobal;
326+
break;
327+
}
328+
return fir::CUDAProcAttributeAttr::get(mlirContext, attr);
329+
}
330+
return {};
331+
}
332+
306333
} // namespace fir
307334

308335
#endif // FORTRAN_OPTIMIZER_SUPPORT_UTILS_H

flang/lib/Lower/CallInterface.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "flang/Lower/CallInterface.h"
10+
#include "flang/Common/Fortran.h"
1011
#include "flang/Evaluate/fold.h"
1112
#include "flang/Lower/Bridge.h"
1213
#include "flang/Lower/Mangler.h"
@@ -559,6 +560,12 @@ void Fortran::lower::CallInterface<T>::declare() {
559560
func.setArgAttrs(placeHolder.index(), placeHolder.value().attributes);
560561
side().setFuncAttrs(func);
561562
}
563+
if (characteristic && characteristic->cudaSubprogramAttrs) {
564+
func.getOperation()->setAttr(
565+
fir::getCUDAAttrName(),
566+
fir::getCUDAProcAttribute(func.getContext(),
567+
*characteristic->cudaSubprogramAttrs));
568+
}
562569
}
563570
}
564571

flang/lib/Optimizer/Dialect/FIRAttr.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -298,5 +298,5 @@ void fir::printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
298298
void FIROpsDialect::registerAttributes() {
299299
addAttributes<ClosedIntervalAttr, ExactTypeAttr, FortranVariableFlagsAttr,
300300
LowerBoundAttr, PointIntervalAttr, RealAttr, SubclassAttr,
301-
UpperBoundAttr, CUDADataAttributeAttr>();
301+
UpperBoundAttr, CUDADataAttributeAttr, CUDAProcAttributeAttr>();
302302
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
! RUN: bbc -emit-hlfir -fcuda %s -o - | fir-opt -convert-hlfir-to-fir | FileCheck %s
3+
4+
! Test lowering of CUDA attribute on procedures.
5+
6+
attributes(host) subroutine sub_host(); end
7+
! CHECK: func.func @_QPsub_host() attributes {fir.cuda_attr = #fir.cuda_proc<host>}
8+
9+
attributes(device) subroutine sub_device(); end
10+
! CHECK: func.func @_QPsub_device() attributes {fir.cuda_attr = #fir.cuda_proc<device>}
11+
12+
attributes(host) attributes(device) subroutine sub_host_device; end
13+
! CHECK: func.func @_QPsub_host_device() attributes {fir.cuda_attr = #fir.cuda_proc<host_device>}
14+
15+
attributes(device) attributes(host) subroutine sub_device_host; end
16+
! CHECK: func.func @_QPsub_device_host() attributes {fir.cuda_attr = #fir.cuda_proc<host_device>}
17+
18+
attributes(global) subroutine sub_global(); end
19+
! CHECK: func.func @_QPsub_global() attributes {fir.cuda_attr = #fir.cuda_proc<global>}
20+
21+
attributes(grid_global) subroutine sub_grid_global(); end
22+
! CHECK: func.func @_QPsub_grid_global() attributes {fir.cuda_attr = #fir.cuda_proc<grid_global>}
23+
24+
attributes(host) integer function fct_host(); end
25+
! CHECK: func.func @_QPfct_host() -> i32 attributes {fir.cuda_attr = #fir.cuda_proc<host>}
26+
27+
attributes(device) integer function fct_device(); end
28+
! CHECK: func.func @_QPfct_device() -> i32 attributes {fir.cuda_attr = #fir.cuda_proc<device>}
29+
30+
attributes(host) attributes(device) integer function fct_host_device; end
31+
! CHECK: func.func @_QPfct_host_device() -> i32 attributes {fir.cuda_attr = #fir.cuda_proc<host_device>}
32+
33+
attributes(device) attributes(host) integer function fct_device_host; end
34+
! CHECK: func.func @_QPfct_device_host() -> i32 attributes {fir.cuda_attr = #fir.cuda_proc<host_device>}

0 commit comments

Comments
 (0)