Skip to content

[mlir][Conversion] FuncToLLVM: Simplify bare-pointer handling #96393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2024

Conversation

matthias-springer
Copy link
Member

Before this commit, there used to be a workaround in the func.func/gpu.func op lowering when the bare-pointer calling convention is enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate).

This commit effectively reverts back to the old implementation (a664c14) and adds additional checks to make sure that bare pointers are used only for function entry block arguments.

@llvmbot
Copy link
Member

llvmbot commented Jun 22, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-gpu

Author: Matthias Springer (matthias-springer)

Changes

Before this commit, there used to be a workaround in the func.func/gpu.func op lowering when the bare-pointer calling convention is enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate).

This commit effectively reverts back to the old implementation (a664c14) and adds additional checks to make sure that bare pointers are used only for function entry block arguments.


Full diff: https://github.com/llvm/llvm-project/pull/96393.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (-53)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+49-79)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+17-5)
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 744236692fbb6..efb80467369a2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -268,55 +268,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
-/// Modifies the body of the function to construct the `MemRefDescriptor` from
-/// the bare pointer calling convention lowering of `memref` types.
-static void modifyFuncOpToUseBarePtrCallingConv(
-    ConversionPatternRewriter &rewriter, Location loc,
-    const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
-    TypeRange oldArgTypes) {
-  if (funcOp.getBody().empty())
-    return;
-
-  // Promote bare pointers from memref arguments to memref descriptors at the
-  // beginning of the function so that all the memrefs in the function have a
-  // uniform representation.
-  Block *entryBlock = &funcOp.getBody().front();
-  auto blockArgs = entryBlock->getArguments();
-  assert(blockArgs.size() == oldArgTypes.size() &&
-         "The number of arguments and types doesn't match");
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(entryBlock);
-  for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
-    BlockArgument arg = std::get<0>(it);
-    Type argTy = std::get<1>(it);
-
-    // Unranked memrefs are not supported in the bare pointer calling
-    // convention. We should have bailed out before in the presence of
-    // unranked memrefs.
-    assert(!isa<UnrankedMemRefType>(argTy) &&
-           "Unranked memref is not supported");
-    auto memrefTy = dyn_cast<MemRefType>(argTy);
-    if (!memrefTy)
-      continue;
-
-    // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
-    // or unranked memref descriptor and replace placeholder with the last
-    // instruction of the memref descriptor.
-    // TODO: The placeholder is needed to avoid replacing barePtr uses in the
-    // MemRef descriptor instructions. We may want to have a utility in the
-    // rewriter to properly handle this use case.
-    Location loc = funcOp.getLoc();
-    auto placeholder = rewriter.create<LLVM::UndefOp>(
-        loc, typeConverter.convertType(memrefTy));
-    rewriter.replaceUsesOfBlockArgument(arg, placeholder);
-
-    Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
-                                                   memrefTy, arg);
-    rewriter.replaceOp(placeholder, {desc});
-  }
-}
-
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -462,10 +413,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
         wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
                                newFuncOp);
     }
-  } else {
-    modifyFuncOpToUseBarePtrCallingConv(
-        rewriter, funcOp->getLoc(), converter, newFuncOp,
-        llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
   }
 
   return newFuncOp;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 7ea05b7e7f6c1..6053e34f30a41 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -182,35 +182,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                          &signatureConversion)))
     return failure();
 
-  // If bare memref pointers are being used, remap them back to memref
-  // descriptors This must be done after signature conversion to get rid of the
-  // unrealized casts.
-  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
-    for (const auto [idx, argTy] :
-         llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
-      auto memrefTy = dyn_cast<MemRefType>(argTy);
-      if (!memrefTy)
-        continue;
-      assert(memrefTy.hasStaticShape() &&
-             "Bare pointer convertion used with dynamically-shaped memrefs");
-      // Use a placeholder when replacing uses of the memref argument to prevent
-      // circular replacements.
-      auto remapping = signatureConversion.getInputMapping(idx);
-      assert(remapping && remapping->size == 1 &&
-             "Type converter should produce 1-to-1 mapping for bare memrefs");
-      BlockArgument newArg =
-          llvmFuncOp.getBody().getArgument(remapping->inputNo);
-      auto placeholder = rewriter.create<LLVM::UndefOp>(
-          loc, getTypeConverter()->convertType(memrefTy));
-      rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
-      Value desc = MemRefDescriptor::fromStaticShape(
-          rewriter, loc, *getTypeConverter(), memrefTy, newArg);
-      rewriter.replaceOp(placeholder, {desc});
-    }
-  }
-
   // Get memref type from function arguments and set the noalias to
   // pointer arguments.
   for (const auto [idx, argTy] :
@@ -684,62 +655,61 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
   return success();
 }
 
