Skip to content

Commit a664c14

Browse files
nicolasvasilachejoker-eph
authored andcommitted
[mlir][LLVM] Revert bareptr calling convention handling as an argument materialization.
Type conversion and argument materialization are context-free: there is no available information on which op / branch is currently being converted. As a consequence, bare ptr convention cannot be handled as an argument materialization: it would apply irrespectively of the parent op. This doesn't typecheck in the case of non-funcOp and we would see cases where a memref descriptor would be inserted in place of the pointer in another memref descriptor. For now the proper behavior is to revert to a specific BarePtrFunc implementation and drop the blanket argument materialization logic. This reverts the relevant piece of the conversion to LLVM to what it was before https://reviews.llvm.org/D105880 and adds a relevant test and documentation to avoid the mistake by whomever attempts this again in the future. Reviewed By: arpith-jacob Differential Revision: https://reviews.llvm.org/D106495
1 parent c75a2bb commit a664c14

File tree

3 files changed

+84
-13
lines changed

3 files changed

+84
-13
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
5858
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
5959
ValueRange inputs,
6060
Location loc) -> Optional<Value> {
61-
// Explicit "this" is necessary here because otherwise "options" resolves to
62-
// the argument of the parent function (constructor), which is a reference
63-
// and not a copy. This can lead to UB when the lambda is actually called.
64-
if (this->options.useBarePtrCallConv) {
65-
if (!resultType.hasStaticShape())
66-
return llvm::None;
67-
Value v = MemRefDescriptor::fromStaticShape(builder, loc, *this,
68-
resultType, inputs[0]);
69-
return v;
70-
}
61+
// TODO: bare ptr conversion could be handled here but we would need a way
62+
// to distinguish between FuncOp and other regions.
7163
if (inputs.size() == 1)
7264
return llvm::None;
7365
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,62 @@ struct BarePtrFuncOpConversion : public FuncOpConversionBase {
309309
LogicalResult
310310
matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
311311
ConversionPatternRewriter &rewriter) const override {
312+
313+
// TODO: bare ptr conversion could be handled by argument materialization
314+
// and most of the code below would go away. But to do this, we would need a
315+
// way to distinguish between FuncOp and other regions in the
316+
// addArgumentMaterialization hook.
317+
318+
// Store the type of memref-typed arguments before the conversion so that we
319+
// can promote them to MemRef descriptor at the beginning of the function.
320+
SmallVector<Type, 8> oldArgTypes =
321+
llvm::to_vector<8>(funcOp.getType().getInputs());
322+
312323
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
313324
if (!newFuncOp)
314325
return failure();
326+
if (newFuncOp.getBody().empty()) {
327+
rewriter.eraseOp(funcOp);
328+
return success();
329+
}
330+
331+
// Promote bare pointers from memref arguments to memref descriptors at the
332+
// beginning of the function so that all the memrefs in the function have a
333+
// uniform representation.
334+
Block *entryBlock = &newFuncOp.getBody().front();
335+
auto blockArgs = entryBlock->getArguments();
336+
assert(blockArgs.size() == oldArgTypes.size() &&
337+
"The number of arguments and types doesn't match");
338+
339+
OpBuilder::InsertionGuard guard(rewriter);
340+
rewriter.setInsertionPointToStart(entryBlock);
341+
for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
342+
BlockArgument arg = std::get<0>(it);
343+
Type argTy = std::get<1>(it);
344+
345+
// Unranked memrefs are not supported in the bare pointer calling
346+
// convention. We should have bailed out before in the presence of
347+
// unranked memrefs.
348+
assert(!argTy.isa<UnrankedMemRefType>() &&
349+
"Unranked memref is not supported");
350+
auto memrefTy = argTy.dyn_cast<MemRefType>();
351+
if (!memrefTy)
352+
continue;
353+
354+
// Replace barePtr with a placeholder (undef), promote barePtr to a ranked
355+
// or unranked memref descriptor and replace placeholder with the last
356+
// instruction of the memref descriptor.
357+
// TODO: The placeholder is needed to avoid replacing barePtr uses in the
358+
// MemRef descriptor instructions. We may want to have a utility in the
359+
// rewriter to properly handle this use case.
360+
Location loc = funcOp.getLoc();
361+
auto placeholder = rewriter.create<LLVM::UndefOp>(loc, memrefTy);
362+
rewriter.replaceUsesOfBlockArgument(arg, placeholder);
363+
364+
Value desc = MemRefDescriptor::fromStaticShape(
365+
rewriter, loc, *getTypeConverter(), memrefTy, arg);
366+
rewriter.replaceOp(placeholder, {desc});
367+
}
315368

316369
rewriter.eraseOp(funcOp);
317370
return success();
@@ -330,7 +383,8 @@ using DivFOpLowering = VectorConvertToLLVMPattern<DivFOp, LLVM::FDivOp>;
330383
using FPExtOpLowering = VectorConvertToLLVMPattern<FPExtOp, LLVM::FPExtOp>;
331384
using FPToSIOpLowering = VectorConvertToLLVMPattern<FPToSIOp, LLVM::FPToSIOp>;
332385
using FPToUIOpLowering = VectorConvertToLLVMPattern<FPToUIOp, LLVM::FPToUIOp>;
333-
using FPTruncOpLowering = VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
386+
using FPTruncOpLowering =
387+
VectorConvertToLLVMPattern<FPTruncOp, LLVM::FPTruncOp>;
334388
using FloorFOpLowering = VectorConvertToLLVMPattern<FloorFOp, LLVM::FFloorOp>;
335389
using FmaFOpLowering = VectorConvertToLLVMPattern<FmaFOp, LLVM::FMAOp>;
336390
using MulFOpLowering = VectorConvertToLLVMPattern<MulFOp, LLVM::FMulOp>;
@@ -352,7 +406,8 @@ using SignedShiftRightOpLowering =
352406
OneToOneConvertToLLVMPattern<SignedShiftRightOp, LLVM::AShrOp>;
353407
using SubFOpLowering = VectorConvertToLLVMPattern<SubFOp, LLVM::FSubOp>;
354408
using SubIOpLowering = VectorConvertToLLVMPattern<SubIOp, LLVM::SubOp>;
355-
using TruncateIOpLowering = VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
409+
using TruncateIOpLowering =
410+
VectorConvertToLLVMPattern<TruncateIOp, LLVM::TruncOp>;
356411
using UIToFPOpLowering = VectorConvertToLLVMPattern<UIToFPOp, LLVM::UIToFPOp>;
357412
using UnsignedDivIOpLowering =
358413
VectorConvertToLLVMPattern<UnsignedDivIOp, LLVM::UDivOp>;
@@ -1196,4 +1251,3 @@ mlir::createLowerToLLVMPass(const LowerToLLVMOptions &options) {
11961251
options.useBarePtrCallConv, options.emitCWrappers,
11971252
options.getIndexBitwidth(), useAlignedAlloc, options.dataLayout);
11981253
}
1199-

