Skip to content

Commit 7b473df

Browse files
[flang][acc] Implement type categorization for FIR types (#126964)
The OpenACC type interfaces have been updated to require that a type self-identify which type category it belongs to. Ensure that FIR types are able to provide this self identification. In addition to implementing the new API, the PointerLikeType interface attachment was moved to FIROpenACCSupport library like MappableType to ensure all type interfaces and their implementation are now in the same spot.
1 parent 9456e7f commit 7b473df

File tree

10 files changed

+255
-31
lines changed

10 files changed

+255
-31
lines changed

flang/include/flang/Optimizer/OpenACC/FIROpenACCTypeInterfaces.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@
1818

1919
namespace fir::acc {
2020

21+
template <typename T>
22+
struct OpenACCPointerLikeModel
23+
: public mlir::acc::PointerLikeType::ExternalModel<
24+
OpenACCPointerLikeModel<T>, T> {
25+
mlir::Type getElementType(mlir::Type pointer) const {
26+
return mlir::cast<T>(pointer).getElementType();
27+
}
28+
mlir::acc::VariableTypeCategory
29+
getPointeeTypeCategory(mlir::Type pointer,
30+
mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
31+
mlir::Type varType) const;
32+
};
33+
2134
template <typename T>
2235
struct OpenACCMappableModel
2336
: public mlir::acc::MappableType::ExternalModel<OpenACCMappableModel<T>,
@@ -36,6 +49,9 @@ struct OpenACCMappableModel
3649
llvm::SmallVector<mlir::Value>
3750
generateAccBounds(mlir::Type type, mlir::Value var,
3851
mlir::OpBuilder &builder) const;
52+
53+
mlir::acc::VariableTypeCategory getTypeCategory(mlir::Type type,
54+
mlir::Value var) const;
3955
};
4056

4157
} // namespace fir::acc

flang/include/flang/Tools/PointerModels.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#ifndef FORTRAN_TOOLS_POINTER_MODELS_H
1010
#define FORTRAN_TOOLS_POINTER_MODELS_H
1111

12-
#include "mlir/Dialect/OpenACC/OpenACC.h"
1312
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1413

1514
/// models for FIR pointer like types that already provide a `getElementType`
@@ -24,13 +23,4 @@ struct OpenMPPointerLikeModel
2423
}
2524
};
2625

27-
template <typename T>
28-
struct OpenACCPointerLikeModel
29-
: public mlir::acc::PointerLikeType::ExternalModel<
30-
OpenACCPointerLikeModel<T>, T> {
31-
mlir::Type getElementType(mlir::Type pointer) const {
32-
return mlir::cast<T>(pointer).getElementType();
33-
}
34-
};
35-
3626
#endif // FORTRAN_TOOLS_POINTER_MODELS_H

flang/lib/Frontend/FrontendActions.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,12 @@ bool CodeGenAction::beginSourceFileAction() {
261261
}
262262

263263
// Load the MLIR dialects required by Flang
264-
mlir::DialectRegistry registry;
265-
mlirCtx = std::make_unique<mlir::MLIRContext>(registry);
266-
fir::support::registerNonCodegenDialects(registry);
267-
fir::support::loadNonCodegenDialects(*mlirCtx);
264+
mlirCtx = std::make_unique<mlir::MLIRContext>();
268265
fir::support::loadDialects(*mlirCtx);
269266
fir::support::registerLLVMTranslation(*mlirCtx);
267+
mlir::DialectRegistry registry;
268+
fir::acc::registerOpenACCExtensions(registry);
269+
mlirCtx->appendDialectRegistry(registry);
270270

271271
const llvm::TargetMachine &targetMachine = ci.getTargetMachine();
272272

