Skip to content

[MLIR] Testing arith-to-emitc conversions using opaque types #137936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ bool isIntegerIndexOrOpaqueType(Type type);
/// Determines whether \p type is a valid floating-point type in EmitC.
bool isSupportedFloatType(mlir::Type type);

/// Determines whether \p type is a valid floating-point or opaque type in
/// EmitC.
bool isFloatOrOpaqueType(mlir::Type type);

/// Determines whether \p type is a valid integer or opaque type in
/// EmitC.
bool isIntegerOrOpaqueType(mlir::Type type);

/// Determines whether \p type is a emitc.size_t/ssize_t type.
bool isPointerWideType(mlir::Type type);

Expand Down
104 changes: 69 additions & 35 deletions mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -39,8 +40,17 @@ class ArithConstantOpConversionPattern
Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
if (!newTy)
return rewriter.notifyMatchFailure(arithConst, "type conversion failed");

std::optional<Attribute> opAttrib =
this->getTypeConverter()->convertTypeAttribute(
adaptor.getValue().getType(), adaptor.getValue());
if (!opAttrib) {
return rewriter.notifyMatchFailure(arithConst,
"attribute conversion failed");
}

rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
adaptor.getValue());
opAttrib.value());
return success();
}
};
Expand All @@ -67,6 +77,7 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {

/// Insert a cast operation to type \p ty if \p val does not have this type.
Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
assert(emitc::isSupportedEmitCType(val.getType()));
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
}

Expand All @@ -78,7 +89,7 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

