Skip to content

[mlir][RFC] Bytecode: op fallback path #129784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/include/mlir/Bytecode/BytecodeImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
//===--------------------------------------------------------------------===//
Expand Down Expand Up @@ -445,6 +446,14 @@ class BytecodeDialectInterface
return Type();
}

/// Fall back to an operation of this type if parsing an op from bytecode
/// fails for any reason. This can be used to handle new ops emitted from a
/// different version of the dialect, that cannot be read by an older version
/// of the dialect.
virtual FailureOr<OperationName> getFallbackOperationName() const {
return failure();
}

//===--------------------------------------------------------------------===//
// Writing
//===--------------------------------------------------------------------===//
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/mlir/Bytecode/BytecodeOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,28 @@ def BytecodeOpInterface : OpInterface<"BytecodeOpInterface"> {
];
}

// `FallbackBytecodeOpInterface`
def FallbackBytecodeOpInterface : OpInterface<"FallbackBytecodeOpInterface"> {
let description = [{
This interface allows fallback operations sideband access to the
original operation's intrinsic details.
}];
let cppNamespace = "::mlir";

let methods = [
StaticInterfaceMethod<[{
Set the original name for this operation from the bytecode.
}],
"void", "setOriginalOperationName", (ins
"const ::mlir::Twine&":$opName,
"::mlir::OperationState &":$state)
>,
InterfaceMethod<[{
Get the original name for this operation from the bytecode.
}],
"::mlir::StringRef", "getOriginalOperationName", (ins)
>
];
}

#endif // MLIR_BYTECODE_BYTECODEOPINTERFACES
101 changes: 81 additions & 20 deletions mlir/lib/Bytecode/Reader/BytecodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/MemoryBufferRef.h"
#include "llvm/Support/SourceMgr.h"
Expand Down Expand Up @@ -292,6 +293,16 @@ class EncodingReader {

Location getLoc() const { return fileLoc; }

/// Snapshot the location of the BytecodeReader so that parsing can be rewound
/// if needed.
struct Snapshot {
EncodingReader &reader;
const uint8_t *dataIt;

Snapshot(EncodingReader &reader) : reader(reader), dataIt(reader.dataIt) {}
void rewind() { reader.dataIt = dataIt; }
};

private:
/// Parse a variable length encoded integer from the byte stream. This method
/// is a fallback when the number of bytes used to encode the value is greater
Expand Down Expand Up @@ -1410,8 +1421,9 @@ class mlir::BytecodeReader::Impl {
/// Parse an operation name reference using the given reader, and set the
/// `wasRegistered` flag that indicates if the bytecode was produced by a
/// context where opName was registered.
FailureOr<OperationName> parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered);
FailureOr<BytecodeOperationName *>
parseOpName(EncodingReader &reader, std::optional<bool> &wasRegistered,
bool useDialectFallback);

//===--------------------------------------------------------------------===//
// Attribute/Type Section
Expand Down Expand Up @@ -1476,7 +1488,8 @@ class mlir::BytecodeReader::Impl {
RegionReadState &readState);
FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove);
bool &isIsolatedFromAbove,
bool useDialectFallback);