flang/lib/Optimizer/Dialect/FIRType.cpp

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,23 +1370,12 @@ void FIROpsDialect::registerTypes() {
13701370
TypeDescType, fir::VectorType, fir::DummyScopeType>();
13711371
fir::ReferenceType::attachInterface<
13721372
OpenMPPointerLikeModel<fir::ReferenceType>>(*getContext());
1373-
fir::ReferenceType::attachInterface<
1374-
OpenACCPointerLikeModel<fir::ReferenceType>>(*getContext());
1375-
13761373
fir::PointerType::attachInterface<OpenMPPointerLikeModel<fir::PointerType>>(
13771374
*getContext());
1378-
fir::PointerType::attachInterface<OpenACCPointerLikeModel<fir::PointerType>>(
1379-
*getContext());
1380-
13811375
fir::HeapType::attachInterface<OpenMPPointerLikeModel<fir::HeapType>>(
13821376
*getContext());
1383-
fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
1384-
*getContext());
1385-
13861377
fir::LLVMPointerType::attachInterface<
13871378
OpenMPPointerLikeModel<fir::LLVMPointerType>>(*getContext());
1388-
fir::LLVMPointerType::attachInterface<
1389-
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*getContext());
13901379
}
13911380

13921381
std::optional<std::pair<uint64_t, unsigned short>>

