Skip to content

Re-introduce Type Conversion on EmitC #121476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mlir/include/mlir/Conversion/SCFToEmitC/SCFToEmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
#define MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H

#include "mlir/Transforms/DialectConversion.h"
#include <memory>

namespace mlir {
Expand All @@ -19,7 +20,8 @@ class RewritePatternSet;
#include "mlir/Conversion/Passes.h.inc"

/// Collect a set of patterns to convert SCF operations to the EmitC dialect.
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns);
void populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter);
} // namespace mlir

#endif // MLIR_CONVERSION_SCFTOEMITC_SCFTOEMITC_H
1 change: 1 addition & 0 deletions mlir/lib/Conversion/SCFToEmitC/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRSCFToEmitC
LINK_LIBS PUBLIC
MLIRArithDialect
MLIREmitCDialect
MLIREmitCTransforms
MLIRSCFDialect
MLIRTransforms
)
206 changes: 140 additions & 66 deletions mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
Expand All @@ -39,21 +40,22 @@ struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {

// Lower scf::for to emitc::for, implementing result values using
// emitc::variable's updated within the loop body.
struct ForLowering : public OpRewritePattern<ForOp> {
using OpRewritePattern<ForOp>::OpRewritePattern;
struct ForLowering : public OpConversionPattern<ForOp> {
using OpConversionPattern<ForOp>::OpConversionPattern;

LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

// Create an uninitialized emitc::variable op for each result of the given op.
template <typename T>
static SmallVector<Value> createVariablesForResults(T op,
PatternRewriter &rewriter) {
SmallVector<Value> resultVariables;

static LogicalResult
createVariablesForResults(T op, const TypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
SmallVector<Value> &resultVariables) {
if (!op.getNumResults())
return resultVariables;
return success();

Location loc = op->getLoc();
MLIRContext *context = op.getContext();
Expand All @@ -62,21 +64,23 @@ static SmallVector<Value> createVariablesForResults(T op,
rewriter.setInsertionPoint(op);

for (OpResult result : op.getResults()) {
Type resultType = result.getType();
Type resultType = typeConverter->convertType(result.getType());
if (!resultType)
return rewriter.notifyMatchFailure(op, "result type conversion failed");
Type varType = emitc::LValueType::get(resultType);
emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
emitc::VariableOp var =
rewriter.create<emitc::VariableOp>(loc, varType, noInit);
resultVariables.push_back(var);
}

return resultVariables;
return success();
}

// Create a series of assign ops assigning given values to given variables at
// the current insertion point of given rewriter.
static void assignValues(ValueRange values, SmallVector<Value> &variables,
PatternRewriter &rewriter, Location loc) {
static void assignValues(ValueRange values, ValueRange variables,
ConversionPatternRewriter &rewriter, Location loc) {
for (auto [value, var] : llvm::zip(values, variables))
rewriter.create<emitc::AssignOp>(loc, var, value);
}
Expand All @@ -89,46 +93,58 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
});
}

static void lowerYield(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, scf::YieldOp yield) {
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
scf::YieldOp yield) {
Location loc = yield.getLoc();
ValueRange operands = yield.getOperands();

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);

assignValues(operands, resultVariables, rewriter, loc);
SmallVector<Value> yieldOperands;
if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
}

assignValues(yieldOperands, resultVariables, rewriter, loc);

rewriter.create<emitc::YieldOp>(loc);
rewriter.eraseOp(yield);

return success();
}

// Lower the contents of an scf::if/scf::index_switch regions to an
// emitc::if/emitc::switch region. The contents of the lowering region is
// moved into the respective lowered region, but the scf::yield is replaced not
// only with an emitc::yield, but also with a sequence of emitc::assign ops that
// set the yielded values into the result variables.
static void lowerRegion(SmallVector<Value> &resultVariables,
PatternRewriter &rewriter, Region &region,
Region &loweredRegion) {
static LogicalResult lowerRegion(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
lowerYield(resultVariables, rewriter, cast<scf::YieldOp>(terminator));
return lowerYield(op, resultVariables, rewriter,
cast<scf::YieldOp>(terminator));
}

LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
LogicalResult
ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = forOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables =
createVariablesForResults(forOp, rewriter);
SmallVector<Value> resultVariables;
if (failed(createVariablesForResults(forOp, getTypeConverter(), rewriter,
resultVariables)))
return rewriter.notifyMatchFailure(forOp,
"create variables for results failed");

assignValues(forOp.getInits(), resultVariables, rewriter, loc);
assignValues(adaptor.getInitArgs(), resultVariables, rewriter, loc);

emitc::ForOp loweredFor = rewriter.create<emitc::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep());
loc, adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep());