-  LogicalResult
-  GPUReturnOpLowering::matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const {
-    Location loc = op.getLoc();
-    unsigned numArguments = op.getNumOperands();
-    SmallVector<Value, 4> updatedOperands;
-
-    bool useBarePtrCallConv =
-        getTypeConverter()->getOptions().useBarePtrCallConv;
-    if (useBarePtrCallConv) {
-      // For the bare-ptr calling convention, extract the aligned pointer to
-      // be returned from the memref descriptor.
-      for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
-        Type oldTy = std::get<0>(it).getType();
-        Value newOperand = std::get<1>(it);
-        if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
-                                          cast<BaseMemRefType>(oldTy))) {
-          MemRefDescriptor memrefDesc(newOperand);
-          newOperand = memrefDesc.allocatedPtr(rewriter, loc);
-        } else if (isa<UnrankedMemRefType>(oldTy)) {
-          // Unranked memref is not supported in the bare pointer calling
-          // convention.
-          return failure();
-        }
-        updatedOperands.push_back(newOperand);
+LogicalResult GPUReturnOpLowering::matchAndRewrite(
+    gpu::ReturnOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  unsigned numArguments = op.getNumOperands();
+  SmallVector<Value, 4> updatedOperands;
+
+  bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
+  if (useBarePtrCallConv) {
+    // For the bare-ptr calling convention, extract the aligned pointer to
+    // be returned from the memref descriptor.
+    for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
+      Type oldTy = std::get<0>(it).getType();
+      Value newOperand = std::get<1>(it);
+      if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
+                                        cast<BaseMemRefType>(oldTy))) {
+        MemRefDescriptor memrefDesc(newOperand);
+        newOperand = memrefDesc.allocatedPtr(rewriter, loc);
+      } else if (isa<UnrankedMemRefType>(oldTy)) {
+        // Unranked memref is not supported in the bare pointer calling
+        // convention.
+        return failure();
       }
-    } else {
-      updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
-      (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
-                                    updatedOperands,
-                                    /*toDynamic=*/true);
+      updatedOperands.push_back(newOperand);
     }
+  } else {
+    updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
+    (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
+                                  updatedOperands,
+                                  /*toDynamic=*/true);
+  }
 
-    // If ReturnOp has 0 or 1 operand, create it and return immediately.
-    if (numArguments <= 1) {
-      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
-          op, TypeRange(), updatedOperands, op->getAttrs());
-      return success();
-    }
+  // If ReturnOp has 0 or 1 operand, create it and return immediately.
+  if (numArguments <= 1) {
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+        op, TypeRange(), updatedOperands, op->getAttrs());
+    return success();
+  }
 
-    // Otherwise, we need to pack the arguments into an LLVM struct type before
-    // returning.
-    auto packedType = getTypeConverter()->packFunctionResults(
-        op.getOperandTypes(), useBarePtrCallConv);
-    if (!packedType) {
-      return rewriter.notifyMatchFailure(op, "could not convert result types");
-    }
+  // Otherwise, we need to pack the arguments into an LLVM struct type before
+  // returning.
+  auto packedType = getTypeConverter()->packFunctionResults(
+      op.getOperandTypes(), useBarePtrCallConv);
+  if (!packedType) {
+    return rewriter.notifyMatchFailure(op, "could not convert result types");
+  }
 
-    Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
-    for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
-      packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
-    }
-    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
-                                                op->getAttrs());
-    return success();
+  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
+  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
+    packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
   }
+  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
+                                              op->getAttrs());
+  return success();
+}
 
 void mlir::populateGpuMemorySpaceAttributeConversions(
     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 3a01795ce3f53..f5620a6a7cd91 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,18 +159,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization(
       [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
           Location loc) -> std::optional<Value> {
-        if (inputs.size() == 1)
+        if (inputs.size() == 1) {
+          // Bare pointers are not supported for unranked memrefs because a
+          // memref descriptor cannot be built just from a bare pointer.
           return std::nullopt;
+        }
         return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
                                               inputs);
       });
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs,
                                  Location loc) -> std::optional<Value> {
-    // TODO: bare ptr conversion could be handled here but we would need a way
-    // to distinguish between FuncOp and other regions.
-    if (inputs.size() == 1)
-      return std::nullopt;
+    if (inputs.size() == 1) {
+      // This is a bare pointer. We allow bare pointers only for function entry
+      // blocks.
+      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+      if (!barePtr)
+        return std::nullopt;
+      Block *block = barePtr.getOwner();
+      if (!block->isEntryBlock() ||
+          !isa<FunctionOpInterface>(block->getParentOp()))
+        return std::nullopt;
+      return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+                                               inputs[0]);
+    }
     return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
   });
   // Add generic source and target materializations to handle cases where