LogicalResult parseRegion(RegionReadState &readState);
LogicalResult parseBlockHeader(EncodingReader &reader,
Expand Down Expand Up @@ -1506,7 +1519,7 @@ class mlir::BytecodeReader::Impl {
UseListOrderStorage(bool isIndexPairEncoding,
SmallVector<unsigned, 4> &&indices)
: indices(std::move(indices)),
isIndexPairEncoding(isIndexPairEncoding){};
isIndexPairEncoding(isIndexPairEncoding) {};
/// The vector containing the information required to reorder the
/// use-list of a value.
SmallVector<unsigned, 4> indices;
Expand Down Expand Up @@ -1843,16 +1856,20 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
return success();
}

FailureOr<OperationName>
FailureOr<BytecodeOperationName *>
BytecodeReader::Impl::parseOpName(EncodingReader &reader,
std::optional<bool> &wasRegistered) {
std::optional<bool> &wasRegistered,
bool useDialectFallback) {
BytecodeOperationName *opName = nullptr;
if (failed(parseEntry(reader, opNames, opName, "operation name")))
return failure();
wasRegistered = opName->wasRegistered;
// Check to see if this operation name has already been resolved. If we
// haven't, load the dialect and build the operation name.
if (!opName->opName) {
// If `useDialectFallback`, it's likely that parsing previously failed. We'll
// need to reset any previously resolved OperationName with that of the
// fallback op.
if (!opName->opName || useDialectFallback) {
// If the opName is empty, this is because we use to accept names such as
// `foo` without any `.` separator. We shouldn't tolerate this in textual
// format anymore but for now we'll be backward compatible. This can only
Expand All @@ -1865,11 +1882,26 @@ BytecodeReader::Impl::parseOpName(EncodingReader &reader,
dialectsMap, reader, version);
if (failed(opName->dialect->load(dialectReader, getContext())))
return failure();
opName->opName.emplace((opName->dialect->name + "." + opName->name).str(),
getContext());

const BytecodeDialectInterface *dialectIface = opName->dialect->interface;
if (useDialectFallback) {
FailureOr<OperationName> fallbackOp =
dialectIface ? dialectIface->getFallbackOperationName()
: FailureOr<OperationName>{};

// If the dialect doesn't have a fallback operation, we can't parse as
// instructed.
if (failed(fallbackOp))
return failure();

opName->opName.emplace(*fallbackOp);
} else {
opName->opName.emplace(
(opName->dialect->name + "." + opName->name).str(), getContext());
}
}
}
return *opName->opName;
return opName;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2143,10 +2175,30 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
// Read in the next operation. We don't read its regions directly, we
// handle those afterwards as necessary.
bool isIsolatedFromAbove = false;
FailureOr<Operation *> op =
parseOpWithoutRegions(reader, readState, isIsolatedFromAbove);
if (failed(op))
return failure();
FailureOr<Operation *> op;

// Parse the bytecode.
{
// If the op is registered (and serialized in a compatible manner), or
// unregistered but uses standard properties encoding, parsing without
// going through the fallback path should work.
EncodingReader::Snapshot snapshot(reader);
op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
/*useDialectFallback=*/false);

// If reading fails, try parsing the op again as a dialect fallback
// op (if supported).
if (failed(op)) {
snapshot.rewind();
op = parseOpWithoutRegions(reader, readState, isIsolatedFromAbove,
/*useDialectFallback=*/true);
}

// If the dialect doesn't have a fallback op, or parsing as a fallback
// op fails, we can no longer continue.
if (failed(op))
return failure();
}

// If the op has regions, add it to the stack for processing and return:
// we stop the processing of the current region and resume it after the
Expand Down Expand Up @@ -2208,14 +2260,17 @@ BytecodeReader::Impl::parseRegions(std::vector<RegionReadState> &regionStack,
return success();
}

FailureOr<Operation *>
BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
RegionReadState &readState,
bool &isIsolatedFromAbove) {
FailureOr<Operation *> BytecodeReader::Impl::parseOpWithoutRegions(
EncodingReader &reader, RegionReadState &readState,
bool &isIsolatedFromAbove, bool useDialectFallback) {
// Parse the name of the operation.
std::optional<bool> wasRegistered;
FailureOr<OperationName> opName = parseOpName(reader, wasRegistered);
if (failed(opName))
FailureOr<BytecodeOperationName *> bytecodeOp =
parseOpName(reader, wasRegistered, useDialectFallback);
if (failed(bytecodeOp))
return failure();
auto opName = (*bytecodeOp)->opName;
if (!opName)
return failure();

// Parse the operation mask, which indicates which components of the operation
Expand All @@ -2232,6 +2287,12 @@ BytecodeReader::Impl::parseOpWithoutRegions(EncodingReader &reader,
// With the location and name resolved, we can start building the operation
// state.
OperationState opState(opLoc, *opName);
// If this is a fallback op, provide the original name of the operation.
if (auto *iface = opName->getInterface<FallbackBytecodeOpInterface>()) {
const Twine originalName =
opName->getDialect()->getNamespace() + "." + (*bytecodeOp)->name;
iface->setOriginalOperationName(originalName, opState);
}

// Parse the attributes of the operation.
if (opMask & bytecode::OpEncodingMask::kHasAttrs) {
Expand Down
13 changes: 10 additions & 3 deletions mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,12 +841,12 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {

// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
const bool isKnownOp = name.isOpaqueEntry || name.name.isRegistered();
size_t stringId = stringSection.insert(name.name.stripDialect());
if (config.bytecodeVersion < bytecode::kNativePropertiesEncoding)
dialectEmitter.emitVarInt(stringId, "dialect op name");
else
dialectEmitter.emitVarIntWithFlag(stringId, name.name.isRegistered(),
"dialect op name");
dialectEmitter.emitVarIntWithFlag(stringId, isKnownOp, "dialect op name");
};
writeDialectGrouping(dialectEmitter, numberingState.getOpNames(), emitOpName);

Expand Down Expand Up @@ -984,7 +984,14 @@ LogicalResult BytecodeWriter::writeBlock(EncodingEmitter &emitter,
}

LogicalResult BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {
emitter.emitVarInt(numberingState.getNumber(op->getName()), "op name ID");
OperationName opName = op->getName();
// For fallback ops, create a new operation name referencing the original op
// instead.
if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op))
opName =
OperationName(fallback.getOriginalOperationName(), op->getContext());

emitter.emitVarInt(numberingState.getNumber(opName), "op name ID");

// Emit a mask for the operation components. We need to fill this in later
// (when we actually know what needs to be emitted), so emit a placeholder for
Expand Down
17 changes: 13 additions & 4 deletions mlir/lib/Bytecode/Writer/IRNumbering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,16 @@ void IRNumberingState::number(Region &region) {
void IRNumberingState::number(Operation &op) {
// Number the components of an operation that won't be numbered elsewhere
// (e.g. we don't number operands, regions, or successors here).
number(op.getName());

// For fallback ops, create a new OperationName referencing the original op
// instead.
if (auto fallback = dyn_cast<FallbackBytecodeOpInterface>(op)) {
OperationName opName(fallback.getOriginalOperationName(), op.getContext());
number(opName, /*isOpaque=*/true);
} else {
number(op.getName(), /*isOpaque=*/false);
}

for (OpResult result : op.getResults()) {
valueIDs.try_emplace(result, nextValueID++);
number(result.getType());
Expand Down Expand Up @@ -457,7 +466,7 @@ void IRNumberingState::number(Operation &op) {
number(op.getLoc());
}

void IRNumberingState::number(OperationName opName) {
void IRNumberingState::number(OperationName opName, bool isOpaque) {
OpNameNumbering *&numbering = opNames[opName];
if (numbering) {
++numbering->refCount;
Expand All @@ -469,8 +478,8 @@ void IRNumberingState::number(OperationName opName) {
else
dialectNumber = &numberDialect(opName.getDialectNamespace());

numbering =
new (opNameAllocator.Allocate()) OpNameNumbering(dialectNumber, opName);
numbering = new (opNameAllocator.Allocate())
OpNameNumbering(dialectNumber, opName, isOpaque);
orderedOpNames.push_back(numbering);
}

Expand Down
11 changes: 7 additions & 4 deletions mlir/lib/Bytecode/Writer/IRNumbering.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,18 @@ struct TypeNumbering : public AttrTypeNumbering {

/// This class represents the numbering entry of an operation name.
struct OpNameNumbering {
OpNameNumbering(DialectNumbering *dialect, OperationName name)
: dialect(dialect), name(name) {}
OpNameNumbering(DialectNumbering *dialect, OperationName name, bool isOpaque)
: dialect(dialect), name(name), isOpaqueEntry(isOpaque) {}

/// The dialect of this value.
DialectNumbering *dialect;

/// The concrete name.
OperationName name;

/// This entry represents an opaque operation entry.
bool isOpaqueEntry = false;

/// The number assigned to this name.
unsigned number = 0;

Expand Down Expand Up @@ -210,7 +213,7 @@ class IRNumberingState {

/// Get the set desired bytecode version to emit.
int64_t getDesiredBytecodeVersion() const;

private:
/// This class is used to provide a fake dialect writer for numbering nested
/// attributes and types.
Expand All @@ -225,7 +228,7 @@ class IRNumberingState {
DialectNumbering &numberDialect(Dialect *dialect);
DialectNumbering &numberDialect(StringRef dialect);
void number(Operation &op);
void number(OperationName opName);
void number(OperationName opName, bool isOpaque);
void number(Region &region);
void number(Type type);

Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Bytecode/versioning/versioning-fallback.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// RUN: mlir-opt %s --emit-bytecode > %T/versioning-fallback.mlirbc
"test.versionedD"() <{
attribute = #test.compound_attr_no_reading<
noReadingNested = #test.compound_attr_no_reading_nested<
value = "foo",
payload = [24, "bar"]
>,
supportsReading = #test.attr_params<42, 24>
>
}> : () -> ()

// COM: check that versionedD was parsed as a fallback op.
// RUN: mlir-opt %T/versioning-fallback.mlirbc | FileCheck %s --check-prefix=CHECK-PARSE
// CHECK-PARSE: test.bytecode.fallback
// CHECK-PARSE-SAME: encodedReqdAttributes = [#test.bytecode_fallback<attrIndex = 100,
// CHECK-PARSE-SAME: encodedReqdAttributes = [#test.bytecode_fallback<attrIndex = 101,
// CHECK-PARSE-SAME: encodedReqdAttributes = ["foo", [24, "bar"]],
// CHECK-PARSE-SAME: #test.attr_params<42, 24>
// CHECK-PARSE-SAME: opname = "test.versionedD",
// CHECK-PARSE-SAME: opversion = 1

// COM: check that the bytecode roundtrip was successful
// RUN: mlir-opt %T/versioning-fallback.mlirbc --verify-roundtrip

// COM: check that the bytecode roundtrip is bitwise exact
// RUN: mlir-opt %T/versioning-fallback.mlirbc --emit-bytecode | diff %T/versioning-fallback.mlirbc -
Loading
Loading