mlir/test/Conversion/StandardToLLVM/func-memref.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,28 @@ func @check_scalar_func_call(%in : f32) {
182182
%res = call @goo(%in) : (f32) -> (f32)
183183
return
184184
}
185+
186+
// -----
187+
188+
!base_type = type memref<64xi32, 201>
189+
190+
// CHECK-LABEL: func @loop_carried
191+
// BAREPTR-LABEL: func @loop_carried
192+
func @loop_carried(%arg0 : index, %arg1 : index, %arg2 : index, %base0 : !base_type, %base1 : !base_type) -> (!base_type, !base_type) {
193+
// This test checks that in the BAREPTR case, the branch arguments only forward the descriptor.
194+
// This test was lowered from a simple scf.for that swaps 2 memref iter_args.
195+
// BAREPTR: llvm.br ^bb1(%{{.*}}, %{{.*}}, %{{.*}} : i64, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>, !llvm.struct<(ptr<i32, 201>, ptr<i32, 201>, i64, array<1 x i64>, array<1 x i64>)>)
196+
br ^bb1(%arg0, %base0, %base1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
197+
198+
// BAREPTR-NEXT: ^bb1
199+
// BAREPTR-NEXT: llvm.icmp
200+
// BAREPTR-NEXT: llvm.cond_br %{{.*}}, ^bb2, ^bb3
201+
^bb1(%0: index, %1: memref<64xi32, 201>, %2: memref<64xi32, 201>): // 2 preds: ^bb0, ^bb2
202+
%3 = cmpi slt, %0, %arg1 : index
203+
cond_br %3, ^bb2, ^bb3
204+
^bb2: // pred: ^bb1
205+
%4 = addi %0, %arg2 : index
206+
br ^bb1(%4, %2, %1 : index, memref<64xi32, 201>, memref<64xi32, 201>)
207+
^bb3: // pred: ^bb1
208+
return %1, %2 : memref<64xi32, 201>, memref<64xi32, 201>
209+
}

0 commit comments

Comments
 (0)