Block *loweredBody = loweredFor.getBody();

Expand All @@ -143,13 +159,27 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,

rewriter.restoreInsertionPoint(ip);

// Convert the original region types into the new types by adding unrealized
// casts in the beginning of the loop. This performs the conversion in place.
if (failed(rewriter.convertRegionTypes(&forOp.getRegion(),
*getTypeConverter(), nullptr))) {
return rewriter.notifyMatchFailure(forOp, "region types conversion failed");
}

// Register the replacements for the block arguments and inline the body of
// the scf.for loop into the body of the emitc::for loop.
Block *scfBody = &(forOp.getRegion().front());
SmallVector<Value> replacingValues;
replacingValues.push_back(loweredFor.getInductionVar());
replacingValues.append(iterArgsValues.begin(), iterArgsValues.end());
rewriter.mergeBlocks(scfBody, loweredBody, replacingValues);

rewriter.mergeBlocks(forOp.getBody(), loweredBody, replacingValues);
lowerYield(resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));
auto result = lowerYield(forOp, resultVariables, rewriter,
cast<scf::YieldOp>(loweredBody->getTerminator()));

if (failed(result)) {
return result;
}

// Load variables into SSA values after the for loop.
SmallVector<Value> resultValues = loadValues(resultVariables, rewriter, loc);
Expand All @@ -160,38 +190,66 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,

