Skip to content

Commit acc159a

Browse files
[mlir][Transforms] Dialect conversion: Fix missing source materialization (#97903)
This commit fixes a bug in the dialect conversion. During a 1:N signature conversion, the dialect conversion did not insert a cast back to the original block argument type, producing invalid IR. See `test-block-legalization.mlir`: Without this commit, the operand type of the op changes because an `unrealized_conversion_cast` is missing: ``` "test.consumer_of_complex"(%v) : (!llvm.struct<(f64, f64)>) -> () ``` To implement this fix, it was necessary to change the meaning of argument materializations. An argument materialization now maps from the new block argument types to the original block argument type. (It now behaves almost like a source materialization.) This also addresses a `FIXME` in the code base: ``` // FIXME: The current argument materialization hook expects the original // output type, even though it doesn't use that as the actual output type // of the generated IR. The output type is just used as an indicator of // the type of materialization to do. This behavior is really awkward in // that it diverges from the behavior of the other hooks, and can be // easily misunderstood. We should clean up the argument hooks to better // represent the desired invariants we actually care about. ``` It is no longer necessary to distinguish between the "output type" and the "original output type". Most type converter are already written according to the new API. (Most implementations use the same conversion functions as for source materializations.) One exception is the MemRef-to-LLVM type converter, which materialized an `!llvm.struct` based on the elements of a memref descriptor. It still does that, but casts the `!llvm.struct` back to the original memref type. The dialect conversion inserts a target materialization (to `!llvm.struct`) which cancels out with the other cast. This commit also fixes a bug in `computeNecessaryMaterializations`. The implementation did not account for the possibility that a value was replaced multiple times. E.g., replace `a` by `b`, then `b` by `c`. This commit also adds a transform dialect op to populate SCF-to-CF patterns. This transform op was needed to write a test case. The bug described here appears only during a complex interplay of 1:N signature conversions and op replacements. (I was not able to trigger it with ops and patterns from the `test` dialect without duplicating the `scf.if` pattern.) Note for LLVM integration: Make sure that all `addArgument/Source/TargetMaterialization` functions produce an SSA of the specified type. Depends on #98743.
1 parent dd7d81e commit acc159a

File tree

9 files changed

+141
-60
lines changed

9 files changed

+141
-60
lines changed

mlir/docs/DialectConversion.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ class TypeConverter {
352352
353353
/// This method registers a materialization that will be called when
354354
/// converting (potentially multiple) block arguments that were the result of
355-
/// a signature conversion of a single block argument, to a single SSA value.
355+
/// a signature conversion of a single block argument, to a single SSA value
356+
/// with the old argument type.
356357
template <typename FnT,
357358
typename T = typename llvm::function_traits<FnT>::template arg_t<1>>
358359
void addArgumentMaterialization(FnT &&callback) {

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ def ApplySCFStructuralConversionPatternsOp : Op<Transform_Dialect,
3838
let assemblyFormat = "attr-dict";
3939
}
4040

41+
def ApplySCFToControlFlowPatternsOp : Op<Transform_Dialect,
42+
"apply_conversion_patterns.scf.scf_to_control_flow",
43+
[DeclareOpInterfaceMethods<ConversionPatternDescriptorOpInterface>]> {
44+
let description = [{
45+
Collects patterns that lower structured control flow ops to unstructured
46+
control flow.
47+
}];
48+
49+
let assemblyFormat = "attr-dict";
50+
}
51+
4152
def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;
4253

4354
def ForallToForOp : Op<Transform_Dialect, "loop.forall_to_for",

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,15 +174,15 @@ class TypeConverter {
174174
/// where `T` is any subclass of `Type`. This function is responsible for
175175
/// creating an operation, using the OpBuilder and Location provided, that
176176
/// "casts" a range of values into a single value of the given type `T`. It
177-
/// must return a Value of the converted type on success, an `std::nullopt` if
177+
/// must return a Value of the type `T` on success, an `std::nullopt` if
178178
/// it failed but other materialization can be attempted, and `nullptr` on
179-
/// unrecoverable failure. It will only be called for (sub)types of `T`.
180-
/// Materialization functions must be provided when a type conversion may
181-
/// persist after the conversion has finished.
179+
/// unrecoverable failure. Materialization functions must be provided when a
180+
/// type conversion may persist after the conversion has finished.
182181

183182
/// This method registers a materialization that will be called when
184183
/// converting (potentially multiple) block arguments that were the result of
185-
/// a signature conversion of a single block argument, to a single SSA value.
184+
/// a signature conversion of a single block argument, to a single SSA value
185+
/// with the old block argument type.
186186
template <typename FnT, typename T = typename llvm::function_traits<
187187
std::decay_t<FnT>>::template arg_t<1>>
188188
void addArgumentMaterialization(FnT &&callback) {

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,11 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
153153
type.isVarArg());
154154
});
155155

156-
// Materialization for memrefs creates descriptor structs from individual
157-
// values constituting them, when descriptors are used, i.e. more than one
158-
// value represents a memref.
156+
// Argument materializations convert from the new block argument types
157+
// (multiple SSA values that make up a memref descriptor) back to the
158+
// original block argument type. The dialect conversion framework will then
159+
// insert a target materialization from the original block argument type to
160+
// a legal type.
159161
addArgumentMaterialization(
160162
[&](OpBuilder &builder, UnrankedMemRefType resultType, ValueRange inputs,
161163
Location loc) -> std::optional<Value> {
@@ -164,12 +166,18 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
164166
// memref descriptor cannot be built just from a bare pointer.
165167
return std::nullopt;
166168
}
167-
return UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
168-
inputs);
169+
Value desc = UnrankedMemRefDescriptor::pack(builder, loc, *this,
170+
resultType, inputs);
171+
// An argument materialization must return a value of type
172+
// `resultType`, so insert a cast from the memref descriptor type
173+
// (!llvm.struct) to the original memref type.
174+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
175+
.getResult(0);
169176
});
170177
addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
171178
ValueRange inputs,
172179
Location loc) -> std::optional<Value> {
180+
Value desc;
173181
if (inputs.size() == 1) {
174182
// This is a bare pointer. We allow bare pointers only for function entry
175183
// blocks.
@@ -180,10 +188,16 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
180188
if (!block->isEntryBlock() ||
181189
!isa<FunctionOpInterface>(block->getParentOp()))
182190
return std::nullopt;
183-
return MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
191+
desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
184192
inputs[0]);
193+
} else {
194+
desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
185195
}
186-
return MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
196+
// An argument materialization must return a value of type `resultType`,
197+
// so insert a cast from the memref descriptor type (!llvm.struct) to the
198+
// original memref type.
199+
return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
200+
.getResult(0);
187201
});
188202
// Add generic source and target materializations to handle cases where
189203
// non-LLVM types persist after an LLVM conversion.

mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRSCFTransformOps
1313
MLIRIR
1414
MLIRLoopLikeInterface
1515
MLIRSCFDialect
16+
MLIRSCFToControlFlow
1617
MLIRSCFTransforms
1718
MLIRSCFUtils
1819
MLIRTransformDialect

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
10+
11+
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
1012
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1113
#include "mlir/Dialect/Affine/LoopUtils.h"
1214
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -49,6 +51,11 @@ void transform::ApplySCFStructuralConversionPatternsOp::
4951
conversionTarget);
5052
}
5153

54+
void transform::ApplySCFToControlFlowPatternsOp::populatePatterns(
55+
TypeConverter &typeConverter, RewritePatternSet &patterns) {
56+
populateSCFToControlFlowConversionPatterns(patterns);
57+
}
58+
5259
//===----------------------------------------------------------------------===//
5360
// ForallToForOp
5461
//===----------------------------------------------------------------------===//
@@ -261,8 +268,10 @@ loopScheduling(scf::ForOp forOp,
261268
return 1;
262269
};
263270

264-
std::optional<int64_t> ubConstant = getConstantIntValue(forOp.getUpperBound());
265-
std::optional<int64_t> lbConstant = getConstantIntValue(forOp.getLowerBound());
271+
std::optional<int64_t> ubConstant =
272+
getConstantIntValue(forOp.getUpperBound());
273+
std::optional<int64_t> lbConstant =
274+
getConstantIntValue(forOp.getLowerBound());
266275
DenseMap<Operation *, unsigned> opCycles;
267276
std::map<unsigned, std::vector<Operation *>> wrappedSchedule;
268277
for (Operation &op : forOp.getBody()->getOperations()) {

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 44 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -707,10 +707,9 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
707707
UnresolvedMaterializationRewrite(
708708
ConversionPatternRewriterImpl &rewriterImpl,
709709
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
710-
MaterializationKind kind = MaterializationKind::Target,
711-
Type origOutputType = nullptr)
710+
MaterializationKind kind = MaterializationKind::Target)
712711
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
713-
converterAndKind(converter, kind), origOutputType(origOutputType) {}
712+
converterAndKind(converter, kind) {}
714713