@llvmbot
Copy link
Member

llvmbot commented Jun 22, 2024

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

Before this commit, there used to be a workaround in the func.func/gpu.func op lowering when the bare-pointer calling convention is enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate).

This commit effectively reverts back to the old implementation (a664c14) and adds additional checks to make sure that bare pointers are used only for function entry block arguments.


Full diff: https://github.com/llvm/llvm-project/pull/96393.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (-53)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+49-79)
  • (modified) mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp (+17-5)
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 744236692fbb6..efb80467369a2 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -268,55 +268,6 @@ static void wrapExternalFunction(OpBuilder &builder, Location loc,
   }
 }
 
-/// Modifies the body of the function to construct the `MemRefDescriptor` from
-/// the bare pointer calling convention lowering of `memref` types.
-static void modifyFuncOpToUseBarePtrCallingConv(
-    ConversionPatternRewriter &rewriter, Location loc,
-    const LLVMTypeConverter &typeConverter, LLVM::LLVMFuncOp funcOp,
-    TypeRange oldArgTypes) {
-  if (funcOp.getBody().empty())
-    return;
-
-  // Promote bare pointers from memref arguments to memref descriptors at the
-  // beginning of the function so that all the memrefs in the function have a
-  // uniform representation.
-  Block *entryBlock = &funcOp.getBody().front();
-  auto blockArgs = entryBlock->getArguments();
-  assert(blockArgs.size() == oldArgTypes.size() &&
-         "The number of arguments and types doesn't match");
-
-  OpBuilder::InsertionGuard guard(rewriter);
-  rewriter.setInsertionPointToStart(entryBlock);
-  for (auto it : llvm::zip(blockArgs, oldArgTypes)) {
-    BlockArgument arg = std::get<0>(it);
-    Type argTy = std::get<1>(it);
-
-    // Unranked memrefs are not supported in the bare pointer calling
-    // convention. We should have bailed out before in the presence of
-    // unranked memrefs.
-    assert(!isa<UnrankedMemRefType>(argTy) &&
-           "Unranked memref is not supported");
-    auto memrefTy = dyn_cast<MemRefType>(argTy);
-    if (!memrefTy)
-      continue;
-
-    // Replace barePtr with a placeholder (undef), promote barePtr to a ranked
-    // or unranked memref descriptor and replace placeholder with the last
-    // instruction of the memref descriptor.
-    // TODO: The placeholder is needed to avoid replacing barePtr uses in the
-    // MemRef descriptor instructions. We may want to have a utility in the
-    // rewriter to properly handle this use case.
-    Location loc = funcOp.getLoc();
-    auto placeholder = rewriter.create<LLVM::UndefOp>(
-        loc, typeConverter.convertType(memrefTy));
-    rewriter.replaceUsesOfBlockArgument(arg, placeholder);
-
-    Value desc = MemRefDescriptor::fromStaticShape(rewriter, loc, typeConverter,
-                                                   memrefTy, arg);
-    rewriter.replaceOp(placeholder, {desc});
-  }
-}
-
 FailureOr<LLVM::LLVMFuncOp>
 mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
                                 ConversionPatternRewriter &rewriter,
@@ -462,10 +413,6 @@ mlir::convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
         wrapForExternalCallers(rewriter, funcOp->getLoc(), converter, funcOp,
                                newFuncOp);
     }
-  } else {
-    modifyFuncOpToUseBarePtrCallingConv(
-        rewriter, funcOp->getLoc(), converter, newFuncOp,
-        llvm::cast<FunctionType>(funcOp.getFunctionType()).getInputs());
   }
 
   return newFuncOp;
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 7ea05b7e7f6c1..6053e34f30a41 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -182,35 +182,6 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
                                          &signatureConversion)))
     return failure();
 