// Lower scf::if to emitc::if, implementing result values as emitc::variable's
// updated within the then and else regions.
struct IfLowering : public OpRewritePattern<IfOp> {
using OpRewritePattern<IfOp>::OpRewritePattern;
struct IfLowering : public OpConversionPattern<IfOp> {
using OpConversionPattern<IfOp>::OpConversionPattern;

LogicalResult matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

} // namespace

LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
LogicalResult
IfLowering::matchAndRewrite(IfOp ifOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = ifOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the then & else regions.
SmallVector<Value> resultVariables =
createVariablesForResults(ifOp, rewriter);

Region &thenRegion = ifOp.getThenRegion();
Region &elseRegion = ifOp.getElseRegion();
SmallVector<Value> resultVariables;
if (failed(createVariablesForResults(ifOp, getTypeConverter(), rewriter,
resultVariables)))
return rewriter.notifyMatchFailure(ifOp,
"create variables for results failed");

// Utility function to lower the contents of an scf::if region to an emitc::if
// region. The contents of the scf::if regions is moved into the respective
// emitc::if regions, but the scf::yield is replaced not only with an
// emitc::yield, but also with a sequence of emitc::assign ops that set the
// yielded values into the result variables.
auto lowerRegion = [&resultVariables, &rewriter,
&ifOp](Region &region, Region &loweredRegion) {
rewriter.inlineRegionBefore(region, loweredRegion, loweredRegion.end());
Operation *terminator = loweredRegion.back().getTerminator();
auto result = lowerYield(ifOp, resultVariables, rewriter,
cast<scf::YieldOp>(terminator));
if (failed(result)) {
return result;
}
return success();
};

Region &thenRegion = adaptor.getThenRegion();
Region &elseRegion = adaptor.getElseRegion();

bool hasElseBlock = !elseRegion.empty();

auto loweredIf =
rewriter.create<emitc::IfOp>(loc, ifOp.getCondition(), false, false);
rewriter.create<emitc::IfOp>(loc, adaptor.getCondition(), false, false);

Region &loweredThenRegion = loweredIf.getThenRegion();
lowerRegion(resultVariables, rewriter, thenRegion, loweredThenRegion);
auto result = lowerRegion(thenRegion, loweredThenRegion);
if (failed(result)) {
return result;
}

if (hasElseBlock) {
Region &loweredElseRegion = loweredIf.getElseRegion();
lowerRegion(resultVariables, rewriter, elseRegion, loweredElseRegion);
auto result = lowerRegion(elseRegion, loweredElseRegion);
if (failed(result)) {
return result;
}
}

rewriter.setInsertionPointAfter(ifOp);
Expand All @@ -203,37 +261,46 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,

// Lower scf::index_switch to emitc::switch, implementing result values as
// emitc::variable's updated within the case and default regions.
struct IndexSwitchOpLowering : public OpRewritePattern<IndexSwitchOp> {
using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
struct IndexSwitchOpLowering : public OpConversionPattern<IndexSwitchOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult matchAndRewrite(IndexSwitchOp indexSwitchOp,
PatternRewriter &rewriter) const override;
LogicalResult
matchAndRewrite(IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

LogicalResult
IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
PatternRewriter &rewriter) const {
LogicalResult IndexSwitchOpLowering::matchAndRewrite(
IndexSwitchOp indexSwitchOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = indexSwitchOp.getLoc();

// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the case and default regions.
SmallVector<Value> resultVariables =
createVariablesForResults(indexSwitchOp, rewriter);
SmallVector<Value> resultVariables;
if (failed(createVariablesForResults(indexSwitchOp, getTypeConverter(),
rewriter, resultVariables))) {
return rewriter.notifyMatchFailure(indexSwitchOp,
"create variables for results failed");
}

auto loweredSwitch = rewriter.create<emitc::SwitchOp>(
loc, indexSwitchOp.getArg(), indexSwitchOp.getCases(),
indexSwitchOp.getNumCases());
loc, adaptor.getArg(), adaptor.getCases(), indexSwitchOp.getNumCases());

// Lowering all case regions.
for (auto pair : llvm::zip(indexSwitchOp.getCaseRegions(),
loweredSwitch.getCaseRegions())) {
lowerRegion(resultVariables, rewriter, std::get<0>(pair),
std::get<1>(pair));
for (auto pair :
llvm::zip(adaptor.getCaseRegions(), loweredSwitch.getCaseRegions())) {
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
*std::get<0>(pair), std::get<1>(pair)))) {
return failure();
}
}

// Lowering default region.
lowerRegion(resultVariables, rewriter, indexSwitchOp.getDefaultRegion(),
loweredSwitch.getDefaultRegion());
if (failed(lowerRegion(indexSwitchOp, resultVariables, rewriter,
adaptor.getDefaultRegion(),
loweredSwitch.getDefaultRegion()))) {
return failure();
}

rewriter.setInsertionPointAfter(indexSwitchOp);
SmallVector<Value> results = loadValues(resultVariables, rewriter, loc);
Expand All @@ -242,15 +309,22 @@ IndexSwitchOpLowering::matchAndRewrite(IndexSwitchOp indexSwitchOp,
return success();
}

void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns) {
patterns.add<ForLowering>(patterns.getContext());
patterns.add<IfLowering>(patterns.getContext());
patterns.add<IndexSwitchOpLowering>(patterns.getContext());
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ForLowering>(typeConverter, patterns.getContext());
patterns.add<IfLowering>(typeConverter, patterns.getContext());
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
}

void SCFToEmitCPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateSCFToEmitCConversionPatterns(patterns);
TypeConverter typeConverter;
// Fallback converter
// See note https://mlir.llvm.org/docs/DialectConversion/#type-converter
// Type converters are called most to least recently inserted
typeConverter.addConversion([](Type t) { return t; });
populateEmitCSizeTTypeConversions(typeConverter);
populateSCFToEmitCConversionPatterns(patterns, typeConverter);

// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
Expand Down
Loading
Loading