Skip to content

Commit dac6cf3

Browse files
[mlir][GPUToNVVM] Fix memref function args/results
The `gpu.func` op lowering accounts for memref arguments/results (both "normal" and bare-pointer supported), but the `gpu.return` op lowering did not. This resulting in invalid result IR that did not verify. This commit uses the same lowering strategy as for `func.return` in the `gpu.return` lowering.
1 parent 34d44eb commit dac6cf3

File tree

3 files changed

+82
-8
lines changed

3 files changed

+82
-8
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,62 @@ LogicalResult GPUDynamicSharedMemoryOpLowering::matchAndRewrite(
684684
return success();
685685
}
686686

687+
LogicalResult GPUReturnOpLowering::matchAndRewrite(
688+
gpu::ReturnOp op, OpAdaptor adaptor,
689+
ConversionPatternRewriter &rewriter) const {
690+
Location loc = op.getLoc();
691+
unsigned numArguments = op.getNumOperands();
692+
SmallVector<Value, 4> updatedOperands;
693+
694+
bool useBarePtrCallConv = getTypeConverter()->getOptions().useBarePtrCallConv;
695+
if (useBarePtrCallConv) {
696+
// For the bare-ptr calling convention, extract the aligned pointer to
697+
// be returned from the memref descriptor.
698+
for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
699+
Type oldTy = std::get<0>(it).getType();
700+
Value newOperand = std::get<1>(it);
701+
if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
702+
cast<BaseMemRefType>(oldTy))) {
703+
MemRefDescriptor memrefDesc(newOperand);
704+
newOperand = memrefDesc.allocatedPtr(rewriter, loc);
705+
} else if (isa<UnrankedMemRefType>(oldTy)) {
706+
// Unranked memref is not supported in the bare pointer calling
707+
// convention.
708+
return failure();
709+
}
710+
updatedOperands.push_back(newOperand);
711+
}
712+
} else {
713+
updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
714+
(void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
715+
updatedOperands,
716+
/*toDynamic=*/true);
717+
}
718+
719+
// If ReturnOp has 0 or 1 operand, create it and return immediately.
720+
if (numArguments <= 1) {
721+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
722+
op, TypeRange(), updatedOperands, op->getAttrs());
723+
return success();
724+
}
725+
726+
// Otherwise, we need to pack the arguments into an LLVM struct type before
727+
// returning.
728+
auto packedType = getTypeConverter()->packFunctionResults(
729+
op.getOperandTypes(), useBarePtrCallConv);
730+
if (!packedType) {
731+
return rewriter.notifyMatchFailure(op, "could not convert result types");
732+
}
733+
734+
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
735+
for (auto [idx, operand] : llvm::enumerate(updatedOperands)) {
736+
packed = rewriter.create<LLVM::InsertValueOp>(loc, packed, operand, idx);
737+
}
738+
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), packed,
739+
op->getAttrs());
740+
return success();
741+
}
742+
687743
void mlir::populateGpuMemorySpaceAttributeConversions(
688744
TypeConverter &typeConverter, const MemorySpaceMapping &mapping) {
689745
typeConverter.addTypeAttributeConversion(

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,7 @@ struct GPUReturnOpLowering : public ConvertOpToLLVMPattern<gpu::ReturnOp> {
112112

113113
LogicalResult
114114
matchAndRewrite(gpu::ReturnOp op, OpAdaptor adaptor,
115-
ConversionPatternRewriter &rewriter) const override {
116-
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, adaptor.getOperands());
117-
return success();
118-
}
115+
ConversionPatternRewriter &rewriter) const override;
119116
};
120117

121118
namespace impl {

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1' -split-input-file | FileCheck %s
2+
// RUN: mlir-opt %s -convert-gpu-to-nvvm='has-redux=1 use-bare-ptr-memref-call-conv=1' -split-input-file | FileCheck %s --check-prefix=CHECK-BARE
23
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
34

45
gpu.module @test_module_0 {
@@ -670,7 +671,7 @@ gpu.module @test_module_32 {
670671
}
671672
}
672673

673-
gpu.module @gpumodule {
674+
gpu.module @test_module_33 {
674675
// CHECK-LABEL: func @kernel_with_block_size()
675676
// CHECK: attributes {gpu.kernel, gpu.known_block_size = array<i32: 128, 1, 1>, nvvm.kernel, nvvm.maxntid = array<i32: 128, 1, 1>}
676677
gpu.func @kernel_with_block_size() kernel attributes {known_block_size = array<i32: 128, 1, 1>} {
@@ -679,6 +680,28 @@ gpu.module @gpumodule {
679680
}
680681

681682

683+
gpu.module @test_module_34 {
684+
// CHECK-LABEL: llvm.func @memref_signature(
685+
// CHECK-SAME: %{{.*}}: !llvm.ptr, %{{.*}}: !llvm.ptr, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: i64, %{{.*}}: f32) -> !llvm.struct<(struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, f32)>
686+
// CHECK: llvm.mlir.undef
687+
// CHECK: llvm.insertvalue
688+
// CHECK: llvm.insertvalue
689+
// CHECK: llvm.insertvalue
690+
// CHECK: llvm.insertvalue
691+
// CHECK: llvm.insertvalue
692+
// CHECK: llvm.mlir.undef
693+
// CHECK: llvm.insertvalue
694+
// CHECK: llvm.insertvalue
695+
// CHECK: llvm.return
696+
697+
// CHECK-BARE-LABEL: llvm.func @memref_signature(
698+
// CHECK-BARE-SAME: %{{.*}}: !llvm.ptr, %{{.*}}: f32) -> !llvm.struct<(ptr, f32)>
699+
gpu.func @memref_signature(%m: memref<2xf32>, %f: f32) -> (memref<2xf32>, f32) {
700+
gpu.return %m, %f : memref<2xf32>, f32
701+
}
702+
}
703+
704+
682705
module attributes {transform.with_named_sequence} {
683706
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
684707
%gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module
@@ -701,9 +724,7 @@ module attributes {transform.with_named_sequence} {
701724
} with type_converter {
702725
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
703726
{index_bitwidth = 64,
704-
use_bare_ptr = true,
705-
use_bare_ptr_memref_call_conv = true,
706-
use_opaque_pointers = true}
727+
use_bare_ptr_call_conv = false}
707728
} {
708729
legal_dialects = ["llvm", "memref", "nvvm", "test"],
709730
legal_ops = ["func.func", "gpu.module", "gpu.module_end", "gpu.yield"],

0 commit comments

Comments
 (0)