-  // If bare memref pointers are being used, remap them back to memref
-  // descriptors This must be done after signature conversion to get rid of the
-  // unrealized casts.
-  if (getTypeConverter()->getOptions().useBarePtrCallConv) {
-    OpBuilder::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
-    for (const auto [idx, argTy] :
-         llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
-      auto memrefTy = dyn_cast<MemRefType>(argTy);
-      if (!memrefTy)
-        continue;
-      assert(memrefTy.hasStaticShape() &&
-             "Bare pointer convertion used with dynamically-shaped memrefs");
-      // Use a placeholder when replacing uses of the memref argument to prevent
-      // circular replacements.
-      auto remapping = signatureConversion.getInputMapping(idx);
-      assert(remapping && remapping->size == 1 &&
-             "Type converter should produce 1-to-1 mapping for bare memrefs");
-      BlockArgument newArg =
-          llvmFuncOp.getBody().getArgument(remapping->inputNo);
-      auto placeholder = rewriter.create<LLVM::UndefOp>(
-          loc, getTypeConverter()->convertType(memrefTy));
-      rewriter.replaceUsesOfBlockArgument(newArg, placeholder);
-      Value desc = MemRefDescriptor::fromStaticShape(
-          rewriter, loc, *getTypeConverter(), memrefTy, newArg);
-      rewriter.replaceOp(placeholder, {desc});
-    }
-  }
-
   // Get memref type from function arguments and set the noalias to
   // pointer arguments.
   for (const auto [idx, argTy] :
@@ -684,62 +655,61 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
   return success();
 }
 
-  LogicalResult
-  GPUReturnOpLowering::matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const {
-    Location loc = op.getLoc();
-    unsigned numArguments = op.getNumOperands();
-    SmallVector<Value, 4> updatedOperands;
-
-    bool useBarePtrCallConv =
-        getTypeConverter()->getOptions().useBarePtrCallConv;
-    if (useBarePtrCallConv) {
-      // For the bare-ptr calling convention, extract the aligned pointer to
-      // be returned from the memref descriptor.
-      for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
-        Type oldTy = std::get<0>(it).getType();
-        Value newOperand = std::get<1>(it);
-        if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
-                                          cast<BaseMemRefType>(oldTy))) {
-          MemRefDescriptor memrefDesc(newOperand);
-          newOperand = memrefDesc.allocatedPtr(rewriter, loc);
-        } else if (isa<UnrankedMemRefType>(oldTy)) {
-          // Unranked memref is not supported in the bare pointer calling
-          // convention.
-          return failure();
-        }
-        updatedOperands.push_back(newOperand);
+LogicalResult GPUReturnOpLowering::matchAndRewrite(
+    gpu::ReturnOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  Location loc = op.getLoc();
+  unsigned numArguments = op.getNumOperands();
+  SmallVector<Value, 4> updatedOperands;
+
+  bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
+  if (useBarePtrCallConv) {
+    // For the bare-ptr calling convention, extract the aligned pointer to
+    // be returned from the memref descriptor.
+    for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
+      Type oldTy = std::get<0>(it).getType();
+      Value newOperand = std::get<1>(it);
+      if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
+                                        cast<BaseMemRefType>(oldTy))) {
+        MemRefDescriptor memrefDesc(newOperand);
+        newOperand = memrefDesc.allocatedPtr(rewriter, loc);
+      } else if (isa<UnrankedMemRefType>(oldTy)) {
+        // Unranked memref is not supported in the bare pointer calling
+        // convention.
+        return failure();
       }
-    } else {
-      updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
-      (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
-                                    updatedOperands,
-                                    /*toDynamic=*/true);
+      updatedOperands.push_back(newOperand);
     }
+  } else {
+    updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
+    (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
+                                  updatedOperands,
+                                  /*toDynamic=*/true);
+  }
 
-    // If ReturnOp has 0 or 1 operand, create it and return immediately.
-    if (numArguments <= 1) {
-      rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
-          op, TypeRange(), updatedOperands, op->getAttrs());
-      return success();
-    }
+  // If ReturnOp has 0 or 1 operand, create it and return immediately.
+  if (numArguments <= 1) {
+    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
+        op, TypeRange(), updatedOperands, op->getAttrs());
+    return success();
+  }
 
