Skip to content

Commit ee4f78f

Browse files
matthias-springerAlexisPerry
authored andcommitted
[mlir][Conversion] FuncToLLVM: Simplify bare-pointer handling (llvm#96393)
Before this commit, there used to be a workaround in the `func.func`/`gpu.func` op lowering when the bare-pointer calling convention is enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate). This commit effectively reverts back to the old implementation (a664c14) and adds additional checks to make sure that bare pointers are used only for function entry block arguments.
1 parent b768f92 commit ee4f78f

File tree

3 files changed

+17
-87
lines changed

3 files changed

+17
-87
lines changed

mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -268,55 +268,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
268268
}
269269
}
270270

271-
/// Modifies the body of the function to construct the `MemRefDescriptor` from
272-
/// the bare pointer calling convention lowering of `memref` types.
273-
static void modifyFuncOpToUseBarePtrCallingConv(
274-
ConversionPatternRewriter &rewriter, Location loc,
275-
const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
276-
TypeRange oldArgTypes) {
277-
if (funcOp.getBody().empty())
278-
return;
279-
280-
// Promote bare pointers from memref arguments to memref descriptors at the
281-
// beginning of the function so that all the memrefs in the function have a
282-
// uniform representation.
283-
Block *entryBlock = &funcOp.getBody().front();
284-
auto blockArgs = entryBlock->getArguments();
285-
assert(blockArgs.size() == oldArgTypes.size() &&
286-
"The number of arguments and types doesn't match");
287-
288-
OpBuilder::InsertionGuard guard(rewriter);
289-
rewriter.setInsertionPointToStart(entryBlock);
290-
for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
291-
BlockArgument arg = std::get<0>(it);
292-
Type argTy = std::get<1>(it);
293-
294-
// Unranked memrefs are not supported in the bare pointer calling
295-
// convention. We should have bailed out before in the presence of
296-
// unranked memrefs.
297-
assert(!isa<UnrankedMemRefType>(argTy) &&
298-
"Unranked memref is not supported");
299-
auto memrefTy = dyn_cast<MemRefType>(argTy);
300-
if (!memrefTy)
301-
continue;
302-
303-
// Replace barePtr with a placeholder (undef), promote barePtr to a ranked
304-
// or unranked memref descriptor and replace placeholder with the last
305-
// instruction of the memref descriptor.
306-
// TODO: The placeholder is needed to avoid replacing barePtr uses in the
307-
// MemRef descriptor instructions. We may want to have a utility in the
308-
// rewriter to properly handle this use case.
309-
Location loc = funcOp.getLoc();
310-
auto placeholder = rewriter.create<LLVM::UndefOp>(
311-
loc, typeConverter.convertType(memrefTy));
312-
rewriter.replaceUsesOfBlockArgument(arg, placeholder);
313-
314-
Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
315-
memrefTy, arg);
316-
rewriter.replaceOp(placeholder, {desc});
317-
}
318-
}
319-
320271
FailureOr<LLVM::LLVMFuncOp>
321272
mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
322273
ConversionPatternRewriter &rewriter,
@@ -462,10 +413,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
462413
wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
463414
newFuncOp);
464415
}
465-
} else {
466-
modifyFuncOpToUseBarePtrCallingConv(
467-
rewriter, funcOp->getLoc(), converter, newFuncOp,
468-
llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
469416
}
470417

471418
return newFuncOp;

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -182,35 +182,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
182182
&signatureConversion)))
183183
return failure();
184184

185-
// If bare memref pointers are being used, remap them back to memref
186-
// descriptors This must be done after signature conversion to get rid of the
187-
// unrealized casts.
188-
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
189-
OpBuilder::InsertionGuard guard(rewriter);
190-
rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
191-
for (const auto [idx, argTy] :
192-
llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
193-
auto memrefTy = dyn_cast<MemRefType>(argTy);
194-
if (!memrefTy)
195-
continue;
196-
assert(memrefTy.hasStaticShape() &&
197-
"Bare pointer convertion used with dynamically-shaped memrefs");
198-
// Use a placeholder when replacing uses of the memref argument to prevent
199-
// circular replacements.
200-
auto remapping = signatureConversion.getInputMapping(idx);
201-
assert(remapping && remapping->size == 1 &&
202-
"Type converter should produce 1-to-1 mapping for bare memrefs");
203-
BlockArgument newArg =
204-
llvmFuncOp.getBody().getArgument(remapping->inputNo);
205-
auto placeholder = rewriter.create<LLVM::UndefOp>(
206-
loc, getTypeConverter()->convertType(memrefTy));
207-
rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
208-
Value desc = MemRefDescriptor::fromStaticShape(
209-
rewriter, loc, *getTypeConverter(), memrefTy, newArg);
210-
rewriter.replaceOp(placeholder, {desc});
211-
}
212-
}
213-
214185
// Get memref type from function arguments and set the noalias to
215186
// pointer arguments.
216187
for (const auto [idx, argTy] :

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
159159
addArgumentMaterialization(
160160
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
161161
Location loc) -> std::optional<Value> {
162-
if (inputs.size() == 1)
162+
if (inputs.size() == 1) {
163+
// Bare pointers are not supported for unranked memrefs because a
164+
// memref descriptor cannot be built just from a bare pointer.
163165
return std::nullopt;
166+
}
164167
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
165168
inputs);
166169
});
167170
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
168171
ValueRange inputs,
169172
Location loc) -> std::optional<Value> {
170-
// TODO: bare ptr conversion could be handled here but we would need a way
171-
// to distinguish between FuncOp and other regions.
172-
if (inputs.size() == 1)
173-
return std::nullopt;
173+
if (inputs.size() == 1) {
174+
// This is a bare pointer. We allow bare pointers only for function entry
175+
// blocks.
176+
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
177+
if (!barePtr)
178+
return std::nullopt;
179+
Block *block = barePtr.getOwner();
180+
if (!block->isEntryBlock() ||
181+
!isa<FunctionOpInterface>(block->getParentOp()))
182+
return std::nullopt;
183+
return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
184+
inputs[0]);
185+
}
174186
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
175187
});
176188
// Add generic source and target materializations to handle cases where

0 commit comments

Comments
 (0)