Skip to content

Commit f027116

Browse files
[mlir][Transforms] Support 1:N mappings in ConversionValueMapping
1 parent a496ab4 commit f027116

File tree

12 files changed

+387
-300
lines changed

12 files changed

+387
-300
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 70 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -153,68 +153,106 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156-
// Argument materializations convert from the new block argument types
157-
// (multiple SSA values that make up a memref descriptor) back to the
158-
// original block argument type. The dialect conversion framework will then
159-
// insert a target materialization from the original block argument type to
160-
// a legal type.
161-
addArgumentMaterialization([&](OpBuilder &builder,
162-
UnrankedMemRefType resultType,
163-
ValueRange inputs, Location loc) {
156+
// Add generic source and target materializations to handle cases where
157+
// non-LLVM types persist after an LLVM conversion.
158+
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
159+
ValueRange inputs, Location loc) {
160+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
161+
.getResult(0);
162+
});
163+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
164+
ValueRange inputs, Location loc) {
165+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
166+
.getResult(0);
167+
});
168+
169+
// Source materializations convert the MemrRef descriptor elements
170+
// (multiple SSA values that make up a MemrRef descriptor) back to the
171+
// original MemRef type.
172+
addSourceMaterialization([&](OpBuilder &builder,
173+
UnrankedMemRefType resultType, ValueRange inputs,
174+
Location loc) {
164175
if (inputs.size() == 1) {
165176
// Bare pointers are not supported for unranked memrefs because a
166177
// memref descriptor cannot be built just from a bare pointer.
167178
return Value();
168179
}
169180
Value desc =
170181
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
171-
// An argument materialization must return a value of type
182+
// A source materialization must return a value of type
172183
// `resultType`, so insert a cast from the memref descriptor type
173184
// (!llvm.struct) to the original memref type.
174185
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175186
.getResult(0);
176187
});
177-
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178-
ValueRange inputs, Location loc) {
188+
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
189+
ValueRange inputs, Location loc) {
190+
if (inputs.size() == 1 &&
191+
isa<LLVM::LLVMStructType>(inputs.front().getType()))
192+
return Value();
193+
179194
Value desc;
180-
if (inputs.size() == 1) {
195+
if (inputs.size() == 1 &&
196+
isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
181197
// This is a bare pointer. We allow bare pointers only for function entry
182198
// blocks.
183199
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
184200
if (!barePtr)
185201
return Value();
186-
Block *block = barePtr.getOwner();
187-
if (!block->isEntryBlock() ||
188-
!isa<FunctionOpInterface>(block->getParentOp()))
189-
return Value();
202+
//Block *block = barePtr.getOwner();
203+
//if (!block->isEntryBlock() ||
204+
// !isa<FunctionOpInterface>(block->getParentOp()))
205+
// return Value();
190206
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191207
inputs[0]);
192208
} else {
193209
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
194210
}
195-
// An argument materialization must return a value of type `resultType`,
211+
// A source materialization must return a value of type `resultType`,
196212
// so insert a cast from the memref descriptor type (!llvm.struct) to the
197213
// original memref type.
198214
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
199215
.getResult(0);
200216
});
201-
// Add generic source and target materializations to handle cases where
202-
// non-LLVM types persist after an LLVM conversion.
203-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
204-
ValueRange inputs, Location loc) {
205-
if (inputs.size() != 1)
206-
return Value();
217+
addTargetMaterialization([&](OpBuilder &builder, LLVM::LLVMStructType resultType,
218+
ValueRange inputs, Location loc,
219+
Type originalType) -> Value {
220+
if (auto memrefType = dyn_cast_or_null<MemRefType>(originalType)) {
221+
if (inputs.size() == 1) {
222+
Value input = inputs.front();
223+
//if (auto castOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
224+
// if (castOp.getInputs().size() == 1 &&
225+
// isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
226+
// input = castOp.getInputs()[0];
227+
// }
228+
//}
229+
if (!isa<LLVM::LLVMPointerType>(input.getType()))
230+
return Value();
231+
BlockArgument barePtr = dyn_cast<BlockArgument>(input);
232+
if (!barePtr)
233+
return Value();
234+
//Block *block = barePtr.getOwner();
235+
//if (!block->isEntryBlock() ||
236+
// !isa<FunctionOpInterface>(block->getParentOp()))
237+
// return Value();
238+
// Bare ptr
239+
return MemRefDescriptor::fromStaticShape(builder, loc, *this,
240+
memrefType, input);
241+
}
242+
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
243+
}
207244

208-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
209-
.getResult(0);
210-
});
211-
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
212-
ValueRange inputs, Location loc) {
213-
if (inputs.size() != 1)
214-
return Value();
245+
if (auto memrefType = dyn_cast_or_null<UnrankedMemRefType>(originalType)) {
246+
if (inputs.size() == 1) {
247+
// Bare pointers are not supported for unranked memrefs because a
248+
// memref descriptor cannot be built just from a bare pointer.
249+
return Value();
250+
}
251+
return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType,
252+
inputs);
253+
}
215254

