@@ -153,31 +153,42 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153
153
type.isVarArg ());
154
154
});
155
155
156
- // Argument materializations convert from the new block argument types
157
- // (multiple SSA values that make up a memref descriptor) back to the
158
- // original block argument type. The dialect conversion framework will then
159
- // insert a target materialization from the original block argument type to
160
- // a legal type.
161
- addArgumentMaterialization ([&](OpBuilder &builder,
162
- UnrankedMemRefType resultType,
163
- ValueRange inputs, Location loc) {
156
+ // Add generic source and target materializations to handle cases where
157
+ // non-LLVM types persist after an LLVM conversion.
158
+ addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
159
+ ValueRange inputs, Location loc) {
160
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
161
+ .getResult (0 );
162
+ });
163
+
164
+ // Source materializations convert the MemrRef descriptor elements
165
+ // (multiple SSA values that make up a MemrRef descriptor) back to the
166
+ // original MemRef type.
167
+ addSourceMaterialization ([&](OpBuilder &builder,
168
+ UnrankedMemRefType resultType, ValueRange inputs,
169
+ Location loc) {
164
170
if (inputs.size () == 1 ) {
165
171
// Bare pointers are not supported for unranked memrefs because a
166
172
// memref descriptor cannot be built just from a bare pointer.
167
173
return Value ();
168
174
}
169
175
Value desc =
170
176
UnrankedMemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
171
- // An argument materialization must return a value of type
177
+ // A source materialization must return a value of type
172
178
// `resultType`, so insert a cast from the memref descriptor type
173
179
// (!llvm.struct) to the original memref type.
174
180
return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
175
181
.getResult (0 );
176
182
});
177
- addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
178
- ValueRange inputs, Location loc) {
183
+ addSourceMaterialization ([&](OpBuilder &builder, MemRefType resultType,
184
+ ValueRange inputs, Location loc) {
185
+ if (inputs.size () == 1 &&
186
+ isa<LLVM::LLVMStructType>(inputs.front ().getType ()))
187
+ return Value ();
188
+
179
189
Value desc;
180
- if (inputs.size () == 1 ) {
190
+ if (inputs.size () == 1 &&
191
+ isa<LLVM::LLVMPointerType>(inputs.front ().getType ())) {
181
192
// This is a bare pointer. We allow bare pointers only for function entry
182
193
// blocks.
183
194
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front ());
@@ -192,15 +203,13 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
192
203
} else {
193
204
desc = MemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
194
205
}
195
- // An argument materialization must return a value of type `resultType`,
206
+ // A source materialization must return a value of type `resultType`,
196
207
// so insert a cast from the memref descriptor type (!llvm.struct) to the
197
208
// original memref type.
198
209
return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
199
210
.getResult (0 );
200
211
});
201
- // Add generic source and target materializations to handle cases where
202
- // non-LLVM types persist after an LLVM conversion.
203
- addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
212
+ addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
204
213
ValueRange inputs, Location loc) {
205
214
if (inputs.size () != 1 )
206
215
return Value ();
@@ -209,12 +218,50 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
209
218
.getResult (0 );
210
219
});
211
220
addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
212
- ValueRange inputs, Location loc) {
213
- if (inputs.size () != 1 )
221
+ ValueRange inputs, Location loc,
222
+ Type originalType) -> Value {
223
+ llvm::errs () << " TARGET MAT: -> " << resultType << " \n " ;
224
+ if (!originalType) {
225
+ llvm::errs () << " -- no orig\n " ;
214
226
return Value ();
227
+ }
228
+ if (auto memrefType = dyn_cast<MemRefType>(originalType)) {
229
+ assert (isa<LLVM::LLVMStructType>(resultType) && " expected struct type" );
230
+ if (inputs.size () == 1 ) {
231
+ Value input = inputs.front ();
232
+ if (auto castOp = input.getDefiningOp <UnrealizedConversionCastOp>()) {
233
+ if (castOp.getInputs ().size () == 1 &&
234
+ isa<LLVM::LLVMPointerType>(castOp.getInputs ()[0 ].getType ())) {
235
+ input = castOp.getInputs ()[0 ];
236
+ }
237
+ }
238
+ if (!isa<LLVM::LLVMPointerType>(input.getType ()))
239
+ return Value ();
240
+ BlockArgument barePtr = dyn_cast<BlockArgument>(input);
241
+ if (!barePtr)
242
+ return Value ();
243
+ Block *block = barePtr.getOwner ();
244
+ if (!block->isEntryBlock () ||
245
+ !isa<FunctionOpInterface>(block->getParentOp ()))
246
+ return Value ();
247
+ // Bare ptr
248
+ return MemRefDescriptor::fromStaticShape (builder, loc, *this ,
249
+ memrefType, input);
250
+ }
251
+ return MemRefDescriptor::pack (builder, loc, *this , memrefType, inputs);
252
+ }
253
+ if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
254
+ assert (isa<LLVM::LLVMStructType>(resultType) && " expected struct type" );
255
+ if (inputs.size () == 1 ) {
256
+ // Bare pointers are not supported for unranked memrefs because a
257
+ // memref descriptor cannot be built just from a bare pointer.
258
+ return Value ();
259
+ }
260
+ return UnrankedMemRefDescriptor::pack (builder, loc, *this , memrefType,
261
+ inputs);
262
+ }
215
263
216
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
217
- .getResult (0 );
264
+ return Value ();
218
265
});
219
266
220
267
// Integer memory spaces map to themselves.
0 commit comments