if (!isa<FloatType>(adaptor.getRhs().getType())) {
if (!emitc::isFloatOrOpaqueType(adaptor.getRhs().getType())) {
return rewriter.notifyMatchFailure(op.getLoc(),
"cmpf currently only supported on "
"floats, not tensors/vectors thereof");
Expand Down Expand Up @@ -273,7 +284,8 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = adaptor.getLhs().getType();
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
if (!type || !(emitc::isIntegerOrOpaqueType(type) ||
emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
Expand Down Expand Up @@ -307,7 +319,7 @@ class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
"negf currently only supports scalar types, not vectors or tensors");
}

if (!emitc::isSupportedFloatType(adaptedOpType)) {
if (!emitc::isFloatOrOpaqueType(adaptedOpType)) {
return rewriter.notifyMatchFailure(
op.getLoc(), "floating-point type is not supported by EmitC");
}
Expand All @@ -328,7 +340,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type opReturnType = this->getTypeConverter()->convertType(op.getType());
if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
if (!opReturnType || !(emitc::isIntegerOrOpaqueType(opReturnType) ||
emitc::isPointerWideType(opReturnType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
Expand All @@ -339,7 +351,7 @@ class CastConversion : public OpConversionPattern<ArithOp> {
}

Type operandType = adaptor.getIn().getType();
if (!operandType || !(isa<IntegerType>(operandType) ||
if (!operandType || !(emitc::isIntegerOrOpaqueType(operandType) ||
emitc::isPointerWideType(operandType)))
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
Expand Down Expand Up @@ -433,16 +445,17 @@ class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
if (!newRetTy)
return rewriter.notifyMatchFailure(uiBinOp,
"converting result type failed");
if (!isa<IntegerType>(newRetTy)) {

if (!emitc::isIntegerOrOpaqueType(newRetTy)) {
return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
}
Type unsignedType =
adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
if (!unsignedType)
return rewriter.notifyMatchFailure(uiBinOp,
"converting result type failed");
Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
Value lhsAdapted = adaptValueType(adaptor.getLhs(), rewriter, unsignedType);
Value rhsAdapted = adaptValueType(adaptor.getRhs(), rewriter, unsignedType);

auto newDivOp =
rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
Expand All @@ -463,7 +476,8 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
if (!type || !(emitc::isIntegerOrOpaqueType(type) ||
emitc::isPointerWideType(type))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
Expand Down Expand Up @@ -506,7 +520,7 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!isa_and_nonnull<IntegerType>(type)) {
if (!type || !emitc::isIntegerOrOpaqueType(type)) {
return rewriter.notifyMatchFailure(
op,
"expected integer type, vector/tensor support not yet implemented");
Expand Down Expand Up @@ -546,7 +560,9 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
ConversionPatternRewriter &rewriter) const override {

Type type = this->getTypeConverter()->convertType(op.getType());
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
bool retIsOpaque = isa_and_nonnull<emitc::OpaqueType>(type);
if (!type || (!retIsOpaque && !(isa<IntegerType>(type) ||
emitc::isPointerWideType(type)))) {
return rewriter.notifyMatchFailure(
op, "expected integer or size_t/ssize_t/ptrdiff_t type");
}
Expand All @@ -572,21 +588,33 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> {
op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
sizeOfCall.getResult(0));
} else {
} else if (!retIsOpaque) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could invert the condition and swap the branches.

width = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType,
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
} else {
width = rewriter.create<emitc::ConstantOp>(
op.getLoc(), rhsType,
emitc::OpaqueAttr::get(rhsType.getContext(),
"opaque_shift_bitwidth"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does opaque_shift_bitwidth come from?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If opaque types are used, the bitwidth, which is needed for the shiftOp, can't be determined. So the opaque attribute serves as a reference point for where to enter the bitwidth of the type later on.

}

Value excessCheck = rewriter.create<emitc::CmpOp>(
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);

// Any concrete value is a valid refinement of poison.
Value poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
Value poison;
if (retIsOpaque) {
poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
emitc::OpaqueAttr::get(rhsType.getContext(), "opaque_shift_poison"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, where is this defined and why is it needed?

} else {
poison = rewriter.create<emitc::ConstantOp>(
op.getLoc(), arithmeticType,
(isa<IntegerType>(arithmeticType)
? rewriter.getIntegerAttr(arithmeticType, 0)
: rewriter.getIndexAttr(0)));
}

emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
op.getLoc(), arithmeticType, /*do_not_inline=*/false);
Expand Down Expand Up @@ -655,27 +683,31 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
ConversionPatternRewriter &rewriter) const override {

Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedFloatType(operandType))
if (!emitc::isFloatOrOpaqueType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");

Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");

Type actualResultType = dstType;

// Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
// truncated to 0, whereas a boolean conversion would return true.
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
Type actualResultType = dstType;
if (isa<arith::FPToUIOp>(castOp)) {
actualResultType =
rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
bool dstIsOpaque = isa<emitc::OpaqueType>(dstType);
if (!dstIsOpaque) {
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
if (isa<arith::FPToUIOp>(castOp)) {
actualResultType =
rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
}
}

Value result = rewriter.create<emitc::CastOp>(
Expand All @@ -702,22 +734,24 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
ConversionPatternRewriter &rewriter) const override {
// Vectors in particular are not supported
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedIntegerType(operandType))
bool opIsOpaque = isa<emitc::OpaqueType>(operandType);

if (!(opIsOpaque || emitc::isSupportedIntegerType(operandType)))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");

Type dstType = this->getTypeConverter()->convertType(castOp.getType());
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");

if (!emitc::isSupportedFloatType(dstType))
if (!emitc::isFloatOrOpaqueType(dstType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

// Convert to unsigned if it's the "ui" variant
// Signless is interpreted as signed, so no need to cast for "si"
Type actualOperandType = operandType;
if (isa<arith::UIToFPOp>(castOp)) {
if (!opIsOpaque && isa<arith::UIToFPOp>(castOp)) {
actualOperandType =
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
/*isSigned=*/false);
Expand Down Expand Up @@ -745,7 +779,7 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> {
ConversionPatternRewriter &rewriter) const override {
// Vectors in particular are not supported.
Type operandType = adaptor.getIn().getType();
if (!emitc::isSupportedFloatType(operandType))
if (!emitc::isFloatOrOpaqueType(operandType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast source type");
if (auto roundingModeOp =
Expand All @@ -759,7 +793,7 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> {
if (!dstType)
return rewriter.notifyMatchFailure(castOp, "type conversion failed");

if (!emitc::isSupportedFloatType(dstType))
if (!emitc::isFloatOrOpaqueType(dstType))
return rewriter.notifyMatchFailure(castOp,
"unsupported cast destination type");

Expand Down
35 changes: 34 additions & 1 deletion mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,42 @@ namespace {
struct ConvertArithToEmitC
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
void runOnOperation() override;

/// Applies conversion to opaque types for f80 and i80 types, both unsupported
/// in emitc. Used to test the pass with opaque types.
void populateOpaqueTypeConversions(TypeConverter &converter);
};
} // namespace

void ConvertArithToEmitC::populateOpaqueTypeConversions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why these types should be unconditionally legalized, and why only for bitwidth 80?

TypeConverter &converter) {
converter.addConversion([](Type type) -> std::optional<Type> {
if (type.isF80())
return emitc::OpaqueType::get(type.getContext(), "f80");
if (type.isInteger() && type.getIntOrFloatBitWidth() == 80)
return emitc::OpaqueType::get(type.getContext(), "i80");
return type;
});

converter.addTypeAttributeConversion(
[](Type type,
Attribute attrToConvert) -> TypeConverter::AttributeConversionResult {
if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attrToConvert)) {
if (floatAttr.getType().isF80()) {
return emitc::OpaqueAttr::get(type.getContext(), "f80");
}
return attrToConvert;
}
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attrToConvert)) {
if (intAttr.getType().isInteger() &&
intAttr.getType().getIntOrFloatBitWidth() == 80) {
return emitc::OpaqueAttr::get(type.getContext(), "i80");
}
}
return attrToConvert;
});
}

void ConvertArithToEmitC::runOnOperation() {
ConversionTarget target(getContext());

Expand All @@ -42,8 +75,8 @@ void ConvertArithToEmitC::runOnOperation() {
RewritePatternSet patterns(&getContext());

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });

populateOpaqueTypeConversions(typeConverter);
populateArithToEmitCPatterns(typeConverter, patterns);

if (failed(
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,14 @@ bool mlir::emitc::isSupportedFloatType(Type type) {
return false;
}

bool mlir::emitc::isIntegerOrOpaqueType(Type type) {
return isa<emitc::OpaqueType>(type) || isSupportedIntegerType(type);
}

bool mlir::emitc::isFloatOrOpaqueType(Type type) {
return isa<emitc::OpaqueType>(type) || isSupportedFloatType(type);
}

bool mlir::emitc::isPointerWideType(Type type) {
return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
type);
Expand Down
24 changes: 0 additions & 24 deletions mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
return %t: vector<5xi32>
}

// -----
func.func @arith_cast_f80(%arg0: f80) -> i32 {
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
%t = arith.fptosi %arg0 : f80 to i32
return %t: i32
}

// -----

func.func @arith_cast_f128(%arg0: f128) -> i32 {
Expand All @@ -29,15 +22,6 @@ func.func @arith_cast_f128(%arg0: f128) -> i32 {
return %t: i32
}


// -----

func.func @arith_cast_to_f80(%arg0: i32) -> f80 {
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
%t = arith.sitofp %arg0 : i32 to f80
return %t: f80
}

// -----

func.func @arith_cast_to_f128(%arg0: i32) -> f128 {
Expand Down Expand Up @@ -80,14 +64,6 @@ func.func @arith_cmpf_tensor(%arg0: tensor<5xf32>, %arg1: tensor<5xf32>) -> tens

// -----

func.func @arith_negf_f80(%arg0: f80) -> f80 {
// expected-error @+1 {{failed to legalize operation 'arith.negf'}}
%n = arith.negf %arg0 : f80
return %n: f80
}

// -----

func.func @arith_negf_tensor(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// expected-error @+1 {{failed to legalize operation 'arith.negf'}}
%n = arith.negf %arg0 : tensor<5xf32>
Expand Down
Loading