715714
static bool classof(const IRRewrite *rewrite) {
716715
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -734,17 +733,11 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
734733
return converterAndKind.getInt();
735734
}
736735

737-
/// Return the original illegal output type of the input values.
738-
Type getOrigOutputType() const { return origOutputType; }
739-
740736
private:
741737
/// The corresponding type converter to use when resolving this
742738
/// materialization, and the kind of this materialization.
743739
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
744740
converterAndKind;
745-
746-
/// The original output type. This is only used for argument conversions.
747-
Type origOutputType;
748741
};
749742
} // namespace
750743

@@ -860,12 +853,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
860853
Block *insertBlock,
861854
Block::iterator insertPt, Location loc,
862855
ValueRange inputs, Type outputType,
863-
Type origOutputType,
864856
const TypeConverter *converter);
865857

866858
Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
867859
ValueRange inputs,
868-
Type origOutputType,
869860
Type outputType,
870861
const TypeConverter *converter);
871862

@@ -1388,20 +1379,28 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13881379
if (replArgs.size() == 1 &&
13891380
(!converter || replArgs[0].getType() == origArg.getType())) {
13901381
newArg = replArgs.front();
1382+
mapping.map(origArg, newArg);
13911383
} else {
1392-
Type origOutputType = origArg.getType();
1393-
1394-
// Legalize the argument output type.
1395-
Type outputType = origOutputType;
1396-
if (Type legalOutputType = converter->convertType(outputType))
1397-
outputType = legalOutputType;
1398-
1399-
newArg = buildUnresolvedArgumentMaterialization(
1400-
newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
1401-
converter);
1384+
// Build argument materialization: new block arguments -> old block
1385+
// argument type.
1386+
Value argMat = buildUnresolvedArgumentMaterialization(
1387+
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
1388+
mapping.map(origArg, argMat);
1389+
1390+
// Build target materialization: old block argument type -> legal type.
1391+
// Note: This function returns an "empty" type if no valid conversion to
1392+
// a legal type exists. In that case, we continue the conversion with the
1393+
// original block argument type.
1394+
Type legalOutputType = converter->convertType(origArg.getType());
1395+
if (legalOutputType && legalOutputType != origArg.getType()) {
1396+
newArg = buildUnresolvedTargetMaterialization(
1397+
origArg.getLoc(), argMat, legalOutputType, converter);
1398+
mapping.map(argMat, newArg);
1399+
} else {
1400+
newArg = argMat;
1401+
}
14021402
}
14031403

1404-
mapping.map(origArg, newArg);
14051404
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
14061405
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
14071406
}
@@ -1424,7 +1423,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
14241423
/// of input operands.
14251424
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14261425
MaterializationKind kind, Block *insertBlock, Block::iterator insertPt,
1427-
Location loc, ValueRange inputs, Type outputType, Type origOutputType,
1426+
Location loc, ValueRange inputs, Type outputType,
14281427
const TypeConverter *converter) {
14291428
// Avoid materializing an unnecessary cast.
14301429
if (inputs.size() == 1 && inputs.front().getType() == outputType)
@@ -1435,16 +1434,15 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
14351434
OpBuilder builder(insertBlock, insertPt);
14361435
auto convertOp =
14371436
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1438-
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1439-
origOutputType);
1437+
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
14401438
return convertOp.getResult(0);
14411439
}
14421440
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
1443-
Block *block, Location loc, ValueRange inputs, Type origOutputType,
1444-
Type outputType, const TypeConverter *converter) {
1441+
Block *block, Location loc, ValueRange inputs, Type outputType,
1442+
const TypeConverter *converter) {
14451443
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
14461444
block->begin(), loc, inputs, outputType,
1447-
origOutputType, converter);
1445+
converter);
14481446
}
14491447
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14501448
Location loc, Value input, Type outputType,
@@ -1456,7 +1454,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
14561454

