Skip to content

[MLIR][MemRefToLLVM] Remove typed pointer support #70909

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -745,10 +745,7 @@ def FinalizeMemRefToLLVMConversionPass :
"bool",
/*default=*/"false",
"Use generic allocation and deallocation functions instead of the "
"classic 'malloc', 'aligned_alloc' and 'free' functions">,
Option<"useOpaquePointers", "use-opaque-pointers", "bool",
/*default=*/"true", "Generate LLVM IR using opaque pointers "
"instead of typed pointers">
"classic 'malloc', 'aligned_alloc' and 'free' functions">
];
}

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/LLVMIR/FunctionCallUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ LLVM::LLVMFuncOp lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
Type indexType,
bool opaquePointers);
LLVM::LLVMFuncOp lookupOrCreateGenericFreeFn(ModuleOp moduleOp,
bool opaquePointers);
bool opaquePointers = true);
LLVM::LLVMFuncOp lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
Type unrankedDescriptorType);

Expand Down
129 changes: 32 additions & 97 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,9 @@ LLVM::LLVMFuncOp getFreeFn(const LLVMTypeConverter *typeConverter,
bool useGenericFn = typeConverter->getOptions().useGenericFunctions;

if (useGenericFn)
return LLVM::lookupOrCreateGenericFreeFn(
module, typeConverter->useOpaquePointers());
return LLVM::lookupOrCreateGenericFreeFn(module);

return LLVM::lookupOrCreateFreeFn(module, typeConverter->useOpaquePointers());
return LLVM::lookupOrCreateFreeFn(module);
}

struct AllocOpLowering : public AllocLikeOpLLVMLowering {
Expand Down Expand Up @@ -108,7 +107,7 @@ struct AllocaOpLowering : public AllocLikeOpLLVMLowering {
unsigned addrSpace =
*getTypeConverter()->getMemRefAddressSpace(allocaOp.getType());
auto elementPtrType =
getTypeConverter()->getPointerType(elementType, addrSpace);
LLVM::LLVMPointerType::get(rewriter.getContext(), addrSpace);

auto allocatedElementPtr =
rewriter.create<LLVM::AllocaOp>(loc, elementPtrType, elementType, size,
Expand Down Expand Up @@ -232,10 +231,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
Value allocatedPtr;
if (auto unrankedTy =
llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
Type elementType = unrankedTy.getElementType();
Type llvmElementTy = getTypeConverter()->convertType(elementType);
LLVM::LLVMPointerType elementPtrTy = getTypeConverter()->getPointerType(
llvmElementTy, unrankedTy.getMemorySpaceAsInt());
auto elementPtrTy = LLVM::LLVMPointerType::get(
rewriter.getContext(), unrankedTy.getMemorySpaceAsInt());
allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
rewriter, op.getLoc(),
UnrankedMemRefDescriptor(adaptor.getMemref())
Expand All @@ -245,10 +242,6 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
allocatedPtr = MemRefDescriptor(adaptor.getMemref())
.allocatedPtr(rewriter, op.getLoc());
}
if (!getTypeConverter()->useOpaquePointers())
allocatedPtr = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), getVoidPtrType(), allocatedPtr);

rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
return success();
}
Expand Down Expand Up @@ -306,19 +299,12 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
Value underlyingRankedDesc = unrankedDesc.memRefDescPtr(rewriter, loc);

Type elementType = typeConverter->convertType(scalarMemRefType);
Value scalarMemRefDescPtr;
if (getTypeConverter()->useOpaquePointers())
scalarMemRefDescPtr = underlyingRankedDesc;
else
scalarMemRefDescPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(elementType, addressSpace),
underlyingRankedDesc);

// Get pointer to offset field of memref<element_type> descriptor.
Type indexPtrTy = getTypeConverter()->getPointerType(
getTypeConverter()->getIndexType(), addressSpace);
auto indexPtrTy =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);
Value offsetPtr = rewriter.create<LLVM::GEPOp>(
loc, indexPtrTy, elementType, scalarMemRefDescPtr,
loc, indexPtrTy, elementType, underlyingRankedDesc,
ArrayRef<LLVM::GEPArg>{0, 2});

// The size value that we have to extract can be obtained using GEPop with
Expand Down Expand Up @@ -569,18 +555,14 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
unsigned memSpace = *maybeAddressSpace;

Type arrayTy = convertGlobalMemrefTypeToLLVM(type, *getTypeConverter());
Type resTy = getTypeConverter()->getPointerType(arrayTy, memSpace);
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), memSpace);
auto addressOf =
rewriter.create<LLVM::AddressOfOp>(loc, resTy, getGlobalOp.getName());
rewriter.create<LLVM::AddressOfOp>(loc, ptrTy, getGlobalOp.getName());