flang/lib/Optimizer/OpenACC/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_flang_library(FIROpenACCSupport
66

77
DEPENDS
88
FIRBuilder
9+
FIRCodeGen
910
FIRDialect
1011
FIRDialectSupport
1112
FIRSupport
@@ -14,6 +15,7 @@ add_flang_library(FIROpenACCSupport
1415

1516
LINK_LIBS
1617
FIRBuilder
18+
FIRCodeGen
1719
FIRDialect
1820
FIRDialectSupport
1921
FIRSupport

flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "flang/Optimizer/Builder/DirectivesCommon.h"
1616
#include "flang/Optimizer/Builder/FIRBuilder.h"
1717
#include "flang/Optimizer/Builder/HLFIRTools.h"
18+
#include "flang/Optimizer/CodeGen/CGOps.h"
1819
#include "flang/Optimizer/Dialect/FIROps.h"
1920
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
2021
#include "flang/Optimizer/Dialect/FIRType.h"
@@ -24,6 +25,7 @@
2425
#include "mlir/Dialect/OpenACC/OpenACC.h"
2526
#include "mlir/IR/BuiltinOps.h"
2627
#include "mlir/Support/LLVM.h"
28+
#include "llvm/ADT/TypeSwitch.h"
2729

2830
namespace fir::acc {
2931

@@ -224,4 +226,145 @@ OpenACCMappableModel<fir::BaseBoxType>::generateAccBounds(
224226
return {};
225227
}
226228

229+
static bool isScalarLike(mlir::Type type) {
230+
return fir::isa_trivial(type) || fir::isa_ref_type(type);
231+
}
232+
233+
static bool isArrayLike(mlir::Type type) {
234+
return mlir::isa<fir::SequenceType>(type);
235+
}
236+
237+
static bool isCompositeLike(mlir::Type type) {
238+
return mlir::isa<fir::RecordType, fir::ClassType, mlir::TupleType>(type);
239+
}
240+
241+
template <>
242+
mlir::acc::VariableTypeCategory
243+
OpenACCMappableModel<fir::SequenceType>::getTypeCategory(
244+
mlir::Type type, mlir::Value var) const {
245+
return mlir::acc::VariableTypeCategory::array;
246+
}
247+
248+
template <>
249+
mlir::acc::VariableTypeCategory
250+
OpenACCMappableModel<fir::BaseBoxType>::getTypeCategory(mlir::Type type,
251+
mlir::Value var) const {
252+
253+
mlir::Type eleTy = fir::dyn_cast_ptrOrBoxEleTy(type);
254+
255+
// If the type enclosed by the box is a mappable type, then have it
256+
// provide the type category.
257+
if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
258+
return mappableTy.getTypeCategory(var);
259+
260+
// For all arrays, despite whether they are allocatable, pointer, assumed,
261+
// etc, we'd like to categorize them as "array".
262+
if (isArrayLike(eleTy))
263+
return mlir::acc::VariableTypeCategory::array;
264+
265+
// We got here because we don't have an array nor a mappable type. At this
266+
// point, we know we have a type that fits the "aggregate" definition since it
267+
// is a type with a descriptor. Try to refine it by checking if it matches the
268+
// "composite" definition.
269+
if (isCompositeLike(eleTy))
270+
return mlir::acc::VariableTypeCategory::composite;
271+
272+
// Even if we have a scalar type - simply because it is wrapped in a box
273+
// we want to categorize it as "nonscalar". Anything else would've been
274+
// non-scalar anyway.
275+
return mlir::acc::VariableTypeCategory::nonscalar;
276+
}
277+
278+
static mlir::TypedValue<mlir::acc::PointerLikeType>
279+
getBaseRef(mlir::TypedValue<mlir::acc::PointerLikeType> varPtr) {
280+
// If there is no defining op - the unwrapped reference is the base one.
281+
mlir::Operation *op = varPtr.getDefiningOp();
282+
if (!op)
283+
return varPtr;
284+
285+
// Look to find if this value originates from an interior pointer
286+
// calculation op.
287+
mlir::Value baseRef =
288+
llvm::TypeSwitch<mlir::Operation *, mlir::Value>(op)
289+
.Case<hlfir::DesignateOp>([&](auto op) {
290+
// Get the base object.
291+
return op.getMemref();
292+
})
293+
.Case<fir::ArrayCoorOp, fir::cg::XArrayCoorOp>([&](auto op) {
294+
// Get the base array on which the coordinate is being applied.
295+
return op.getMemref();
296+
})
297+
.Case<fir::CoordinateOp>([&](auto op) {
298+
// For coordinate operation which is applied on derived type
299+
// object, get the base object.
300+
return op.getRef();
301+
})
302+
.Default([&](mlir::Operation *) { return varPtr; });
303+
304+
return mlir::cast<mlir::TypedValue<mlir::acc::PointerLikeType>>(baseRef);
305+
}
306+
307+
static mlir::acc::VariableTypeCategory
308+
categorizePointee(mlir::Type pointer,
309+
mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
310+
mlir::Type varType) {
311+
// FIR uses operations to compute interior pointers.
312+
// So for example, an array element or composite field access to a float
313+
// value would both be represented as !fir.ref<f32>. We do not want to treat
314+
// such a reference as a scalar. Thus unwrap interior pointer calculations.
315+
auto baseRef = getBaseRef(varPtr);
316+
mlir::Type eleTy = baseRef.getType().getElementType();
317+
318+
if (auto mappableTy = mlir::dyn_cast<mlir::acc::MappableType>(eleTy))
319+
return mappableTy.getTypeCategory(varPtr);
320+
321+
if (isScalarLike(eleTy))
322+
return mlir::acc::VariableTypeCategory::scalar;
323+
if (isArrayLike(eleTy))
324+
return mlir::acc::VariableTypeCategory::array;
325+
if (isCompositeLike(eleTy))
326+
return mlir::acc::VariableTypeCategory::composite;
327+
if (mlir::isa<fir::CharacterType, mlir::FunctionType>(eleTy))
328+
return mlir::acc::VariableTypeCategory::nonscalar;
329+
// "pointers" - in the sense of raw address point-of-view, are considered
330+
// scalars. However
331+
if (mlir::isa<fir::LLVMPointerType>(eleTy))
332+
return mlir::acc::VariableTypeCategory::scalar;
333+
334+
// Without further checking, this type cannot be categorized.
335+
return mlir::acc::VariableTypeCategory::uncategorized;
336+
}
337+
338+
template <>
339+
mlir::acc::VariableTypeCategory
340+
OpenACCPointerLikeModel<fir::ReferenceType>::getPointeeTypeCategory(
341+
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
342+
mlir::Type varType) const {
343+
return categorizePointee(pointer, varPtr, varType);
344+
}
345+
346+
template <>
347+
mlir::acc::VariableTypeCategory
348+
OpenACCPointerLikeModel<fir::PointerType>::getPointeeTypeCategory(
349+
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
350+
mlir::Type varType) const {
351+
return categorizePointee(pointer, varPtr, varType);
352+
}
353+
354+
template <>
355+
mlir::acc::VariableTypeCategory
356+
OpenACCPointerLikeModel<fir::HeapType>::getPointeeTypeCategory(
357+
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
358+
mlir::Type varType) const {
359+
return categorizePointee(pointer, varPtr, varType);
360+
}
361+
362+
template <>
363+
mlir::acc::VariableTypeCategory
364+
OpenACCPointerLikeModel<fir::LLVMPointerType>::getPointeeTypeCategory(
365+
mlir::Type pointer, mlir::TypedValue<mlir::acc::PointerLikeType> varPtr,
366+
mlir::Type varType) const {
367+
return categorizePointee(pointer, varPtr, varType);
368+
}
369+
227370
} // namespace fir::acc

flang/lib/Optimizer/OpenACC/RegisterOpenACCExtensions.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ void registerOpenACCExtensions(mlir::DialectRegistry &registry) {
2222
fir::SequenceType::attachInterface<OpenACCMappableModel<fir::SequenceType>>(
2323
*ctx);
2424
fir::BoxType::attachInterface<OpenACCMappableModel<fir::BaseBoxType>>(*ctx);
25+
26+
fir::ReferenceType::attachInterface<
27+
OpenACCPointerLikeModel<fir::ReferenceType>>(*ctx);
28+
fir::PointerType::attachInterface<
29+
OpenACCPointerLikeModel<fir::PointerType>>(*ctx);
30+
fir::HeapType::attachInterface<OpenACCPointerLikeModel<fir::HeapType>>(
31+
*ctx);
32+
fir::LLVMPointerType::attachInterface<
33+
OpenACCPointerLikeModel<fir::LLVMPointerType>>(*ctx);
2534
});
2635
}
2736

