Skip to content

Commit 16e58e7

Browse files
[mlir][Transforms] Support 1:N mappings in ConversionValueMapping
1 parent a496ab4 commit 16e58e7

File tree

12 files changed

+384
-288
lines changed

12 files changed

+384
-288
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -153,31 +153,42 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156-
// Argument materializations convert from the new block argument types
157-
// (multiple SSA values that make up a memref descriptor) back to the
158-
// original block argument type. The dialect conversion framework will then
159-
// insert a target materialization from the original block argument type to
160-
// a legal type.
161-
addArgumentMaterialization([&](OpBuilder &builder,
162-
UnrankedMemRefType resultType,
163-
ValueRange inputs, Location loc) {
156+
// Add generic source and target materializations to handle cases where
157+
// non-LLVM types persist after an LLVM conversion.
158+
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
159+
ValueRange inputs, Location loc) {
160+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
161+
.getResult(0);
162+
});
163+
164+
// Source materializations convert the MemrRef descriptor elements
165+
// (multiple SSA values that make up a MemrRef descriptor) back to the
166+
// original MemRef type.
167+
addSourceMaterialization([&](OpBuilder &builder,
168+
UnrankedMemRefType resultType, ValueRange inputs,
169+
Location loc) {
164170
if (inputs.size() == 1) {
165171
// Bare pointers are not supported for unranked memrefs because a
166172
// memref descriptor cannot be built just from a bare pointer.
167173
return Value();
168174
}
169175
Value desc =
170176
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
171-
// An argument materialization must return a value of type
177+
// A source materialization must return a value of type
172178
// `resultType`, so insert a cast from the memref descriptor type
173179
// (!llvm.struct) to the original memref type.
174180
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175181
.getResult(0);
176182
});
177-
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
178-
ValueRange inputs, Location loc) {
183+
addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
184+
ValueRange inputs, Location loc) {
185+
if (inputs.size() == 1 &&
186+
isa<LLVM::LLVMStructType>(inputs.front().getType()))
187+
return Value();
188+
179189
Value desc;
180-
if (inputs.size() == 1) {
190+
if (inputs.size() == 1 &&
191+
isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
181192
// This is a bare pointer. We allow bare pointers only for function entry
182193
// blocks.
183194
BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
@@ -192,15 +203,13 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
192203
} else {
193204
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
194205
}
195-
// An argument materialization must return a value of type `resultType`,
206+
// A source materialization must return a value of type `resultType`,
196207
// so insert a cast from the memref descriptor type (!llvm.struct) to the
197208
// original memref type.
198209
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
199210
.getResult(0);
200211
});
201-
// Add generic source and target materializations to handle cases where
202-
// non-LLVM types persist after an LLVM conversion.
203-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
212+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
204213
ValueRange inputs, Location loc) {
205214
if (inputs.size() != 1)
206215
return Value();
@@ -209,12 +218,50 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
209218
.getResult(0);
210219
});
211220
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
212-
ValueRange inputs, Location loc) {
213-
if (inputs.size() != 1)
221+
ValueRange inputs, Location loc,
222+
Type originalType) -> Value {
223+
llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
224+
if (!originalType) {
225+
llvm::errs() << " -- no orig\n";
214226
return Value();
227+
}
228+
if (auto memrefType = dyn_cast<MemRefType>(originalType)) {
229+
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
230+
if (inputs.size() == 1) {
231+
Value input = inputs.front();
232+
if (auto castOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
233+
if (castOp.getInputs().size() == 1 &&
234+
isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
235+
input = castOp.getInputs()[0];
236+
}
237+
}
238+
if (!isa<LLVM::LLVMPointerType>(input.getType()))
239+
return Value();
240+
BlockArgument barePtr = dyn_cast<BlockArgument>(input);
241+
if (!barePtr)
242+
return Value();
243+
Block *block = barePtr.getOwner();
244+
if (!block->isEntryBlock() ||
245+
!isa<FunctionOpInterface>(block->getParentOp()))
246+
return Value();
247+
// Bare ptr
248+
return MemRefDescriptor::fromStaticShape(builder, loc, *this,
249+
memrefType, input);
250+
}
251+
return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
252+
}
253+
if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
254+
assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
255+
if (inputs.size() == 1) {
256+
// Bare pointers are not supported for unranked memrefs because a
257+
// memref descriptor cannot be built just from a bare pointer.
258+
return Value();
259+
}
260+
return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType,
261+
inputs);
262+
}
215263

216-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
217-
.getResult(0);
264+
return Value();
218265
});
219266

220267
// Integer memory spaces map to themselves.

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
6161
addConversion([](UnrankedTensorType type) -> Type {
6262
return UnrankedMemRefType::get(type.getElementType(), 0);
6363
});
64-
addArgumentMaterialization(materializeToTensor);
6564
addSourceMaterialization(materializeToTensor);
6665
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
6766
ValueRange inputs, Location loc) -> Value {

mlir/lib/Dialect/EmitC/Transforms/TypeConversions.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ void mlir::populateEmitCSizeTTypeConversions(TypeConverter &converter) {
3333

3434
converter.addSourceMaterialization(materializeAsUnrealizedCast);
3535
converter.addTargetMaterialization(materializeAsUnrealizedCast);
36-
converter.addArgumentMaterialization(materializeAsUnrealizedCast);
3736
}
3837

3938
/// Get an unsigned integer or size data type corresponding to \p ty.

mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ class DetensorizeTypeConverter : public TypeConverter {
154154
});
155155

156156
addSourceMaterialization(sourceMaterializationCallback);
157-
addArgumentMaterialization(sourceMaterializationCallback);
158157
}
159158
};
160159

mlir/lib/Dialect/Quant/Transforms/StripFuncQuantTypes.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ class QuantizedTypeConverter : public TypeConverter {
3636
static Type convertQuantizedType(QuantizedType quantizedType) {
3737
return quantizedType.getStorageType();
3838
}
39-
39+
4040
static Type convertTensorType(TensorType tensorType) {
41-
if (auto quantizedType = dyn_cast<QuantizedType>(tensorType.getElementType()))
41+
if (auto quantizedType =
42+
dyn_cast<QuantizedType>(tensorType.getElementType()))
4243
return tensorType.clone(convertQuantizedType(quantizedType));
4344
return tensorType;
4445
}
@@ -50,20 +51,19 @@ class QuantizedTypeConverter : public TypeConverter {
5051
}
5152

5253
public:
53-
5454
explicit QuantizedTypeConverter() {
5555
addConversion([](Type type) { return type; });
5656
addConversion(convertQuantizedType);
5757
addConversion(convertTensorType);
5858

59-
addArgumentMaterialization(materializeConversion);
6059
addSourceMaterialization(materializeConversion);
6160
addTargetMaterialization(materializeConversion);
6261
}
6362
};
6463

6564
// Conversion pass
66-
class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
65+
class StripFuncQuantTypes
66+
: public impl::StripFuncQuantTypesBase<StripFuncQuantTypes> {
6767

6868
// Return whether a type is considered legal when occurring in the header of
6969
// a function or as an operand to a 'return' op.
@@ -74,11 +74,10 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
7474
}
7575

7676
public:
77-
7877
void runOnOperation() override {
79-
78+
8079
auto moduleOp = cast<ModuleOp>(getOperation());
81-
auto* context = &getContext();
80+
auto *context = &getContext();
8281

8382
QuantizedTypeConverter typeConverter;
8483
ConversionTarget target(*context);
@@ -111,4 +110,3 @@ class StripFuncQuantTypes : public impl::StripFuncQuantTypesBase<StripFuncQuantT
111110

112111
} // namespace quant
113112
} // namespace mlir
114-

mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,6 @@ SparseTensorTypeToBufferConverter::SparseTensorTypeToBufferConverter() {
6969

7070
// Required by scf.for 1:N type conversion.
7171
addSourceMaterialization(materializeTuple);
72-
73-
// Required as a workaround until we have full 1:N support.
74-
addArgumentMaterialization(materializeTuple);
7572
}
7673

7774
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,6 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality(
481481

482482
return builder.create<vector::ShapeCastOp>(loc, type, inputs.front());
483483
};
484-
typeConverter.addArgumentMaterialization(materializeCast);
485484
typeConverter.addSourceMaterialization(materializeCast);
486485
typeConverter.addTargetMaterialization(materializeCast);
487486
target.markUnknownOpDynamicallyLegal(

0 commit comments

Comments
 (0)