// Get the address of the first element in the array by creating a GEP with
// the address of the GV as the base, and (rank + 1) number of 0 indices.
Type elementType = typeConverter->convertType(type.getElementType());
Type elementPtrType =
getTypeConverter()->getPointerType(elementType, memSpace);

auto gep = rewriter.create<LLVM::GEPOp>(
loc, elementPtrType, arrayTy, addressOf,
loc, ptrTy, arrayTy, addressOf,
SmallVector<LLVM::GEPArg>(type.getRank() + 1, 0));

// We do not expect the memref obtained using `memref.get_global` to be
Expand All @@ -590,7 +572,7 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
Value deadBeefConst =
createIndexAttrConstant(rewriter, op->getLoc(), intPtrType, 0xdeadbeef);
auto deadBeefPtr =
rewriter.create<LLVM::IntToPtrOp>(loc, elementPtrType, deadBeefConst);
rewriter.create<LLVM::IntToPtrOp>(loc, ptrTy, deadBeefConst);

// Both allocated and aligned pointers are same. We could potentially stash
// a nullptr for the allocated pointer since we do not expect any dealloc.
Expand Down Expand Up @@ -734,13 +716,6 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
auto ptr = getTypeConverter()->promoteOneMemRefDescriptor(
loc, adaptor.getSource(), rewriter);

// voidptr = BitCastOp srcType* to void*
Value voidPtr;
if (getTypeConverter()->useOpaquePointers())
voidPtr = ptr;
else
voidPtr = rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);

// rank = ConstantOp srcRank
auto rankVal = rewriter.create<LLVM::ConstantOp>(
loc, getIndexType(), rewriter.getIndexAttr(rank));
Expand All @@ -749,8 +724,8 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
UnrankedMemRefDescriptor::undef(rewriter, loc, targetStructType);
// d1 = InsertValueOp undef, rank, 0
memRefDesc.setRank(rewriter, loc, rankVal);
// d2 = InsertValueOp d1, voidptr, 1
memRefDesc.setMemRefDescPtr(rewriter, loc, voidPtr);
// d2 = InsertValueOp d1, ptr, 1
memRefDesc.setMemRefDescPtr(rewriter, loc, ptr);
rewriter.replaceOp(memRefCastOp, (Value)memRefDesc);

} else if (isa<UnrankedMemRefType>(srcType) && isa<MemRefType>(dstType)) {
Expand All @@ -760,17 +735,9 @@ struct MemRefCastOpLowering : public ConvertOpToLLVMPattern<memref::CastOp> {
UnrankedMemRefDescriptor memRefDesc(adaptor.getSource());
// ptr = ExtractValueOp src, 1
auto ptr = memRefDesc.memRefDescPtr(rewriter, loc);
// castPtr = BitCastOp i8* to structTy*
Value castPtr;
if (getTypeConverter()->useOpaquePointers())
castPtr = ptr;
else
castPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(targetStructType), ptr);

// struct = LoadOp castPtr
auto loadOp =
rewriter.create<LLVM::LoadOp>(loc, targetStructType, castPtr);
// struct = LoadOp ptr
auto loadOp = rewriter.create<LLVM::LoadOp>(loc, targetStructType, ptr);
rewriter.replaceOp(memRefCastOp, loadOp.getResult());
} else {
llvm_unreachable("Unsupported unranked memref to unranked memref cast");
Expand Down Expand Up @@ -841,17 +808,10 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto ptr =
typeConverter->promoteOneMemRefDescriptor(loc, ranked, rewriter);

Value voidPtr;
if (getTypeConverter()->useOpaquePointers())
voidPtr = ptr;
else
voidPtr = rewriter.create<LLVM::BitcastOp>(loc, getVoidPtrType(), ptr);

auto unrankedType =
UnrankedMemRefType::get(type.getElementType(), type.getMemorySpace());
return UnrankedMemRefDescriptor::pack(rewriter, loc, *typeConverter,
unrankedType,
ValueRange{rank, voidPtr});
return UnrankedMemRefDescriptor::pack(
rewriter, loc, *typeConverter, unrankedType, ValueRange{rank, ptr});
};

// Save stack position before promoting descriptors
Expand All @@ -871,7 +831,7 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto one = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
rewriter.getIndexAttr(1));
auto promote = [&](Value desc) {
Type ptrType = getTypeConverter()->getPointerType(desc.getType());
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext());
auto allocated =
rewriter.create<LLVM::AllocaOp>(loc, ptrType, desc.getType(), one);
rewriter.create<LLVM::StoreOp>(loc, desc, allocated);
Expand Down Expand Up @@ -983,12 +943,10 @@ struct MemorySpaceCastOpLowering
result.setMemRefDescPtr(rewriter, loc, resultUnderlyingDesc);