-    // Otherwise, we need to pack the arguments into an LLVM struct type before
-    // returning.
-    auto packedType = getTypeConverter()->packFunctionResults(
-        op.getOperandTypes(), useBarePtrCallConv);
-    if (!packedType) {
-      return rewriter.notifyMatchFailure(op, "could not convert result types");
-    }
+  // Otherwise, we need to pack the arguments into an LLVM struct type before
+  // returning.
+  auto packedType = getTypeConverter()->packFunctionResults(
+      op.getOperandTypes(), useBarePtrCallConv);
+  if (!packedType) {
+    return rewriter.notifyMatchFailure(op, "could not convert result types");
+  }
 
-    Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
-    for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
-      packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
-    }
-    rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
-                                                op->getAttrs());
-    return success();
+  Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
+  for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
+    packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
   }
+  rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
+                                              op->getAttrs());
+  return success();
+}
 
 void mlir::populateGpuMemorySpaceAttributeConversions(
     TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 3a01795ce3f53..f5620a6a7cd91 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -159,18 +159,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
   addArgumentMaterialization(
       [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
           Location loc) -> std::optional<Value> {
-        if (inputs.size() == 1)
+        if (inputs.size() == 1) {
+          // Bare pointers are not supported for unranked memrefs because a
+          // memref descriptor cannot be built just from a bare pointer.
           return std::nullopt;
+        }
         return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
                                               inputs);
       });
   addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                  ValueRange inputs,
                                  Location loc) -> std::optional<Value> {
-    // TODO: bare ptr conversion could be handled here but we would need a way
-    // to distinguish between FuncOp and other regions.
-    if (inputs.size() == 1)
-      return std::nullopt;
+    if (inputs.size() == 1) {
+      // This is a bare pointer. We allow bare pointers only for function entry
+      // blocks.
+      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+      if (!barePtr)
+        return std::nullopt;
+      Block *block = barePtr.getOwner();
+      if (!block->isEntryBlock() ||
+          !isa<FunctionOpInterface>(block->getParentOp()))
+        return std::nullopt;
+      return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
+                                               inputs[0]);
+    }
     return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
   });
   // Add generic source and target materializations to handle cases where

@matthias-springer matthias-springer force-pushed the users/matthias-springer/gpu_memref_lowering branch from 9eff486 to dac6cf3 Compare June 22, 2024 13:55
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bare_ptr_conv branch from f65911a to 2d83858 Compare June 22, 2024 13:56
Base automatically changed from users/matthias-springer/gpu_memref_lowering to main June 23, 2024 07:51
Before this commit, there used to be a workaround in the `func.func`/`gpu.func` op lowering when the bare-pointer calling convention was enabled. This workaround "patched up" the argument materializations for memref arguments. This can be done directly in the argument materialization functions (as the TODOs in the code base indicate).
@matthias-springer matthias-springer force-pushed the users/matthias-springer/bare_ptr_conv branch from 2d83858 to 0ae7616 Compare June 23, 2024 09:13
Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has anything changed between now and the patch this is reverting that makes this now feasible? The implementation looks to me as though it would have been feasible back then as well

@matthias-springer
Copy link
Member Author

matthias-springer commented Jun 23, 2024

That's right, this implementation would have been feasible back then. My best guess is that the author was confused about addArgumentMaterialization and did not know that the given ValueRange is always a list of bbargs that belong to a newly added block.

I'm thinking of adding a different MaterializationCallbackFn type for argument materializations to make that clear:

using ArgumentMaterializationCallbackFn = std::function<std::optional<Value>(
    OpBuilder &, Type, BlockArgListType, Location)>;

But that could be a larger change and I'd rather do that in a separate PR.

@zero9178
Copy link
Member

That's right, this implementation would have been feasible back then. My best guess is that the author was confused about addArgumentMaterialization and did not know that the given ValueRange is always a list of bbargs that belong to a newly added block.

I'm thinking of adding a different MaterializationCallbackFn type for argument materializations to make that clear:

using ArgumentMaterializationCallbackFn = std::function<std::optional<Value>(
    OpBuilder &, Type, BlockArgListType, Location)>;

But that could be a larger change and I'd rather do that in a separate PR.

That'd be great!

@matthias-springer matthias-springer merged commit 9e8ccf6 into main Jun 24, 2024
7 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/bare_ptr_conv branch June 24, 2024 06:38
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…#96393)

Before this commit, there used to be a workaround in the
`func.func`/`gpu.func` op lowering when the bare-pointer calling
convention is enabled. This workaround "patched up" the argument
materializations for memref arguments. This can be done directly in the
argument materialization functions (as the TODOs in the code base
indicate).

This commit effectively reverts back to the old implementation
(a664c14) and adds additional checks to
make sure that bare pointers are used only for function entry block
arguments.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants