Skip to content

Commit f9e8758

Browse files
[mlir][Transforms] Dialect Conversion: Do not build target mat. during 1:N replacement
fix test experiement
1 parent 74fb992 commit f9e8758

File tree

4 files changed

+128
-103
lines changed

4 files changed

+128
-103
lines changed

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 86 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -153,70 +153,112 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156+
// Add generic source and target materializations to handle cases where
157+
// non-LLVM types persist after an LLVM conversion.
158+
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
159+
ValueRange inputs, Location loc) {
160+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
161+
.getResult(0);
162+
});
163+
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
164+
ValueRange inputs, Location loc) {
165+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
166+
.getResult(0);
167+
});
168+
156169
// Helper function that checks if the given value range is a bare pointer.
157170
auto isBarePointer = [](ValueRange values) {
158171
return values.size() == 1 &&
159172
isa<LLVM::LLVMPointerType>(values.front().getType());
160173
};
161174

162-
// Argument materializations convert from the new block argument types
163-
// (multiple SSA values that make up a memref descriptor) back to the
164-
// original block argument type. The dialect conversion framework will then
165-
// insert a target materialization from the original block argument type to
166-
// a legal type.
167-
addArgumentMaterialization([&](OpBuilder &builder,
168-
UnrankedMemRefType resultType,
169-
ValueRange inputs, Location loc) {
175+
// TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
176+
// must be passed explicitly.
177+
auto packUnrankedMemRefDesc =
178+
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
179+
Location loc, LLVMTypeConverter &converter) -> Value {
170180
// Note: Bare pointers are not supported for unranked memrefs because a
171181
// memref descriptor cannot be built just from a bare pointer.
172-
if (TypeRange(inputs) != getUnrankedMemRefDescriptorFields())
182+
if (TypeRange(inputs) != converter.getUnrankedMemRefDescriptorFields())
173183
return Value();
174-
Value desc =
175-
UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
184+
return UnrankedMemRefDescriptor::pack(builder, loc, converter, resultType,
185+
inputs);
186+
};
187+
188+
// MemRef descriptor elements -> UnrankedMemRefType
189+
auto unrakedMemRefMaterialization = [&](OpBuilder &builder,
190+
UnrankedMemRefType resultType,
191+
ValueRange inputs, Location loc) {
176192
// An argument materialization must return a value of type
177193
// `resultType`, so insert a cast from the memref descriptor type
178194
// (!llvm.struct) to the original memref type.
179-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
180-
.getResult(0);
181-
});
182-
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
183-
ValueRange inputs, Location loc) {
184-
Value desc;
185-
if (isBarePointer(inputs)) {
186-
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
187-
inputs[0]);
188-
} else if (TypeRange(inputs) ==
189-
getMemRefDescriptorFields(resultType,
190-
/*unpackAggregates=*/true)) {
191-
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
192-
} else {
193-
// The inputs are neither a bare pointer nor an unpacked memref
194-
// descriptor. This materialization function cannot be used.
195+
Value packed =
196+
packUnrankedMemRefDesc(builder, resultType, inputs, loc, *this);
197+
if (!packed)
195198
return Value();
196-
}
199+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
200+
.getResult(0);
201+
};
202+
203+
// TODO: For some reason, `this` is nullptr in here, so the LLVMTypeConverter
204+
// must be passed explicitly.
205+
auto packRankedMemRefDesc = [&](OpBuilder &builder, MemRefType resultType,
206+
ValueRange inputs, Location loc,
207+
LLVMTypeConverter &converter) -> Value {
208+
assert(resultType && "expected non-null result type");
209+
if (isBarePointer(inputs))
210+
return MemRefDescriptor::fromStaticShape(builder, loc, converter,
211+
resultType, inputs[0]);
212+
if (TypeRange(inputs) ==
213+
converter.getMemRefDescriptorFields(resultType,
214+
/*unpackAggregates=*/true))
215+
return MemRefDescriptor::pack(builder, loc, converter, resultType,
216+
inputs);
217+
// The inputs are neither a bare pointer nor an unpacked memref descriptor.
218+
// This materialization function cannot be used.
219+
return Value();
220+
};
221+
222+
// MemRef descriptor elements -> MemRefType
223+
auto rankedMemRefMaterialization = [&](OpBuilder &builder,
224+
MemRefType resultType,
225+
ValueRange inputs, Location loc) {
197226
// An argument materialization must return a value of type `resultType`,
198227
// so insert a cast from the memref descriptor type (!llvm.struct) to the
199228
// original memref type.
200-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
201-
.getResult(0);
202-
});
203-
// Add generic source and target materializations to handle cases where
204-
// non-LLVM types persist after an LLVM conversion.
205-
addSourceMaterialization([&](OpBuilder &builder, Type resultType,
206-
ValueRange inputs, Location loc) {
207-
if (inputs.size() != 1)
229+
Value packed =
230+
packRankedMemRefDesc(builder, resultType, inputs, loc, *this);
231+
if (!packed)
208232
return Value();
209-
210-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
233+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, packed)
211234
.getResult(0);
212-
});
235+
};
236+
237+
// Argument materializations convert from the new block argument types
238+
// (multiple SSA values that make up a memref descriptor) back to the
239+
// original block argument type.
240+
addArgumentMaterialization(unrakedMemRefMaterialization);
241+
addArgumentMaterialization(rankedMemRefMaterialization);
242+
addSourceMaterialization(unrakedMemRefMaterialization);
243+
addSourceMaterialization(rankedMemRefMaterialization);
244+
245+
// Bare pointer -> Packed MemRef descriptor
213246
addTargetMaterialization([&](OpBuilder &builder, Type resultType,
214-
ValueRange inputs, Location loc) {
215-
if (inputs.size() != 1)
247+
ValueRange inputs, Location loc,
248+
Type originalType) -> Value {
249+
// The original MemRef type is required to build a MemRef descriptor
250+
// because the sizes/strides of the MemRef cannot be inferred from just the
251+
// bare pointer.
252+
if (!originalType)
216253
return Value();
217-
218-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
219-
.getResult(0);
254+
if (resultType != convertType(originalType))
255+
return Value();
256+
if (auto memrefType = dyn_cast<MemRefType>(originalType))
257+
return packRankedMemRefDesc(builder, memrefType, inputs, loc, *this);
258+
if (auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(originalType))
259+
return packUnrankedMemRefDesc(builder, unrankedMemrefType, inputs, loc,
260+
*this);
261+
return Value();
220262
});
221263

