@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153
153
type.isVarArg ());
154
154
});
155
155
156
+ // Add generic source and target materializations to handle cases where
157
+ // non-LLVM types persist after an LLVM conversion.
158
+ addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
159
+ ValueRange inputs, Location loc) {
160
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
161
+ .getResult (0 );
162
+ });
163
+ addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
164
+ ValueRange inputs, Location loc) {
165
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
166
+ .getResult (0 );
167
+ });
168
+
156
169
// Helper function that checks if the given value range is a bare pointer.
157
170
auto isBarePointer = [](ValueRange values) {
158
171
return values.size () == 1 &&
159
172
isa<LLVM::LLVMPointerType>(values.front ().getType ());
160
173
};
161
174
162
- // Argument materializations convert from the new block argument types
163
- // (multiple SSA values that make up a memref descriptor) back to the
164
- // original block argument type. The dialect conversion framework will then
165
- // insert a target materialization from the original block argument type to
166
- // a legal type.
167
- addArgumentMaterialization ([&](OpBuilder &builder,
168
- UnrankedMemRefType resultType,
169
- ValueRange inputs, Location loc) {
175
+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
176
+ // must be passed explicitly.
177
+ auto packUnrankedMemRefDesc =
178
+ [&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
179
+ Location loc, LLVMTypeConverter &converter) -> Value {
170
180
// Note: Bare pointers are not supported for unranked memrefs because a
171
181
// memref descriptor cannot be built just from a bare pointer.
172
- if (TypeRange (inputs) != getUnrankedMemRefDescriptorFields ())
182
+ if (TypeRange (inputs) != converter. getUnrankedMemRefDescriptorFields ())
173
183
return Value ();
174
- Value desc =
175
- UnrankedMemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
184
+ return UnrankedMemRefDescriptor::pack (builder, loc, converter, resultType,
185
+ inputs);
186
+ };
187
+
188
+ // MemRef descriptor elements -> UnrankedMemRefType
189
+ auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
190
+ UnrankedMemRefType resultType,
191
+ ValueRange inputs, Location loc) {
176
192
// An argument materialization must return a value of type
177
193
// `resultType`, so insert a cast from the memref descriptor type
178
194
// (!llvm.struct) to the original memref type.
179
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
180
- .getResult (0 );
181
- });
182
- addArgumentMaterialization ([&](OpBuilder &builder, MemRefType resultType,
183
- ValueRange inputs, Location loc) {
184
- Value desc;
185
- if (isBarePointer (inputs)) {
186
- desc = MemRefDescriptor::fromStaticShape (builder, loc, *this , resultType,
187
- inputs[0 ]);
188
- } else if (TypeRange (inputs) ==
189
- getMemRefDescriptorFields (resultType,
190
- /* unpackAggregates=*/ true )) {
191
- desc = MemRefDescriptor::pack (builder, loc, *this , resultType, inputs);
192
- } else {
193
- // The inputs are neither a bare pointer nor an unpacked memref
194
- // descriptor. This materialization function cannot be used.
195
+ Value packed =
196
+ packUnrankedMemRefDesc (builder, resultType, inputs, loc, *this );
197
+ if (!packed)
195
198
return Value ();
196
- }
199
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
200
+ .getResult (0 );
201
+ };
202
+
203
+ // TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
204
+ // must be passed explicitly.
205
+ auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
206
+ ValueRange inputs, Location loc,
207
+ LLVMTypeConverter &converter) -> Value {
208
+ assert (resultType && " expected non-null result type" );
209
+ if (isBarePointer (inputs))
210
+ return MemRefDescriptor::fromStaticShape (builder, loc, converter,
211
+ resultType, inputs[0 ]);
212
+ if (TypeRange (inputs) ==
213
+ converter.getMemRefDescriptorFields (resultType,
214
+ /* unpackAggregates=*/ true ))
215
+ return MemRefDescriptor::pack (builder, loc, converter, resultType,
216
+ inputs);
217
+ // The inputs are neither a bare pointer nor an unpacked memref descriptor.
218
+ // This materialization function cannot be used.
219
+ return Value ();
220
+ };
221
+
222
+ // MemRef descriptor elements -> MemRefType
223
+ auto rankedMemRefMaterialization = [&](OpBuilder &builder,
224
+ MemRefType resultType,
225
+ ValueRange inputs, Location loc) {
197
226
// An argument materialization must return a value of type `resultType`,
198
227
// so insert a cast from the memref descriptor type (!llvm.struct) to the
199
228
// original memref type.
200
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, desc)
201
- .getResult (0 );
202
- });
203
- // Add generic source and target materializations to handle cases where
204
- // non-LLVM types persist after an LLVM conversion.
205
- addSourceMaterialization ([&](OpBuilder &builder, Type resultType,
206
- ValueRange inputs, Location loc) {
207
- if (inputs.size () != 1 )
229
+ Value packed =
230
+ packRankedMemRefDesc (builder, resultType, inputs, loc, *this );
231
+ if (!packed)
208
232
return Value ();
209
-
210
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
233
+ return builder.create <UnrealizedConversionCastOp>(loc, resultType, packed)
211
234
.getResult (0 );
212
- });
235
+ };
236
+
237
+ // Argument materializations convert from the new block argument types
238
+ // (multiple SSA values that make up a memref descriptor) back to the
239
+ // original block argument type.
240
+ addArgumentMaterialization (unrakedMemRefMaterialization);
241
+ addArgumentMaterialization (rankedMemRefMaterialization);
242
+ addSourceMaterialization (unrakedMemRefMaterialization);
243
+ addSourceMaterialization (rankedMemRefMaterialization);
244
+
245
+ // Bare pointer -> Packed MemRef descriptor
213
246
addTargetMaterialization ([&](OpBuilder &builder, Type resultType,
214
- ValueRange inputs, Location loc) {
215
- if (inputs.size () != 1 )
247
+ ValueRange inputs, Location loc,
248
+ Type originalType) -> Value {
249
+ // The original MemRef type is required to build a MemRef descriptor
250
+ // because the sizes/strides of the MemRef cannot be inferred from just the
251
+ // bare pointer.
252
+ if (!originalType)
216
253
return Value ();
217
-
218
- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
219
- .getResult (0 );
254
+ if (resultType != convertType (originalType))
255
+ return Value ();
256
+ if (auto memrefType = dyn_cast<MemRefType>(originalType))
257
+ return packRankedMemRefDesc (builder, memrefType, inputs, loc, *this );
258
+ if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
259
+ return packUnrankedMemRefDesc (builder, unrankedMemrefType, inputs, loc,
260
+ *this );
261
+ return Value ();
220
262
});
221
263
222
264
// Integer memory spaces map to themselves.
0 commit comments