14571455
return buildUnresolvedMaterialization(MaterializationKind::Target,
14581456
insertBlock, insertPt, loc, input,
1459-
outputType, outputType, converter);
1457+
outputType, converter);
14601458
}
14611459

14621460
//===----------------------------------------------------------------------===//
@@ -2672,19 +2670,28 @@ static void computeNecessaryMaterializations(
26722670
ConversionPatternRewriterImpl &rewriterImpl,
26732671
DenseMap<Value, SmallVector<Value>> &inverseMapping,
26742672
SetVector<UnresolvedMaterializationRewrite *> &necessaryMaterializations) {
2673+
// Helper function to check if the given value or a not yet materialized
2674+
// replacement of the given value is live.
2675+
// Note: `inverseMapping` maps from replaced values to original values.
26752676
auto isLive = [&](Value value) {
26762677
auto findFn = [&](Operation *user) {
26772678
auto matIt = materializationOps.find(user);
26782679
if (matIt != materializationOps.end())
26792680
return !necessaryMaterializations.count(matIt->second);
26802681
return rewriterImpl.isOpIgnored(user);
26812682
};
2682-
// This value may be replacing another value that has a live user.
2683-
for (Value inv : inverseMapping.lookup(value))
2684-
if (llvm::find_if_not(inv.getUsers(), findFn) != inv.user_end())
2683+
// A worklist is needed because a value may have gone through a chain of
2684+
// replacements and each of the replaced values may have live users.
2685+
SmallVector<Value> worklist;
2686+
worklist.push_back(value);
2687+
while (!worklist.empty()) {
2688+
Value next = worklist.pop_back_val();
2689+
if (llvm::find_if_not(next.getUsers(), findFn) != next.user_end())
26852690
return true;
2686-
// Or have live users itself.
2687-
return llvm::find_if_not(value.getUsers(), findFn) != value.user_end();
2691+
// This value may be replacing another value that has a live user.
2692+
llvm::append_range(worklist, inverseMapping.lookup(next));
2693+
}
2694+
return false;
26882695
};
26892696

26902697
llvm::unique_function<Value(Value, Value, Type)> lookupRemappedValue =
@@ -2844,18 +2851,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
28442851
switch (mat.getMaterializationKind()) {
28452852
case MaterializationKind::Argument:
28462853
// Try to materialize an argument conversion.
2847-
// FIXME: The current argument materialization hook expects the original
2848-
// output type, even though it doesn't use that as the actual output type
2849-
// of the generated IR. The output type is just used as an indicator of
2850-
// the type of materialization to do. This behavior is really awkward in
2851-
// that it diverges from the behavior of the other hooks, and can be
2852-
// easily misunderstood. We should clean up the argument hooks to better
2853-
// represent the desired invariants we actually care about.
28542854
newMaterialization = converter->materializeArgumentConversion(
2855-
rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands);
2855+
rewriter, op->getLoc(), outputType, inputOperands);
28562856
if (newMaterialization)
28572857
break;
2858-
28592858
// If an argument materialization failed, fallback to trying a target
28602859
// materialization.
28612860
[[fallthrough]];
@@ -2865,6 +2864,8 @@ static LogicalResult legalizeUnresolvedMaterialization(
28652864
break;
28662865
}
28672866
if (newMaterialization) {
2867+
assert(newMaterialization.getType() == outputType &&
2868+
"materialization callback produced value of incorrect type");
28682869
replaceMaterialization(rewriterImpl, opResult, newMaterialization,
28692870
inverseMapping);
28702871
return success();

mlir/test/Conversion/FuncToLLVM/func-memref-return.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: mlir-opt -convert-func-to-llvm -reconcile-unrealized-casts %s | FileCheck %s
22

3-
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' %s | FileCheck %s --check-prefix=BAREPTR
3+
// RUN: mlir-opt -convert-func-to-llvm='use-bare-ptr-memref-call-conv=1' -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
44

5-
// RUN: mlir-opt -transform-interpreter %s | FileCheck %s --check-prefix=BAREPTR
5+
// RUN: mlir-opt -transform-interpreter -reconcile-unrealized-casts %s | FileCheck %s --check-prefix=BAREPTR
66

77
// These tests were separated from func-memref.mlir because applying
88
// -reconcile-unrealized-casts resulted in `llvm.extractvalue` ops getting
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
2+
3+
// CHECK-LABEL: func @complex_block_signature_conversion(
4+
// CHECK: %[[cst:.*]] = complex.constant
5+
// CHECK: %[[complex_llvm:.*]] = builtin.unrealized_conversion_cast %[[cst]] : complex<f64> to !llvm.struct<(f64, f64)>
6+
// Note: Some blocks are omitted.
7+
// CHECK: llvm.br ^[[block1:.*]](%[[complex_llvm]]
8+
// CHECK: ^[[block1]](%[[arg:.*]]: !llvm.struct<(f64, f64)>):
9+
// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[arg]] : !llvm.struct<(f64, f64)> to complex<f64>
10+
// CHECK: llvm.br ^[[block2:.*]]
11+
// CHECK: ^[[block2]]:
12+
// CHECK: "test.consumer_of_complex"(%[[cast]]) : (complex<f64>) -> ()
13+
func.func @complex_block_signature_conversion() {
14+
%cst = complex.constant [0.000000e+00, 0.000000e+00] : complex<f64>
15+
%true = arith.constant true
16+
%0 = scf.if %true -> complex<f64> {
17+
scf.yield %cst : complex<f64>
18+
} else {
19+
scf.yield %cst : complex<f64>
20+
}
21+
22+
// Regression test to ensure that the a source materialization is inserted.
23+
// The operand of "test.consumer_of_complex" must not change.
24+
"test.consumer_of_complex"(%0) : (complex<f64>) -> ()
25+
return
26+
}
27+
28+
module attributes {transform.with_named_sequence} {
29+
transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) {
30+
%func = transform.structured.match ops{["func.func"]} in %toplevel_module
31+
: (!transform.any_op) -> !transform.any_op
32+
transform.apply_conversion_patterns to %func {
33+
transform.apply_conversion_patterns.dialect_to_llvm "cf"
34+
transform.apply_conversion_patterns.func.func_to_llvm
35+
transform.apply_conversion_patterns.scf.scf_to_control_flow
36+
} with type_converter {
37+
transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
38+
} {
39+
legal_dialects = ["llvm"],
40+
partial_conversion
41+
} : !transform.any_op
42+
transform.yield
43+
}
44+
}

0 commit comments

Comments
 (0)