@@ -153,68 +153,106 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153
153
type.isVarArg ());
154
154
});
155
155
156
- // Argument materializations convert from the new block argument types
157
- // (multiple SSA values that make up a memref descriptor) back to the
158
- // original block argument type. The dialect conversion framework will then
159
- // insert a target materialization from the original block argument type to
160
- // a legal type.
161
- addArgumentMaterialization ([&](OpBuilder &builder,
162
- UnrankedMemRefType resultType,
163
- ValueRange inputs, Location loc) {
156
+ // Add generic source and target materializations to handle cases where
157
+ // non-LLVM types persist after an LLVM conversion.
158
+ addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
159
+ ValueRange inputs, Location loc) {
160
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
161
+ .getResult (0 );
162
+ });
163
+ addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
164
+ ValueRange inputs, Location loc) {
165
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
166
+ .getResult (0 );
167
+ });
168
+
169
+ // Source materializations convert the MemrRef descriptor elements
170
+ // (multiple SSA values that make up a MemrRef descriptor) back to the
171
+ // original MemRef type.
172
+ addSourceMaterialization ([&](OpBuilder &builder,
173
+ UnrankedMemRefType resultType, ValueRange inputs,
174
+ Location loc) {
164
175
if (inputs.size () == 1 ) {
165
176
// Bare pointers are not supported for unranked memrefs because a
166
177
// memref descriptor cannot be built just from a bare pointer.
167
178
return Value ();
168
179
}
169
180
Value desc =
170
181
UnrankedMemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
171
- // An argument materialization must return a value of type
182
+ // A source materialization must return a value of type
172
183
// `resultType`, so insert a cast from the memref descriptor type
173
184
// (!llvm.struct) to the original memref type.
174
185
return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
175
186
.getResult (0 );
176
187
});
177
- addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
178
- ValueRange inputs, Location loc) {
188
+ addSourceMaterialization ([&](OpBuilder &builder, MemRefType resultType,
189
+ ValueRange inputs, Location loc) {
190
+ if (inputs.size () == 1 &&
191
+ isa<LLVM::LLVMStructType>(inputs.front ().getType ()))
192
+ return Value ();
193
+
179
194
Value desc;
180
- if (inputs.size () == 1 ) {
195
+ if (inputs.size () == 1 &&
196
+ isa<LLVM::LLVMPointerType>(inputs.front ().getType ())) {
181
197
// This is a bare pointer. We allow bare pointers only for function entry
182
198
// blocks.
183
199
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front ());
184
200
if (!barePtr)
185
201
return Value ();
186
- Block *block = barePtr.getOwner ();
187
- if (!block->isEntryBlock () ||
188
- !isa<FunctionOpInterface>(block->getParentOp ()))
189
- return Value ();
202
+ // Block *block = barePtr.getOwner();
203
+ // if (!block->isEntryBlock() ||
204
+ // !isa<FunctionOpInterface>(block->getParentOp()))
205
+ // return Value();
190
206
desc = MemRefDescriptor::fromStaticShape (builder, loc, *this , resultType,
191
207
inputs[0 ]);
192
208
} else {
193
209
desc = MemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
194
210
}
195
- // An argument materialization must return a value of type `resultType`,
211
+ // A source materialization must return a value of type `resultType`,
196
212
// so insert a cast from the memref descriptor type (!llvm.struct) to the
197
213
// original memref type.
198
214
return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
199
215
.getResult (0 );
200
216
});
201
- // Add generic source and target materializations to handle cases where
202
- // non-LLVM types persist after an LLVM conversion.
203
- addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
204
- ValueRange inputs, Location loc) {
205
- if (inputs.size () != 1 )
206
- return Value ();
217
+ addTargetMaterialization ([&](OpBuilder &builder, LLVM::LLVMStructType resultType,
218
+ ValueRange inputs, Location loc,
219
+ Type originalType) -> Value {
220
+ if (auto memrefType = dyn_cast_or_null<MemRefType>(originalType)) {
221
+ if (inputs.size () == 1 ) {
222
+ Value input = inputs.front ();
223
+ // if (auto castOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
224
+ // if (castOp.getInputs().size() == 1 &&
225
+ // isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
226
+ // input = castOp.getInputs()[0];
227
+ // }
228
+ // }
229
+ if (!isa<LLVM::LLVMPointerType>(input.getType ()))
230
+ return Value ();
231
+ BlockArgument barePtr = dyn_cast<BlockArgument>(input);
232
+ if (!barePtr)
233
+ return Value ();
234
+ // Block *block = barePtr.getOwner();
235
+ // if (!block->isEntryBlock() ||
236
+ // !isa<FunctionOpInterface>(block->getParentOp()))
237
+ // return Value();
238
+ // Bare ptr
239
+ return MemRefDescriptor::fromStaticShape (builder, loc, *this ,
240
+ memrefType, input);
241
+ }
242
+ return MemRefDescriptor::pack (builder, loc, *this , memrefType, inputs);
243
+ }
207
244
208
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
209
- .getResult (0 );
210
- });
211
- addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
212
- ValueRange inputs, Location loc) {
213
- if (inputs.size () != 1 )
214
- return Value ();
245
+ if (auto memrefType = dyn_cast_or_null<UnrankedMemRefType>(originalType)) {
246
+ if (inputs.size () == 1 ) {
247
+ // Bare pointers are not supported for unranked memrefs because a
248
+ // memref descriptor cannot be built just from a bare pointer.
249
+ return Value ();
250
+ }
251
+ return UnrankedMemRefDescriptor::pack (builder, loc, *this , memrefType,
252
+ inputs);
253
+ }
215
254
216
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
217
- .getResult (0 );
255
+ return Value ();
218
256
});
219
257
220
258
// Integer memory spaces map to themselves.
0 commit comments