Skip to content

[mlir][Transforms] Dialect conversion: Simplify handling of dropped arguments #97213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 55 additions & 118 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're seeing asserts downstream after this PR that go away if the commit is reverted.

mlir/lib/IR/Region.cpp:25: MLIRContext *mlir::Region::getContext(): Assertion `container && "region is not attached to a container"' failed.

Logs:

IR before the pass that crashes for one test case: https://gist.github.com/ScottTodd/6a9fdc0976d7336291d61ccf24bcb22b

Our downstream code involved in the callstack is here: https://github.com/iree-org/iree/blob/main/compiler/src/iree/compiler/InputConversion/Common/ConvertPrimitiveType.cpp . I'm not sure yet if our usage downstream needs to change or if there is a bug in this upstream code. Do you have any suggestions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like an edge case with detached IR. can you try passing the context from the rewriter to “ buildUnresolvedMaterialization” as an extra argument (in the dialect conversion)? if that doesn’t fix it, please revert and i will get back to it when i have time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and create the builder with the context:

// Create an unresolved materialization. We use a new OpBuilder to avoid
// tracking the materialization like we do for other operations.
OpBuilder builder(insertBlock, insertPt);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This patch appears to fix our crashes: https://gist.github.com/ScottTodd/7d05663c3180f5ae5711e278479f0146

I have us downstream set to carry a local revert as a bandaid fix. Some options:

  1. Fix-forward: I send this patch as a PR (not sure what tests to add, I don't understand this code well enough)
  2. Fix-forward: you take over the patch
  3. We revert this PR and then proceed with a rollforward including the fix patch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If that's all that's necessary I can send a fix later today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Here's my patch as a commit, if it helps: ScottTodd@f766cd2

Original file line number Diff line number Diff line change
Expand Up @@ -432,34 +432,14 @@ class MoveBlockRewrite : public BlockRewrite {
Block *insertBeforeBlock;
};

/// This structure contains the information pertaining to an argument that has
/// been converted.
struct ConvertedArgInfo {
ConvertedArgInfo(unsigned newArgIdx, unsigned newArgSize,
Value castValue = nullptr)
: newArgIdx(newArgIdx), newArgSize(newArgSize), castValue(castValue) {}

/// The start index of in the new argument list that contains arguments that
/// replace the original.
unsigned newArgIdx;

/// The number of arguments that replaced the original argument.
unsigned newArgSize;

/// The cast value that was created to cast from the new arguments to the
/// old. This only used if 'newArgSize' > 1.
Value castValue;
};

/// Block type conversion. This rewrite is partially reflected in the IR.
class BlockTypeConversionRewrite : public BlockRewrite {
public:
BlockTypeConversionRewrite(
ConversionPatternRewriterImpl &rewriterImpl, Block *block,
Block *origBlock, SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo,
const TypeConverter *converter)
BlockTypeConversionRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Block *block, Block *origBlock,
const TypeConverter *converter)
: BlockRewrite(Kind::BlockTypeConversion, rewriterImpl, block),
origBlock(origBlock), argInfo(argInfo), converter(converter) {}
origBlock(origBlock), converter(converter) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::BlockTypeConversion;
Expand All @@ -479,10 +459,6 @@ class BlockTypeConversionRewrite : public BlockRewrite {
/// The original block that was requested to have its signature converted.
Block *origBlock;

/// The conversion information for each of the arguments. The information is
/// std::nullopt if the argument was dropped during conversion.
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;

/// The type converter used to convert the arguments.
const TypeConverter *converter;
};
Expand Down Expand Up @@ -691,12 +667,16 @@ class CreateOperationRewrite : public OperationRewrite {
/// The type of materialization.
enum MaterializationKind {
/// This materialization materializes a conversion for an illegal block
/// argument type, to a legal one.
/// argument type, to the original one.
Argument,

/// This materialization materializes a conversion from an illegal type to a
/// legal one.
Target
Target,

/// This materialization materializes a conversion from a legal type back to
/// an illegal one.
Source
};

/// An unresolved materialization, i.e., a "builtin.unrealized_conversion_cast"
Expand Down Expand Up @@ -736,7 +716,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
private:
/// The corresponding type converter to use when resolving this
/// materialization, and the kind of this materialization.
llvm::PointerIntPair<const TypeConverter *, 1, MaterializationKind>
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
converterAndKind;
};
} // namespace
Expand Down Expand Up @@ -855,11 +835,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ValueRange inputs, Type outputType,
const TypeConverter *converter);

Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
ValueRange inputs,
Type outputType,
const TypeConverter *converter);

Value buildUnresolvedTargetMaterialization(Location loc, Value input,
Type outputType,
const TypeConverter *converter);
Expand Down Expand Up @@ -989,28 +964,6 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
for (Operation *op : block->getUsers())
listener->notifyOperationModified(op);

// Process the remapping for each of the original arguments.
for (auto [origArg, info] :
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
// Handle the case of a 1->0 value mapping.
if (!info) {
if (Value newArg =
rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
rewriter.replaceAllUsesWith(origArg, newArg);
continue;
}

// Otherwise this is a 1->1+ value mapping.
Value castValue = info->castValue;
assert(info->newArgSize >= 1 && castValue && "expected 1->1+ mapping");

// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty()) {
rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
castValue, origArg.getType()));
}
}
}

void BlockTypeConversionRewrite::rollback() {
Expand All @@ -1035,14 +988,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
continue;

Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
bool isDroppedArg = replacementValue == origArg;
if (!isDroppedArg)
builder.setInsertionPointAfterValue(replacementValue);
assert(replacementValue && "replacement value not found");
Value newArg;
if (converter) {
builder.setInsertionPointAfterValue(replacementValue);
newArg = converter->materializeSourceConversion(
builder, origArg.getLoc(), origArg.getType(),
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
builder, origArg.getLoc(), origArg.getType(), replacementValue);
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
"type");
Expand All @@ -1053,8 +1004,6 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
<< "failed to materialize conversion for block argument #"
<< it.index() << " that remained live after conversion, type was "
<< origArg.getType();
if (!isDroppedArg)
diag << ", with target type " << replacementValue.getType();
diag.attachNote(liveUser->getLoc())
<< "see existing live user here: " << *liveUser;
return failure();
Expand Down Expand Up @@ -1340,73 +1289,64 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Replace all uses of the old block with the new block.
block->replaceAllUsesWith(newBlock);

// Remap each of the original arguments as determined by the signature
// conversion.
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
argInfo.resize(origArgCount);

for (unsigned i = 0; i != origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap)
continue;
BlockArgument origArg = block->getArgument(i);
Type origArgType = origArg.getType();

std::optional<TypeConverter::SignatureConversion::InputMapping> inputMap =
signatureConversion.getInputMapping(i);
if (!inputMap) {
// This block argument was dropped and no replacement value was provided.
// Materialize a replacement value "out of thin air".
Value repl = buildUnresolvedMaterialization(
MaterializationKind::Source, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/origArgType, converter);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}

// If inputMap->replacementValue is not nullptr, then the argument is
// dropped and a replacement value is provided to be the remappedValue.
if (inputMap->replacementValue) {
if (Value repl = inputMap->replacementValue) {
// This block argument was dropped and a replacement value was provided.
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
mapping.map(origArg, inputMap->replacementValue);
mapping.map(origArg, repl);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
continue;
}

// Otherwise, this is a 1->1+ mapping.
// This is a 1->1+ mapping. 1->N mappings are not fully supported in the
// dialect conversion. Therefore, we need an argument materialization to
// turn the replacement block arguments into a single SSA value that can be
// used as a replacement.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
Value newArg;
Value argMat = buildUnresolvedMaterialization(
MaterializationKind::Argument, newBlock, newBlock->begin(),
origArg.getLoc(), /*inputs=*/replArgs, origArgType, converter);
mapping.map(origArg, argMat);
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);

// If this is a 1->1 mapping and the types of new and replacement arguments
// match (i.e. it's an identity map), then the argument is mapped to its
// original type.
// FIXME: We simply pass through the replacement argument if there wasn't a
// converter, which isn't great as it allows implicit type conversions to
// appear. We should properly restructure this code to handle cases where a
// converter isn't provided and also to properly handle the case where an
// argument materialization is actually a temporary source materialization
// (e.g. in the case of 1->N).
if (replArgs.size() == 1 &&
(!converter || replArgs[0].getType() == origArg.getType())) {
newArg = replArgs.front();
mapping.map(origArg, newArg);
} else {
// Build argument materialization: new block arguments -> old block
// argument type.
Value argMat = buildUnresolvedArgumentMaterialization(
newBlock, origArg.getLoc(), replArgs, origArg.getType(), converter);
mapping.map(origArg, argMat);

// Build target materialization: old block argument type -> legal type.
// Note: This function returns an "empty" type if no valid conversion to
// a legal type exists. In that case, we continue the conversion with the
// original block argument type.
Type legalOutputType = converter->convertType(origArg.getType());
if (legalOutputType && legalOutputType != origArg.getType()) {
newArg = buildUnresolvedTargetMaterialization(
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, newArg);
} else {
newArg = argMat;
}
Type legalOutputType;
if (converter)
legalOutputType = converter->convertType(origArgType);
if (legalOutputType && legalOutputType != origArgType) {
Value targetMat = buildUnresolvedTargetMaterialization(
origArg.getLoc(), argMat, legalOutputType, converter);
mapping.map(argMat, targetMat);
}

appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg);
}

appendRewrite<BlockTypeConversionRewrite>(newBlock, block, argInfo,
converter);
appendRewrite<BlockTypeConversionRewrite>(newBlock, block, converter);

// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
Expand Down Expand Up @@ -1437,13 +1377,6 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
Block *block, Location loc, ValueRange inputs, Type outputType,
const TypeConverter *converter) {
return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
block->begin(), loc, inputs, outputType,
converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
const TypeConverter *converter) {
Expand Down Expand Up @@ -2862,6 +2795,10 @@ static LogicalResult legalizeUnresolvedMaterialization(
newMaterialization = converter->materializeTargetConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
case MaterializationKind::Source:
newMaterialization = converter->materializeSourceConversion(
rewriter, op->getLoc(), outputType, inputOperands);
break;
}
if (newMaterialization) {
assert(newMaterialization.getType() == outputType &&
Expand All @@ -2874,8 +2811,8 @@ static LogicalResult legalizeUnresolvedMaterialization(

InFlightDiagnostic diag = op->emitError()
<< "failed to legalize unresolved materialization "
"from "
<< inputOperands.getTypes() << " to " << outputType
"from ("
<< inputOperands.getTypes() << ") to " << outputType
<< " that remained live after conversion";
if (Operation *liveUser = findLiveUser(op->getUsers())) {
diag.attachNote(liveUser->getLoc())
Expand Down
6 changes: 2 additions & 4 deletions mlir/test/Transforms/test-legalize-type-conversion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@


func.func @test_invalid_arg_materialization(
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'i16'}}
// expected-error@below {{failed to legalize unresolved materialization from () to 'i16' that remained live after conversion}}
%arg0: i16) {
// expected-note@below {{see existing live user here}}
"foo.return"(%arg0) : (i16) -> ()
}

Expand Down Expand Up @@ -104,9 +103,8 @@ func.func @test_block_argument_not_converted() {
// Make sure argument type changes aren't implicitly forwarded.
func.func @test_signature_conversion_no_converter() {
"test.signature_conversion_no_converter"() ({
// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}}
// expected-error@below {{failed to legalize unresolved materialization from ('f64') to 'f32' that remained live after conversion}}
^bb0(%arg0: f32):
// expected-note@below {{see existing live user here}}
"test.type_consumer"(%arg0) : (f32) -> ()
"test.return"(%arg0) : (f32) -> ()
}) : () -> ()
Expand Down
Loading