222264
// Integer memory spaces map to themselves.

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -849,8 +849,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
849849
/// function will be deleted when full 1:N support has been added.
850850
///
851851
/// This function inserts an argument materialization back to the original
852-
/// type, followed by a target materialization to the legalized type (if
853-
/// applicable).
852+
/// type.
854853
void insertNTo1Materialization(OpBuilder::InsertPoint ip, Location loc,
855854
ValueRange replacements, Value originalValue,
856855
const TypeConverter *converter);
@@ -1376,9 +1375,13 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13761375
// used as a replacement.
13771376
auto replArgs =
13781377
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
1379-
insertNTo1Materialization(
1380-
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1381-
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1378+
if (replArgs.size() == 1) {
1379+
mapping.map(origArg, replArgs.front());
1380+
} else {
1381+
insertNTo1Materialization(
1382+
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1383+
/*replacements=*/replArgs, /*outputValue=*/origArg, converter);
1384+
}
13821385
appendRewrite<ReplaceBlockArgRewrite>(block, origArg, converter);
13831386
}
13841387

@@ -1437,36 +1440,12 @@ void ConversionPatternRewriterImpl::insertNTo1Materialization(
14371440
// Insert argument materialization back to the original type.
14381441
Type originalType = originalValue.getType();
14391442
UnrealizedConversionCastOp argCastOp;
1440-
Value argMat = buildUnresolvedMaterialization(
1443+
buildUnresolvedMaterialization(
14411444
MaterializationKind::Argument, ip, loc, /*valueToMap=*/originalValue,
1442-
/*inputs=*/replacements, originalType, /*originalType=*/Type(), converter,
1443-
&argCastOp);
1445+
/*inputs=*/replacements, originalType,
1446+
/*originalType=*/Type(), converter, &argCastOp);
14441447
if (argCastOp)
14451448
nTo1TempMaterializations.insert(argCastOp);
1446-
1447-
// Insert target materialization to the legalized type.
1448-
Type legalOutputType;
1449-
if (converter) {
1450-
legalOutputType = converter->convertType(originalType);
1451-
} else if (replacements.size() == 1) {
1452-
// When there is no type converter, assume that the replacement value
1453-
// types are legal. This is reasonable to assume because they were
1454-
// specified by the user.
1455-
// FIXME: This won't work for 1->N conversions because multiple output
1456-
// types are not supported in parts of the dialect conversion. In such a
1457-
// case, we currently use the original value type.
1458-
legalOutputType = replacements[0].getType();
1459-
}
1460-
if (legalOutputType && legalOutputType != originalType) {
1461-
UnrealizedConversionCastOp targetCastOp;
1462-
buildUnresolvedMaterialization(
1463-
MaterializationKind::Target, computeInsertPoint(argMat), loc,
1464-
/*valueToMap=*/argMat, /*inputs=*/argMat,
1465-
/*outputType=*/legalOutputType, /*originalType=*/originalType,
1466-
converter, &targetCastOp);
1467-
if (targetCastOp)
1468-
nTo1TempMaterializations.insert(targetCastOp);
1469-
}
14701449
}
14711450

14721451
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
@@ -2864,6 +2843,9 @@ void TypeConverter::SignatureConversion::remapInput(unsigned origInputNo,
28642843

28652844
LogicalResult TypeConverter::convertType(Type t,
28662845
SmallVectorImpl<Type> &results) const {
2846+
assert(this && "expected non-null type converter");
2847+
assert(t && "expected non-null type");
2848+
28672849
{
28682850
std::shared_lock<decltype(cacheMutex)> cacheReadLock(cacheMutex,
28692851
std::defer_lock);

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ func.func @no_remap_nested() {
124124
// CHECK-NEXT: "foo.region"
125125
// expected-remark@+1 {{op 'foo.region' is not legalizable}}
126126
"foo.region"() ({
127-
// CHECK-NEXT: ^bb0(%{{.*}}: i64, %{{.*}}: i16, %{{.*}}: i64):
128-
^bb0(%i0: i64, %unused: i16, %i1: i64):
129-
// CHECK-NEXT: "test.valid"{{.*}} : (i64, i64)
130-
"test.invalid"(%i0, %i1) : (i64, i64) -> ()
127+
// CHECK-NEXT: ^bb0(%{{.*}}: f64, %{{.*}}: i16, %{{.*}}: f64):
128+
^bb0(%i0: f64, %unused: i16, %i1: f64):
129+
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64)
130+
"test.invalid"(%i0, %i1) : (f64, f64) -> ()
131131
}) : () -> ()
132132
// expected-remark@+1 {{op 'func.return' is not legalizable}}
133133
return

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -979,8 +979,8 @@ struct TestDropOpSignatureConversion : public ConversionPattern {
979979
};
980980
/// This pattern simply updates the operands of the given operation.
981981
struct TestPassthroughInvalidOp : public ConversionPattern {
982-
TestPassthroughInvalidOp(MLIRContext *ctx)
983-
: ConversionPattern("test.invalid", 1, ctx) {}
982+
TestPassthroughInvalidOp(MLIRContext *ctx, const TypeConverter &converter)
983+
: ConversionPattern(converter, "test.invalid", 1, ctx) {}
984984
LogicalResult
985985
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
986986
ConversionPatternRewriter &rewriter) const final {
@@ -1301,19 +1301,19 @@ struct TestLegalizePatternDriver
13011301
TestTypeConverter converter;
13021302
mlir::RewritePatternSet patterns(&getContext());
13031303
populateWithGenerated(patterns);
1304-
patterns.add<
1305-
TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1306-
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1307-
TestUndoBlockArgReplace, TestUndoBlockErase, TestPassthroughInvalidOp,
1308-
TestSplitReturnType, TestChangeProducerTypeI32ToF32,
1309-
TestChangeProducerTypeF32ToF64, TestChangeProducerTypeF32ToInvalid,
1310-
TestUpdateConsumerType, TestNonRootReplacement,
1311-
TestBoundedRecursiveRewrite, TestNestedOpCreationUndoRewrite,
1312-
TestReplaceEraseOp, TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1313-
TestUndoPropertiesModification, TestEraseOp,
1314-
TestRepetitive1ToNConsumer>(&getContext());
1315-
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp>(
1316-
&getContext(), converter);
1304+
patterns
1305+
.add<TestRegionRewriteBlockMovement, TestDetachedSignatureConversion,
1306+
TestRegionRewriteUndo, TestCreateBlock, TestCreateIllegalBlock,
1307+
TestUndoBlockArgReplace, TestUndoBlockErase, TestSplitReturnType,
1308+
TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
1309+
TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
1310+
TestNonRootReplacement, TestBoundedRecursiveRewrite,
1311+
TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
1312+
TestCreateUnregisteredOp, TestUndoMoveOpBefore,
1313+
TestUndoPropertiesModification, TestEraseOp,
1314+
TestRepetitive1ToNConsumer>(&getContext());
1315+
patterns.add<TestDropOpSignatureConversion, TestDropAndReplaceInvalidOp,
1316+
TestPassthroughInvalidOp>(&getContext(), converter);
13171317
patterns.add<TestDuplicateBlockArgs>(converter, &getContext());
13181318
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
13191319
converter);
@@ -1749,8 +1749,9 @@ struct TestTypeConversionAnotherProducer
17491749
};
17501750

17511751
struct TestReplaceWithLegalOp : public ConversionPattern {
1752-
TestReplaceWithLegalOp(MLIRContext *ctx)
1753-
: ConversionPattern("test.replace_with_legal_op", /*benefit=*/1, ctx) {}
1752+
TestReplaceWithLegalOp(const TypeConverter &converter, MLIRContext *ctx)
1753+
: ConversionPattern(converter, "test.replace_with_legal_op",
1754+
/*benefit=*/1, ctx) {}
17541755
LogicalResult
17551756
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
17561757
ConversionPatternRewriter &rewriter) const final {
@@ -1872,12 +1873,12 @@ struct TestTypeConversionDriver
18721873

18731874
// Initialize the set of rewrite patterns.
18741875
RewritePatternSet patterns(&getContext());
1875-
patterns.add<TestTypeConsumerForward, TestTypeConversionProducer,
1876-
TestSignatureConversionUndo,
1877-
TestTestSignatureConversionNoConverter>(converter,
1878-
&getContext());
1879-
patterns.add<TestTypeConversionAnotherProducer, TestReplaceWithLegalOp>(
1880-
&getContext());
1876+
patterns
1877+
.add<TestTypeConsumerForward, TestTypeConversionProducer,
1878+
TestSignatureConversionUndo,
1879+
TestTestSignatureConversionNoConverter, TestReplaceWithLegalOp>(
1880+
converter, &getContext());
1881+
patterns.add<TestTypeConversionAnotherProducer>(&getContext());
18811882
mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
18821883
converter);
18831884

0 commit comments

Comments
 (0)