Skip to content

Commit e398383

Browse files
authored
[flang][fir] add codegen for fir.load of assumed-rank fir.box (#93569)
- Update LLVM type conversion of assumed-rank fir.box/class to generate the type of the maximum ranked descriptor. That way, alloca for assumed rank descriptor copies are always big enough. This is needed in the fir.load case that generates a new storage for the value - Add a "computeBoxSize" helper to compute the dynamic size of a descriptor. - Use that size to generate an llvm.memcpy intrinsic to copy the input descriptor into the new storage. Looking at https://reviews.llvm.org/D108221?id=404635, it seems valid to add the TBAA node on the memcpy, which I did. In a further patch, I think we should likely always use a memcpy since LLVM seems to have a better time optimizing it than fir.load/fir.store patterns.
1 parent 6f2794a commit e398383

File tree

7 files changed

+130
-20
lines changed

7 files changed

+130
-20
lines changed

flang/include/flang/Optimizer/CodeGen/FIROpPatterns.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ class ConvertFIRToLLVMPattern : public mlir::ConvertToLLVMPattern {
125125
mlir::ConversionPatternRewriter &rewriter,
126126
unsigned maskValue) const;
127127

128+
/// Compute the descriptor size in bytes. The result is not guaranteed to be a
129+
/// compile time constant if the box is for an assumed rank, in which case the
130+
/// box rank will be read.
131+
mlir::Value computeBoxSize(mlir::Location, TypePair boxTy, mlir::Value box,
132+
mlir::ConversionPatternRewriter &rewriter) const;
133+
128134
template <typename... ARGS>
129135
mlir::LLVM::GEPOp genGEP(mlir::Location loc, mlir::Type ty,
130136
mlir::ConversionPatternRewriter &rewriter,

flang/include/flang/Optimizer/CodeGen/TypeConverter.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,16 @@ class LLVMTypeConverter : public mlir::LLVMTypeConverter {
123123
mlir::Type baseFIRType, mlir::Type accessFIRType,
124124
mlir::LLVM::GEPOp gep) const;
125125

126+
const mlir::DataLayout &getDataLayout() const {
127+
assert(dataLayout && "must be set in ctor");
128+
return *dataLayout;
129+
}
130+
126131
private:
127132
KindMapping kindMapping;
128133
std::unique_ptr<CodeGenSpecifics> specifics;
129134
std::unique_ptr<TBAABuilder> tbaaBuilder;
135+
const mlir::DataLayout *dataLayout;
130136
};
131137

132138
} // namespace fir

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2863,23 +2863,32 @@ struct LoadOpConversion : public fir::FIROpConversion<fir::LoadOp> {
28632863
// descriptor value into a new descriptor temp.
28642864
auto inputBoxStorage = adaptor.getOperands()[0];
28652865
mlir::Location loc = load.getLoc();
2866-
fir::SequenceType seqTy = fir::unwrapUntilSeqType(boxTy);
2867-
// fir.box of assumed rank do not have a storage
2868-
// size that is know at compile time. The copy needs to be runtime driven
2869-
// depending on the actual dynamic rank or type.
2870-
if (seqTy && seqTy.hasUnknownShape())
2871-
TODO(loc, "loading or assumed rank fir.box");
2872-
auto boxValue =
2873-
rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy, inputBoxStorage);
2874-
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
2875-
boxValue.setTBAATags(*optionalTag);
2876-
else
2877-
attachTBAATag(boxValue, boxTy, boxTy, nullptr);
28782866
auto newBoxStorage =
28792867
genAllocaAndAddrCastWithType(loc, llvmLoadTy, defaultAlign, rewriter);
2880-
auto storeOp =
2881-
rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
2882-
attachTBAATag(storeOp, boxTy, boxTy, nullptr);
2868+
// TODO: always generate llvm.memcpy, LLVM is better at optimizing it than
2869+
// aggregate loads + stores.
2870+
if (boxTy.isAssumedRank()) {
2871+
2872+
TypePair boxTypePair{boxTy, llvmLoadTy};
2873+
mlir::Value boxSize =
2874+
computeBoxSize(loc, boxTypePair, inputBoxStorage, rewriter);
2875+
auto memcpy = rewriter.create<mlir::LLVM::MemcpyOp>(
2876+
loc, newBoxStorage, inputBoxStorage, boxSize, /*isVolatile=*/false);
2877+
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
2878+
memcpy.setTBAATags(*optionalTag);
2879+
else
2880+
attachTBAATag(memcpy, boxTy, boxTy, nullptr);
2881+
} else {
2882+
auto boxValue = rewriter.create<mlir::LLVM::LoadOp>(loc, llvmLoadTy,
2883+
inputBoxStorage);
2884+
if (std::optional<mlir::ArrayAttr> optionalTag = load.getTbaa())
2885+
boxValue.setTBAATags(*optionalTag);
2886+
else
2887+
attachTBAATag(boxValue, boxTy, boxTy, nullptr);
2888+
auto storeOp =
2889+
rewriter.create<mlir::LLVM::StoreOp>(loc, boxValue, newBoxStorage);
2890+
attachTBAATag(storeOp, boxTy, boxTy, nullptr);
2891+
}
28832892
rewriter.replaceOp(load, newBoxStorage);
28842893
} else {
28852894
auto loadOp = rewriter.create<mlir::LLVM::LoadOp>(

flang/lib/Optimizer/CodeGen/FIROpPatterns.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,37 @@ mlir::Value ConvertFIRToLLVMPattern::genBoxAttributeCheck(
240240
maskRes, c0);
241241
}
242242

243+
mlir::Value ConvertFIRToLLVMPattern::computeBoxSize(
244+
mlir::Location loc, TypePair boxTy, mlir::Value box,
245+
mlir::ConversionPatternRewriter &rewriter) const {
246+
auto firBoxType = mlir::dyn_cast<fir::BaseBoxType>(boxTy.fir);
247+
assert(firBoxType && "must be a BaseBoxType");
248+
const mlir::DataLayout &dl = lowerTy().getDataLayout();
249+
if (!firBoxType.isAssumedRank())
250+
return genConstantOffset(loc, rewriter, dl.getTypeSize(boxTy.llvm));
251+
fir::BaseBoxType firScalarBoxType = firBoxType.getBoxTypeWithNewShape(0);
252+
mlir::Type llvmScalarBoxType =
253+
lowerTy().convertBoxTypeAsStruct(firScalarBoxType);
254+
llvm::TypeSize scalarBoxSizeCst = dl.getTypeSize(llvmScalarBoxType);
255+
mlir::Value scalarBoxSize =
256+
genConstantOffset(loc, rewriter, scalarBoxSizeCst);
257+
mlir::Value rawRank = getRankFromBox(loc, boxTy, box, rewriter);
258+
mlir::Value rank =
259+
integerCast(loc, rewriter, scalarBoxSize.getType(), rawRank);
260+
mlir::Type llvmDimsType = getBoxEleTy(boxTy.llvm, {kDimsPosInBox, 1});
261+
llvm::TypeSize sizePerDimCst = dl.getTypeSize(llvmDimsType);
262+
assert((scalarBoxSizeCst + sizePerDimCst ==
263+
dl.getTypeSize(lowerTy().convertBoxTypeAsStruct(
264+
firBoxType.getBoxTypeWithNewShape(1)))) &&
265+
"descriptor layout requires adding padding for dim field");
266+
mlir::Value sizePerDim = genConstantOffset(loc, rewriter, sizePerDimCst);
267+
mlir::Value dimsSize = rewriter.create<mlir::LLVM::MulOp>(
268+
loc, sizePerDim.getType(), sizePerDim, rank);
269+
mlir::Value size = rewriter.create<mlir::LLVM::AddOp>(
270+
loc, scalarBoxSize.getType(), scalarBoxSize, dimsSize);
271+
return size;
272+
}
273+
243274
// Find the Block in which the alloca should be inserted.
244275
// The order to recursively find the proper block:
245276
// 1. An OpenMP Op that will be outlined.

flang/lib/Optimizer/CodeGen/TypeConverter.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "flang/Optimizer/CodeGen/TypeConverter.h"
1616
#include "DescriptorModel.h"
17+
#include "flang/Common/Fortran.h"
1718
#include "flang/Optimizer/Builder/Todo.h" // remove when TODO's are done
1819
#include "flang/Optimizer/CodeGen/TBAABuilder.h"
1920
#include "flang/Optimizer/CodeGen/Target.h"
@@ -36,7 +37,8 @@ LLVMTypeConverter::LLVMTypeConverter(mlir::ModuleOp module, bool applyTBAA,
3637
module.getContext(), getTargetTriple(module), getKindMapping(module),
3738
getTargetCPU(module), getTargetFeatures(module), dl)),
3839
tbaaBuilder(std::make_unique<TBAABuilder>(module->getContext(), applyTBAA,
39-
forceUnifiedTBAATree)) {
40+
forceUnifiedTBAATree)),
41+
dataLayout{&dl} {
4042
LLVM_DEBUG(llvm::dbgs() << "FIR type converter\n");
4143

4244
// Each conversion should return a value of type mlir::Type.
@@ -243,7 +245,10 @@ mlir::Type LLVMTypeConverter::convertBoxTypeAsStruct(BaseBoxType box,
243245
// [dims]
244246
if (rank == unknownRank()) {
245247
if (auto seqTy = mlir::dyn_cast<SequenceType>(ele))
246-
rank = seqTy.getDimension();
248+
if (seqTy.hasUnknownShape())
249+
rank = Fortran::common::maxRank;
250+
else
251+
rank = seqTy.getDimension();
247252
else
248253
rank = 0;
249254
}

flang/test/Fir/convert-to-llvm.fir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,31 @@ func.func @test_load_box(%addr : !fir.ref<!fir.box<!fir.array<10xf32>>>) {
931931

932932
// -----
933933

934+
func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
935+
%0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
936+
fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
937+
return
938+
}
939+
func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
940+
941+
// CHECK-LABEL: llvm.func @test_assumed_rank_load(
942+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) {
943+
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
944+
// GENERIC: %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
945+
// AMDGPU: %[[VAL_2A:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr<5>
946+
// AMDGPU: %[[VAL_2:.*]] = llvm.addrspacecast %[[VAL_2A]] : !llvm.ptr<5> to !llvm.ptr
947+
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
948+
// CHECK: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
949+
// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] : !llvm.ptr -> i8
950+
// CHECK: %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
951+
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
952+
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
953+
// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
954+
// CHECK: "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
955+
// CHECK: llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()
956+
957+
// -----
958+
934959
// Test `fir.box_rank` conversion.
935960

936961
func.func @extract_rank(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {

flang/test/Fir/tbaa.fir

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i32 {
247247

248248
// CHECK-LABEL: llvm.func @tbaa(
249249
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i32 {
250-
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
250+
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
251251
// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
252252
// CHECK: %[[VAL_3:.*]] = llvm.sext %[[VAL_2]] : i8 to i32
253253
// CHECK: llvm.return %[[VAL_3]] : i32
@@ -267,7 +267,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
267267

268268
// CHECK-LABEL: llvm.func @tbaa(
269269
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
270-
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
270+
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
271271
// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
272272
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i8
273273
// CHECK: %[[VAL_4:.*]] = llvm.icmp "ne" %[[VAL_2]], %[[VAL_3]] : i8
@@ -307,7 +307,7 @@ func.func @tbaa(%arg0: !fir.box<!fir.array<*:f64>>) -> i1 {
307307

308308
// CHECK-LABEL: llvm.func @tbaa(
309309
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> i1 {
310-
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8)>
310+
// CHECK: %[[VAL_1:.*]] = llvm.getelementptr %[[VAL_0]][0, 5] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
311311
// CHECK: %[[VAL_2:.*]] = llvm.load %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i32
312312
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(2 : i32) : i32
313313
// CHECK: %[[VAL_4:.*]] = llvm.and %[[VAL_2]], %[[VAL_3]] : i32
@@ -379,3 +379,31 @@ func.func @tbaa(%arg0: !fir.ref<!fir.array<2x!fir.type<_QMtypesTt{x:!fir.box<!fi
379379
// CHECK-LABEL: llvm.func @tbaa(
380380
// CHECK: llvm.load{{.*}}{tbaa = [#[[$ANYT]]]}
381381
// CHECK: llvm.store{{.*}}{tbaa = [#[[$ANYT]]]}
382+
383+
// -----
384+
385+
func.func @test_assumed_rank_load(%arg0: !fir.ref<!fir.box<!fir.array<*:f64>>>) -> () {
386+
%0 = fir.load %arg0 : !fir.ref<!fir.box<!fir.array<*:f64>>>
387+
fir.call @some_assumed_rank_func(%0) : (!fir.box<!fir.array<*:f64>>) -> ()
388+
return
389+
}
390+
func.func private @some_assumed_rank_func(!fir.box<!fir.array<*:f64>>) -> ()
391+
392+
// CHECK-DAG: #[[ROOT:.*]] = #llvm.tbaa_root<id = "Flang function root ">
393+
// CHECK-DAG: #[[ANYACC:.*]] = #llvm.tbaa_type_desc<id = "any access", members = {<#[[ROOT]], 0>}>
394+
// CHECK-DAG: #[[BOXMEM:.*]] = #llvm.tbaa_type_desc<id = "descriptor member", members = {<#[[ANYACC]], 0>}>
395+
// CHECK-DAG: #[[$BOXT:.*]] = #llvm.tbaa_tag<base_type = #[[BOXMEM]], access_type = #[[BOXMEM]], offset = 0>
396+
397+
// CHECK-LABEL: llvm.func @test_assumed_rank_load(
398+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) {
399+
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : i32
400+
// CHECK: %[[VAL_2:.*]] = llvm.alloca %[[VAL_1]] x !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)> {alignment = 8 : i64} : (i32) -> !llvm.ptr
401+
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(24 : i32) : i32
402+
// CHECK: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]][0, 3] : (!llvm.ptr) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<15 x array<3 x i64>>)>
403+
// CHECK: %[[VAL_5:.*]] = llvm.load %[[VAL_4]] {tbaa = [#[[$BOXT]]]} : !llvm.ptr -> i8
404+
// CHECK: %[[VAL_6:.*]] = llvm.sext %[[VAL_5]] : i8 to i32
405+
// CHECK: %[[VAL_7:.*]] = llvm.mlir.constant(24 : i32) : i32
406+
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_7]], %[[VAL_6]] : i32
407+
// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_3]], %[[VAL_8]] : i32
408+
// CHECK: "llvm.intr.memcpy"(%[[VAL_2]], %[[VAL_0]], %[[VAL_9]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
409+
// CHECK: llvm.call @some_assumed_rank_func(%[[VAL_2]]) : (!llvm.ptr) -> ()

0 commit comments

Comments
 (0)