// Copy pointers, performing address space casts.
Type llvmElementType =
typeConverter->convertType(sourceType.getElementType());
LLVM::LLVMPointerType sourceElemPtrType =
getTypeConverter()->getPointerType(llvmElementType, sourceAddrSpace);
auto sourceElemPtrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
auto resultElemPtrType =
getTypeConverter()->getPointerType(llvmElementType, resultAddrSpace);
LLVM::LLVMPointerType::get(rewriter.getContext(), resultAddrSpace);

Value allocatedPtr = sourceDesc.allocatedPtr(
rewriter, loc, sourceUnderlyingDesc, sourceElemPtrType);
Expand Down Expand Up @@ -1053,10 +1011,8 @@ static void extractPointersAndOffset(Location loc,
// These will all cause assert()s on unconvertible types.
unsigned memorySpace = *typeConverter.getMemRefAddressSpace(
cast<UnrankedMemRefType>(operandType));
Type elementType = cast<UnrankedMemRefType>(operandType).getElementType();
Type llvmElementType = typeConverter.convertType(elementType);
LLVM::LLVMPointerType elementPtrType =
typeConverter.getPointerType(llvmElementType, memorySpace);
auto elementPtrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), memorySpace);

// Extract pointer to the underlying ranked memref descriptor and cast it to
// ElemType**.
Expand Down Expand Up @@ -1254,7 +1210,6 @@ struct MemRefReshapeOpLowering
auto targetType = cast<UnrankedMemRefType>(reshapeOp.getResult().getType());
unsigned addressSpace =
*getTypeConverter()->getMemRefAddressSpace(targetType);
Type elementType = targetType.getElementType();

// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
Expand All @@ -1276,9 +1231,8 @@ struct MemRefReshapeOpLowering
&allocatedPtr, &alignedPtr, &offset);

// Set pointers and offset.
Type llvmElementType = typeConverter->convertType(elementType);
LLVM::LLVMPointerType elementPtrType =
getTypeConverter()->getPointerType(llvmElementType, addressSpace);
auto elementPtrType =
LLVM::LLVMPointerType::get(rewriter.getContext(), addressSpace);

UnrankedMemRefDescriptor::setAllocatedPtr(rewriter, loc, underlyingDescPtr,
elementPtrType, allocatedPtr);
Expand Down Expand Up @@ -1328,7 +1282,7 @@ struct MemRefReshapeOpLowering
rewriter.setInsertionPointToStart(bodyBlock);

// Copy size from shape to descriptor.
Type llvmIndexPtrType = getTypeConverter()->getPointerType(indexType);
auto llvmIndexPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value sizeLoadGep = rewriter.create<LLVM::GEPOp>(
loc, llvmIndexPtrType,
typeConverter->convertType(shapeMemRefType.getElementType()),
Expand Down Expand Up @@ -1430,9 +1384,9 @@ class TransposeOpLowering : public ConvertOpToLLVMPattern<memref::TransposeOp> {
targetMemRef.setOffset(rewriter, loc, viewMemRef.offset(rewriter, loc));

// Iterate over the dimensions and apply size/stride permutation:
// When enumerating the results of the permutation map, the enumeration index
// is the index into the target dimensions and the DimExpr points to the
// dimension of the source memref.
// When enumerating the results of the permutation map, the enumeration
// index is the index into the target dimensions and the DimExpr points to
// the dimension of the source memref.
for (const auto &en :
llvm::enumerate(transposeOp.getPermutation().getResults())) {
int targetPos = en.index();
Expand Down Expand Up @@ -1523,17 +1477,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
// Field 1: Copy the allocated pointer, used for malloc/free.
Value allocatedPtr = sourceMemRef.allocatedPtr(rewriter, loc);
auto srcMemRefType = cast<MemRefType>(viewOp.getSource().getType());
unsigned sourceMemorySpace =
*getTypeConverter()->getMemRefAddressSpace(srcMemRefType);
Value bitcastPtr;
if (getTypeConverter()->useOpaquePointers())
bitcastPtr = allocatedPtr;
else
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
allocatedPtr);

targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr);
targetMemRef.setAllocatedPtr(rewriter, loc, allocatedPtr);

// Field 2: Copy the actual aligned pointer to payload.
Value alignedPtr = sourceMemRef.alignedPtr(rewriter, loc);
Expand All @@ -1542,15 +1486,7 @@ struct ViewOpLowering : public ConvertOpToLLVMPattern<memref::ViewOp> {
typeConverter->convertType(srcMemRefType.getElementType()), alignedPtr,
adaptor.getByteShift());

if (getTypeConverter()->useOpaquePointers()) {
bitcastPtr = alignedPtr;
} else {
bitcastPtr = rewriter.create<LLVM::BitcastOp>(
loc, LLVM::LLVMPointerType::get(targetElementTy, sourceMemorySpace),
alignedPtr);
}

targetMemRef.setAlignedPtr(rewriter, loc, bitcastPtr);
targetMemRef.setAlignedPtr(rewriter, loc, alignedPtr);

Type indexType = getIndexType();
// Field 3: The offset in the resulting type must be 0. This is
Expand Down Expand Up @@ -1766,7 +1702,6 @@ struct FinalizeMemRefToLLVMConversionPass
: LowerToLLVMOptions::AllocLowering::Malloc);

