@@ -53,8 +53,7 @@ MemRefDescriptor MemRefDescriptor::fromStaticShape(
53
53
54
54
// Extract all strides and offsets and verify they are static.
55
55
auto [strides, offset] = getStridesAndOffset (type);
56
- assert (!ShapedType::isDynamic (offset) &&
57
- " expected static offset" );
56
+ assert (!ShapedType::isDynamic (offset) && " expected static offset" );
58
57
assert (!llvm::any_of (strides, ShapedType::isDynamic) &&
59
58
" expected static strides" );
60
59
@@ -134,27 +133,19 @@ Value MemRefDescriptor::size(OpBuilder &builder, Location loc, Value pos,
134
133
int64_t rank) {
135
134
auto arrayTy = LLVM::LLVMArrayType::get (indexType, rank);
136
135
137
- LLVM::LLVMPointerType indexPtrTy;
138
- LLVM::LLVMPointerType arrayPtrTy;
139
-
140
- if (useOpaquePointers ()) {
141
- arrayPtrTy = indexPtrTy = LLVM::LLVMPointerType::get (builder.getContext ());
142
- } else {
143
- indexPtrTy = LLVM::LLVMPointerType::get (indexType);
144
- arrayPtrTy = LLVM::LLVMPointerType::get (arrayTy);
145
- }
136
+ auto ptrTy = LLVM::LLVMPointerType::get (builder.getContext ());
146
137
147
138
// Copy size values to stack-allocated memory.
148
139
auto one = createIndexAttrConstant (builder, loc, indexType, 1 );
149
140
auto sizes = builder.create <LLVM::ExtractValueOp>(
150
141
loc, value, llvm::ArrayRef<int64_t >({kSizePosInMemRefDescriptor }));
151
- auto sizesPtr = builder.create <LLVM::AllocaOp>(loc, arrayPtrTy , arrayTy, one,
142
+ auto sizesPtr = builder.create <LLVM::AllocaOp>(loc, ptrTy , arrayTy, one,
152
143
/* alignment=*/ 0 );
153
144
builder.create <LLVM::StoreOp>(loc, sizes, sizesPtr);
154
145
155
146
// Load an return size value of interest.
156
- auto resultPtr = builder.create <LLVM::GEPOp>(
157
- loc, indexPtrTy, arrayTy, sizesPtr, ArrayRef<LLVM::GEPArg>{0 , pos});
147
+ auto resultPtr = builder.create <LLVM::GEPOp>(loc, ptrTy, arrayTy, sizesPtr,
148
+ ArrayRef<LLVM::GEPArg>{0 , pos});
158
149
return builder.create <LLVM::LoadOp>(loc, indexType, resultPtr);
159
150
}
160
151
@@ -273,10 +264,6 @@ unsigned MemRefDescriptor::getNumUnpackedValues(MemRefType type) {
273
264
return 3 + 2 * type.getRank ();
274
265
}
275
266
276
- bool MemRefDescriptor::useOpaquePointers () {
277
- return getElementPtrType ().isOpaque ();
278
- }
279
-
280
267
// ===----------------------------------------------------------------------===//
281
268
// MemRefDescriptorView implementation.
282
269
// ===----------------------------------------------------------------------===//
@@ -413,44 +400,20 @@ void UnrankedMemRefDescriptor::computeSizes(
413
400
Value UnrankedMemRefDescriptor::allocatedPtr (
414
401
OpBuilder &builder, Location loc, Value memRefDescPtr,
415
402
LLVM::LLVMPointerType elemPtrType) {
416
-
417
- Value elementPtrPtr;
418
- if (elemPtrType.isOpaque ())
419
- elementPtrPtr = memRefDescPtr;
420
- else
421
- elementPtrPtr = builder.create <LLVM::BitcastOp>(
422
- loc, LLVM::LLVMPointerType::get (elemPtrType), memRefDescPtr);
423
-
424
- return builder.create <LLVM::LoadOp>(loc, elemPtrType, elementPtrPtr);
403
+ return builder.create <LLVM::LoadOp>(loc, elemPtrType, memRefDescPtr);
425
404
}
426
405
427
406
void UnrankedMemRefDescriptor::setAllocatedPtr (
428
407
OpBuilder &builder, Location loc, Value memRefDescPtr,
429
408
LLVM::LLVMPointerType elemPtrType, Value allocatedPtr) {
430
- Value elementPtrPtr;
431
- if (elemPtrType.isOpaque ())
432
- elementPtrPtr = memRefDescPtr;
433
- else
434
- elementPtrPtr = builder.create <LLVM::BitcastOp>(
435
- loc, LLVM::LLVMPointerType::get (elemPtrType), memRefDescPtr);
436
-
437
- builder.create <LLVM::StoreOp>(loc, allocatedPtr, elementPtrPtr);
409
+ builder.create <LLVM::StoreOp>(loc, allocatedPtr, memRefDescPtr);
438
410
}
439
411
440
412
static std::pair<Value, Type>
441
413
castToElemPtrPtr (OpBuilder &builder, Location loc, Value memRefDescPtr,
442
414
LLVM::LLVMPointerType elemPtrType) {
443
- Value elementPtrPtr;
444
- Type elemPtrPtrType;
445
- if (elemPtrType.isOpaque ()) {
446
- elementPtrPtr = memRefDescPtr;
447
- elemPtrPtrType = LLVM::LLVMPointerType::get (builder.getContext ());
448
- } else {
449
- elemPtrPtrType = LLVM::LLVMPointerType::get (elemPtrType);
450
- elementPtrPtr =
451
- builder.create <LLVM::BitcastOp>(loc, elemPtrPtrType, memRefDescPtr);
452
- }
453
- return {elementPtrPtr, elemPtrPtrType};
415
+ auto elemPtrPtrType = LLVM::LLVMPointerType::get (builder.getContext ());
416
+ return {memRefDescPtr, elemPtrPtrType};
454
417
}
455
418
456
419
Value UnrankedMemRefDescriptor::alignedPtr (
@@ -483,16 +446,8 @@ Value UnrankedMemRefDescriptor::offsetBasePtr(
483
446
auto [elementPtrPtr, elemPtrPtrType] =
484
447
castToElemPtrPtr (builder, loc, memRefDescPtr, elemPtrType);
485
448
486
- Value offsetGep =
487
- builder.create <LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
488
- elementPtrPtr, ArrayRef<LLVM::GEPArg>{2 });
489
-
490
- if (!elemPtrType.isOpaque ()) {
491
- offsetGep = builder.create <LLVM::BitcastOp>(
492
- loc, LLVM::LLVMPointerType::get (typeConverter.getIndexType ()),
493
- offsetGep);
494
- }
495
- return offsetGep;
449
+ return builder.create <LLVM::GEPOp>(loc, elemPtrPtrType, elemPtrType,
450
+ elementPtrPtr, ArrayRef<LLVM::GEPArg>{2 });
496
451
}
497
452
498
453
Value UnrankedMemRefDescriptor::offset (OpBuilder &builder, Location loc,
@@ -521,19 +476,8 @@ Value UnrankedMemRefDescriptor::sizeBasePtr(
521
476
Type indexTy = typeConverter.getIndexType ();
522
477
Type structTy = LLVM::LLVMStructType::getLiteral (
523
478
indexTy.getContext (), {elemPtrType, elemPtrType, indexTy, indexTy});
524
- Value structPtr;
525
- if (elemPtrType.isOpaque ()) {
526
- structPtr = memRefDescPtr;
527
- } else {
528
- Type structPtrTy = LLVM::LLVMPointerType::get (structTy);
529
- structPtr =
530
- builder.create <LLVM::BitcastOp>(loc, structPtrTy, memRefDescPtr);
531
- }
532
-
533
- auto resultType = elemPtrType.isOpaque ()
534
- ? LLVM::LLVMPointerType::get (indexTy.getContext ())
535
- : LLVM::LLVMPointerType::get (indexTy);
536
- return builder.create <LLVM::GEPOp>(loc, resultType, structTy, structPtr,
479
+ auto resultType = LLVM::LLVMPointerType::get (builder.getContext ());
480
+ return builder.create <LLVM::GEPOp>(loc, resultType, structTy, memRefDescPtr,
537
481
ArrayRef<LLVM::GEPArg>{0 , 3 });
538
482
}
539
483
0 commit comments