216-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
217-
.getResult(0);
255+
return Value();
218256
});
219257

220258
// Integer memory spaces map to themselves.

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
6161
addConversion([](UnrankedTensorType type) -> Type {
6262
return UnrankedMemRefType::get(type.getElementType(), 0);
6363
});
64-
addArgumentMaterialization(materializeToTensor);
6564
addSourceMaterialization(materializeToTensor);
6665
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
6766
ValueRange inputs, Location loc) -> Value {

mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
3333

3434
converter.addSourceMaterialization(materializeAsUnrealizedCast);
3535
converter.addTargetMaterialization(materializeAsUnrealizedCast);
36-
converter.addArgumentMaterialization(materializeAsUnrealizedCast);
3736
}
3837

3938
/// Get an unsigned integer or size data type corresponding to \p ty.

mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
154154
});
155155

156156
addSourceMaterialization(sourceMaterializationCallback);
157-
addArgumentMaterialization(sourceMaterializationCallback);
158157
}
159158
};
160159

mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ class QuantizedTypeConverter : public TypeConverter {
3636
static Type convertQuantizedType(QuantizedType quantizedType) {
3737
return quantizedType.getStorageType();
3838
}
39-
39+
4040
static Type convertTensorType(TensorType tensorType) {
41-
if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
41+
if (auto quantizedType =
42+
dyn_cast<QuantizedType>(tensorType.getElementType()))
4243
return tensorType.clone(convertQuantizedType(quantizedType));
4344
return tensorType;
4445
}
@@ -50,20 +51,19 @@ class QuantizedTypeConverter : public TypeConverter {
5051
}
5152

5253
public:
53-
5454
explicit QuantizedTypeConverter() {
5555
addConversion([](Type type) { return type; });
5656
addConversion(convertQuantizedType);
5757
addConversion(convertTensorType);
5858

59-
addArgumentMaterialization(materializeConversion);
6059
addSourceMaterialization(materializeConversion);
6160
addTargetMaterialization(materializeConversion);
6261
}
6362
};
6463

6564
// Conversion pass
66-
class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
65+
class StripFuncQuantTypes
66+
: public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
6767

6868
// Return whether a type is considered legal when occurring in the header of
6969
// a function or as an operand to a 'return' op.
@@ -74,11 +74,10 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
7474
}
7575

7676
public:
77-
7877
void runOnOperation() override {
79-
78+
8079
auto moduleOp = cast<ModuleOp>(getOperation());
81-
auto* context = &getContext();
80+
auto *context = &getContext();
8281

8382
QuantizedTypeConverter typeConverter;
8483
ConversionTarget target(*context);
@@ -111,4 +110,3 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
111110

112111
} // namespace quant
113112
} // namespace mlir
114-

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
6969

7070
// Required by scf.for 1:N type conversion.
7171
addSourceMaterialization(materializeTuple);
72-
73-
// Required as a workaround until we have full 1:N support.
74-
addArgumentMaterialization(materializeTuple);
7572
}
7673

7774
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
481481

482482
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
483483
};
484-
typeConverter.addArgumentMaterialization(materializeCast);
485484
typeConverter.addSourceMaterialization(materializeCast);
486485
typeConverter.addTargetMaterialization(materializeCast);
487486
target.markUnknownOpDynamicallyLegal(

0 commit comments

Comments
 (0)