Skip to content

Commit ef8d04d

Browse files
[mlir][Transforms] Support 1:N mappings in ConversionValueMapping
1 parent 5857c76 commit ef8d04d

File tree

12 files changed

+354
-298
lines changed

12 files changed

+354
-298
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,31 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
// Add generic source and target materializations to handle cases where
157+
// non-LLVM types persist after an LLVM conversion.
158+
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
159+
ValueRange inputs, Location loc) {
160+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
161+
.getResult(0);
162+
});
163+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
164+
ValueRange inputs, Location loc) {
165+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
166+
.getResult(0);
167+
});
168+
156169
// Helper function that checks if the given value range is a bare pointer.
157170
auto isBarePointer = [](ValueRange values) {
158171
return values.size() == 1 &&
159172
isa<LLVM::LLVMPointerType>(values.front().getType());
160173
};
161174

162-
// Argument materializations convert from the new block argument types
163-
// (multiple SSA values that make up a memref descriptor) back to the
164-
// original block argument type. The dialect conversion framework will then
165-
// insert a target materialization from the original block argument type to
166-
// a legal type.
167-
addArgumentMaterialization([&](OpBuilder &builder,
168-
UnrankedMemRefType resultType,
169-
ValueRange inputs, Location loc) {
175+
// Source materializations convert the MemrRef descriptor elements
176+
// (multiple SSA values that make up a MemrRef descriptor) back to the
177+
// original MemRef type.
178+
addSourceMaterialization([&](OpBuilder &builder,
179+
UnrankedMemRefType resultType, ValueRange inputs,
180+
Location loc) {
170181
// Note: Bare pointers are not supported for unranked memrefs because a
171182
// memref descriptor cannot be built just from a bare pointer.
172183
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
@@ -179,8 +190,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
179190
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
180191
.getResult(0);
181192
});
182-
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
183-
ValueRange inputs, Location loc) {
193+
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
194+
ValueRange inputs, Location loc) {
184195
Value desc;
185196
if (isBarePointer(inputs)) {
186197
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
@@ -200,23 +211,30 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
200211
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
201212
.getResult(0);
202213
});
203-
// Add generic source and target materializations to handle cases where
204-
// non-LLVM types persist after an LLVM conversion.
205-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
206-
ValueRange inputs, Location loc) {
207-
if (inputs.size() != 1)
208-
return Value();
214+
addTargetMaterialization([&](OpBuilder &builder,
215+
LLVM::LLVMStructType resultType,
216+
ValueRange inputs, Location loc,
217+
Type originalType) -> Value {
218+
if (auto memrefType = dyn_cast_or_null<MemRefType>(originalType)) {
219+
if (isBarePointer(inputs)) {
220+
return MemRefDescriptor::fromStaticShape(builder, loc, *this,
221+
memrefType, inputs[0]);
222+
} else if (TypeRange(inputs) ==
223+
getMemRefDescriptorFields(memrefType,
224+
/*unpackAggregates=*/true)) {
225+
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
226+
}
227+
}
209228

210-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
211-
.getResult(0);
212-
});
213-
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214-
ValueRange inputs, Location loc) {
215-
if (inputs.size() != 1)
216-
return Value();
229+
if (auto memrefType = dyn_cast_or_null<UnrankedMemRefType>(originalType)) {
230+
// Note: Bare pointers are not supported for unranked memrefs because a
231+
// memref descriptor cannot be built just from a bare pointer.
232+
if (TypeRange(inputs) == getUnrankedMemRefDescriptorFields())
233+
return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType,
234+
inputs);
235+
}
217236

218-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
219-
.getResult(0);
237+
return Value();
220238
});
221239

222240
// Integer memory spaces map to themselves.

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
6161
addConversion([](UnrankedTensorType type) -> Type {
6262
return UnrankedMemRefType::get(type.getElementType(), 0);
6363
});
64-
addArgumentMaterialization(materializeToTensor);
6564
addSourceMaterialization(materializeToTensor);
6665
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
6766
ValueRange inputs, Location loc) -> Value {

mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
3333

3434
converter.addSourceMaterialization(materializeAsUnrealizedCast);
3535
converter.addTargetMaterialization(materializeAsUnrealizedCast);
36-
converter.addArgumentMaterialization(materializeAsUnrealizedCast);
3736
}
3837

3938
/// Get an unsigned integer or size data type corresponding to \p ty.

mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
154154
});
155155

156156
addSourceMaterialization(sourceMaterializationCallback);
157-
addArgumentMaterialization(sourceMaterializationCallback);
158157
}
159158
};
160159

mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ class QuantizedTypeConverter : public TypeConverter {
3636
static Type convertQuantizedType(QuantizedType quantizedType) {
3737
return quantizedType.getStorageType();
3838
}
39-
39+
4040
static Type convertTensorType(TensorType tensorType) {
41-
if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
41+
if (auto quantizedType =
42+
dyn_cast<QuantizedType>(tensorType.getElementType()))
4243
return tensorType.clone(convertQuantizedType(quantizedType));
4344
return tensorType;
4445
}
@@ -50,20 +51,19 @@ class QuantizedTypeConverter : public TypeConverter {
5051
}
5152

5253
public:
53-
5454
explicit QuantizedTypeConverter() {
5555
addConversion([](Type type) { return type; });
5656
addConversion(convertQuantizedType);
5757
addConversion(convertTensorType);
5858

59-
addArgumentMaterialization(materializeConversion);
6059
addSourceMaterialization(materializeConversion);
6160
addTargetMaterialization(materializeConversion);
6261
}
6362
};
6463

6564
// Conversion pass
66-
class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
65+
class StripFuncQuantTypes
66+
: public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
6767

6868
// Return whether a type is considered legal when occurring in the header of
6969
// a function or as an operand to a 'return' op.
@@ -74,11 +74,10 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
7474
}
7575

7676
public:
77-
7877
void runOnOperation() override {
79-
78+
8079
auto moduleOp = cast<ModuleOp>(getOperation());
81-
auto* context = &getContext();
80+
auto *context = &getContext();
8281

8382
QuantizedTypeConverter typeConverter;
8483
ConversionTarget target(*context);
@@ -111,4 +110,3 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
111110

112111
} // namespace quant
113112
} // namespace mlir
114-

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
6969

7070
// Required by scf.for 1:N type conversion.
7171
addSourceMaterialization(materializeTuple);
72-
73-
// Required as a workaround until we have full 1:N support.
74-
addArgumentMaterialization(materializeTuple);
7572
}
7673

7774
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
481481

482482
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
483483
};
484-
typeConverter.addArgumentMaterialization(materializeCast);
485484
typeConverter.addSourceMaterialization(materializeCast);
486485
typeConverter.addTargetMaterialization(materializeCast);
487486
target.markUnknownOpDynamicallyLegal(

0 commit comments

Comments
 (0)