options.useGenericFunctions = useGenericFunctions;
options.useOpaquePointers = useOpaquePointers;

if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -finalize-memref-to-llvm='use-opaque-pointers=1' %s | FileCheck %s
// RUN: mlir-opt -finalize-memref-to-llvm %s | FileCheck %s

// CHECK-LABEL: @empty
func.func @empty() {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt -split-input-file -finalize-memref-to-llvm='use-opaque-pointers=1' %s | FileCheck %s
// RUN: mlir-opt -split-input-file -finalize-memref-to-llvm='use-aligned-alloc=1 use-opaque-pointers=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC
// RUN: mlir-opt -split-input-file -finalize-memref-to-llvm='index-bitwidth=32 use-opaque-pointers=1' %s | FileCheck --check-prefix=CHECK32 %s
// RUN: mlir-opt -split-input-file -finalize-memref-to-llvm %s | FileCheck %s
// RUN: mlir-opt -split-input-file -finalize-memref-to-llvm='use-aligned-alloc=1' %s | FileCheck %s --check-prefix=ALIGNED-ALLOC
// RUN: mlir-opt -split-input-file -finalize-memref-to-llvm='index-bitwidth=32' %s | FileCheck --check-prefix=CHECK32 %s

// CHECK-LABEL: func @mixed_alloc(
// CHECK: %[[Marg:.*]]: index, %[[Narg:.*]]: index)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -finalize-memref-to-llvm='use-opaque-pointers=1' -split-input-file %s | FileCheck %s
// RUN: mlir-opt -finalize-memref-to-llvm -split-input-file %s | FileCheck %s

// CHECK-LABEL: func @zero_d_alloc()
func.func @zero_d_alloc() -> memref<f32> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -expand-strided-metadata -finalize-memref-to-llvm='use-opaque-pointers=1' -lower-affine -convert-arith-to-llvm -cse %s -split-input-file | FileCheck %s
// RUN: mlir-opt -expand-strided-metadata -finalize-memref-to-llvm -lower-affine -convert-arith-to-llvm -cse %s -split-input-file | FileCheck %s
//
// This test demonstrates a full "memref to llvm" pipeline where
// we first expand some of the memref operations (using affine,
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/MemRefToLLVM/generic-functions.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: mlir-opt -pass-pipeline="builtin.module(finalize-memref-to-llvm{use-generic-functions=1 use-opaque-pointers=1})" -split-input-file %s \
// RUN: mlir-opt -pass-pipeline="builtin.module(finalize-memref-to-llvm{use-generic-functions=1})" -split-input-file %s \
// RUN: | FileCheck %s --check-prefix="CHECK-NOTALIGNED"

// RUN: mlir-opt -pass-pipeline="builtin.module(finalize-memref-to-llvm{use-generic-functions=1 use-aligned-alloc=1 use-opaque-pointers=1})" -split-input-file %s \
// RUN: mlir-opt -pass-pipeline="builtin.module(finalize-memref-to-llvm{use-generic-functions=1 use-aligned-alloc=1})" -split-input-file %s \
// RUN: | FileCheck %s --check-prefix="CHECK-ALIGNED"

// CHECK-LABEL: func @alloc()
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt -finalize-memref-to-llvm='use-opaque-pointers=1' %s -split-input-file | FileCheck %s
// RUN: mlir-opt -finalize-memref-to-llvm='index-bitwidth=32 use-opaque-pointers=1' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s
// RUN: mlir-opt -finalize-memref-to-llvm %s -split-input-file | FileCheck %s
// RUN: mlir-opt -finalize-memref-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s

// Same below, but using the `ConvertToLLVMPatternInterface` entry point
// and the generic `convert-to-llvm` pass. This produces slightly different IR
Expand Down
Loading