@@ -153,20 +153,31 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153
153
type.isVarArg ());
154
154
});
155
155
156
+ // Add generic source and target materializations to handle cases where
157
+ // non-LLVM types persist after an LLVM conversion.
158
+ addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
159
+ ValueRange inputs, Location loc) {
160
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
161
+ .getResult (0 );
162
+ });
163
+ addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
164
+ ValueRange inputs, Location loc) {
165
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
166
+ .getResult (0 );
167
+ });
168
+
156
169
// Helper function that checks if the given value range is a bare pointer.
157
170
auto isBarePointer = [](ValueRange values) {
158
171
return values.size () == 1 &&
159
172
isa<LLVM::LLVMPointerType>(values.front ().getType ());
160
173
};
161
174
162
- // Argument materializations convert from the new block argument types
163
- // (multiple SSA values that make up a memref descriptor) back to the
164
- // original block argument type. The dialect conversion framework will then
165
- // insert a target materialization from the original block argument type to
166
- // a legal type.
167
- addArgumentMaterialization ([&](OpBuilder &builder,
168
- UnrankedMemRefType resultType,
169
- ValueRange inputs, Location loc) {
175
+ // Source materializations convert the MemrRef descriptor elements
176
+ // (multiple SSA values that make up a MemrRef descriptor) back to the
177
+ // original MemRef type.
178
+ addSourceMaterialization ([&](OpBuilder &builder,
179
+ UnrankedMemRefType resultType, ValueRange inputs,
180
+ Location loc) {
170
181
// Note: Bare pointers are not supported for unranked memrefs because a
171
182
// memref descriptor cannot be built just from a bare pointer.
172
183
if (TypeRange (inputs) != getUnrankedMemRefDescriptorFields ())
@@ -179,8 +190,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
179
190
return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
180
191
.getResult (0 );
181
192
});
182
- addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
183
- ValueRange inputs, Location loc) {
193
+ addSourceMaterialization ([&](OpBuilder &builder, MemRefType resultType,
194
+ ValueRange inputs, Location loc) {
184
195
Value desc;
185
196
if (isBarePointer (inputs)) {
186
197
desc = MemRefDescriptor::fromStaticShape (builder, loc, *this , resultType,
@@ -200,23 +211,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
200
211
return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
201
212
.getResult (0 );
202
213
});
203
- // Add generic source and target materializations to handle cases where
204
- // non-LLVM types persist after an LLVM conversion.
205
- addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
206
- ValueRange inputs, Location loc) {
207
- if (inputs.size () != 1 )
208
- return Value ();
214
+ addTargetMaterialization ([&](OpBuilder &builder,
215
+ LLVM::LLVMStructType resultType,
216
+ ValueRange inputs, Location loc,
217
+ Type originalType) -> Value {
218
+ if (auto memrefType = dyn_cast_or_null<MemRefType>(originalType)) {
219
+ if (isBarePointer (inputs)) {
220
+ return MemRefDescriptor::fromStaticShape (builder, loc, *this ,
221
+ memrefType, inputs[0 ]);
222
+ } else if (TypeRange (inputs) ==
223
+ getMemRefDescriptorFields (memrefType,
224
+ /* unpackAggregates=*/ true )) {
225
+ return MemRefDescriptor::pack (builder, loc, *this , memrefType, inputs);
226
+ }
227
+ }
209
228
210
- return builder. create <UnrealizedConversionCastOp>(loc, resultType, inputs)
211
- . getResult ( 0 );
212
- });
213
- addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
214
- ValueRange inputs, Location loc) {
215
- if (inputs. size () != 1 )
216
- return Value ();
229
+ if ( auto memrefType = dyn_cast_or_null<UnrankedMemRefType>(originalType)) {
230
+ // Note: Bare pointers are not supported for unranked memrefs because a
231
+ // memref descriptor cannot be built just from a bare pointer.
232
+ if ( TypeRange (inputs) == getUnrankedMemRefDescriptorFields ())
233
+ return UnrankedMemRefDescriptor::pack (builder, loc, * this , memrefType,
234
+ inputs);
235
+ }
217
236
218
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
219
- .getResult (0 );
237
+ return Value ();
220
238
});
221
239
222
240
// Integer memory spaces map to themselves.
0 commit comments