flang/test/Fir/OpenACC/openacc-mappable.fir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<f16 = dense<16> : vector<2xi64>,
1919

2020
// CHECK: Visiting: %{{.*}} = acc.copyin var(%{{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "arr", structured = false}
2121
// CHECK: Mappable: !fir.box<!fir.array<10xf32>>
22+
// CHECK: Type category: array
2223
// CHECK: Size: 40
2324
// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr", structured = false}
2425
// CHECK: Mappable: !fir.array<10xf32>
26+
// CHECK: Type category: array
2527
// CHECK: Size: 40
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
! RUN: bbc -fopenacc -emit-hlfir %s -o - | fir-opt -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' --mlir-disable-threading 2>&1 | FileCheck %s
2+
3+
program main
4+
real :: scalar
5+
real, allocatable :: scalaralloc
6+
type tt
7+
real :: field
8+
real :: fieldarray(10)
9+
end type tt
10+
type(tt) :: ttvar
11+
real :: arrayconstsize(10)
12+
real, allocatable :: arrayalloc(:)
13+
complex :: complexvar
14+
character*1 :: charvar
15+
16+
!$acc enter data copyin(scalar, scalaralloc, ttvar, arrayconstsize, arrayalloc)
17+
!$acc enter data copyin(complexvar, charvar, ttvar%field, ttvar%fieldarray, arrayconstsize(1))
18+
end program
19+
20+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "scalar", structured = false}
21+
! CHECK: Pointer-like: !fir.ref<f32>
22+
! CHECK: Type category: scalar
23+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "scalaralloc", structured = false}
24+
! CHECK: Pointer-like: !fir.ref<!fir.box<!fir.heap<f32>>>
25+
! CHECK: Type category: nonscalar
26+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar", structured = false}
27+
! CHECK: Pointer-like: !fir.ref<!fir.type<_QFTtt{field:f32,fieldarray:!fir.array<10xf32>}>>
28+
! CHECK: Type category: composite
29+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayconstsize", structured = false}
30+
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
31+
! CHECK: Type category: array
32+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayalloc", structured = false}
33+
! CHECK: Pointer-like: !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
34+
! CHECK: Type category: array
35+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "complexvar", structured = false}
36+
! CHECK: Pointer-like: !fir.ref<complex<f32>>
37+
! CHECK: Type category: scalar
38+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "charvar", structured = false}
39+
! CHECK: Pointer-like: !fir.ref<!fir.char<1>>
40+
! CHECK: Type category: nonscalar
41+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar%field", structured = false}
42+
! CHECK: Pointer-like: !fir.ref<f32>
43+
! CHECK: Type category: composite
44+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "ttvar%fieldarray", structured = false}
45+
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
46+
! CHECK: Type category: array
47+
! CHECK: Visiting: {{.*}} acc.copyin {{.*}} {name = "arrayconstsize(1)", structured = false}
48+
! CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
49+
! CHECK: Type category: array

0 commit comments

Comments
 (0)