Skip to content

Commit e3946a5

Browse files
[mlir][Transforms] Support 1:N mappings in ConversionValueMapping
1 parent 7025a8c commit e3946a5

File tree

12 files changed

+335
-285
lines changed

12 files changed

+335
-285
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,31 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
// Add generic source and target materializations to handle cases where
157+
// non-LLVM types persist after an LLVM conversion.
158+
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
159+
ValueRange inputs, Location loc) {
160+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
161+
.getResult(0);
162+
});
163+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
164+
ValueRange inputs, Location loc) {
165+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
166+
.getResult(0);
167+
});
168+
156169
// Helper function that checks if the given value range is a bare pointer.
157170
auto isBarePointer = [](ValueRange values) {
158171
return values.size() == 1 &&
159172
isa<LLVM::LLVMPointerType>(values.front().getType());
160173
};
161174

162-
// Argument materializations convert from the new block argument types
163-
// (multiple SSA values that make up a memref descriptor) back to the
164-
// original block argument type. The dialect conversion framework will then
165-
// insert a target materialization from the original block argument type to
166-
// a legal type.
167-
addArgumentMaterialization([&](OpBuilder &builder,
168-
UnrankedMemRefType resultType,
169-
ValueRange inputs, Location loc) {
175+
// Source materializations convert the MemrRef descriptor elements
176+
// (multiple SSA values that make up a MemrRef descriptor) back to the
177+
// original MemRef type.
178+
addSourceMaterialization([&](OpBuilder &builder,
179+
UnrankedMemRefType resultType, ValueRange inputs,
180+
Location loc) {
170181
// Note: Bare pointers are not supported for unranked memrefs because a
171182
// memref descriptor cannot be built just from a bare pointer.
172183
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
@@ -179,8 +190,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
179190
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
180191
.getResult(0);
181192
});
182-
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
183-
ValueRange inputs, Location loc) {
193+
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
194+
ValueRange inputs, Location loc) {
184195
Value desc;
185196
if (isBarePointer(inputs)) {
186197
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
@@ -200,23 +211,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
200211
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
201212
.getResult(0);
202213
});
203-
// Add generic source and target materializations to handle cases where
204-
// non-LLVM types persist after an LLVM conversion.
205-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
206-
ValueRange inputs, Location loc) {
207-
if (inputs.size() != 1)
208-
return Value();
214+
addTargetMaterialization([&](OpBuilder &builder,
215+
LLVM::LLVMStructType resultType,
216+
ValueRange inputs, Location loc,
217+
Type originalType) -> Value {
218+
if (auto memrefType = dyn_cast_or_null<MemRefType>(originalType)) {
219+
if (isBarePointer(inputs)) {
220+
return MemRefDescriptor::fromStaticShape(builder, loc, *this,
221+
memrefType, inputs[0]);
222+
} else if (TypeRange(inputs) ==
223+
getMemRefDescriptorFields(memrefType,
224+
/*unpackAggregates=*/true)) {
225+
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
226+
}
227+
}
209228

210-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211-
.getResult(0);
212-
});
213-
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214-
ValueRange inputs, Location loc) {
215-
if (inputs.size() != 1)
216-
return Value();
229+
if (auto memrefType = dyn_cast_or_null<UnrankedMemRefType>(originalType)) {
230+
// Note: Bare pointers are not supported for unranked memrefs because a
231+
// memref descriptor cannot be built just from a bare pointer.
232+
if (TypeRange(inputs) == getUnrankedMemRefDescriptorFields())
233+
return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType,
234+
inputs);
235+
}
217236

218-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
219-
.getResult(0);
237+
return Value();
220238
});
221239

222240
// Integer memory spaces map to themselves.

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
6161
addConversion([](UnrankedTensorType type) -> Type {
6262
return UnrankedMemRefType::get(type.getElementType(), 0);
6363
});
64-
addArgumentMaterialization(materializeToTensor);
6564
addSourceMaterialization(materializeToTensor);
6665
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
6766
ValueRange inputs, Location loc) -> Value {

mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
3333

3434
converter.addSourceMaterialization(materializeAsUnrealizedCast);
3535
converter.addTargetMaterialization(materializeAsUnrealizedCast);
36-
converter.addArgumentMaterialization(materializeAsUnrealizedCast);
3736
}
3837

3938
/// Get an unsigned integer or size data type corresponding to \p ty.

mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
154154
});
155155

156156
addSourceMaterialization(sourceMaterializationCallback);
157-
addArgumentMaterialization(sourceMaterializationCallback);
158157
}
159158
};
160159

mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ class QuantizedTypeConverter : public TypeConverter {
5656
addConversion(convertQuantizedType);
5757
addConversion(convertTensorType);
5858

59-
addArgumentMaterialization(materializeConversion);
6059
addSourceMaterialization(materializeConversion);
6160
addTargetMaterialization(materializeConversion);
6261
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
6969

7070
// Required by scf.for 1:N type conversion.
7171
addSourceMaterialization(materializeTuple);
72-
73-
// Required as a workaround until we have full 1:N support.
74-
addArgumentMaterialization(materializeTuple);
7572
}
7673

7774
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
481481

482482
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
483483
};
484-
typeConverter.addArgumentMaterialization(materializeCast);
485484
typeConverter.addSourceMaterialization(materializeCast);
486485
typeConverter.addTargetMaterialization(materializeCast);
487486
target.markUnknownOpDynamicallyLegal(

0 commit comments

Comments
 (0)