-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR] Testing arith-to-emitc conversions using opaque types #137936
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
40f1409
94ae63f
da471d1
665eada
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/EmitC/IR/EmitC.h" | ||
#include "mlir/Dialect/EmitC/Transforms/TypeConversions.h" | ||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
@@ -39,8 +40,17 @@ class ArithConstantOpConversionPattern | |
Type newTy = this->getTypeConverter()->convertType(arithConst.getType()); | ||
if (!newTy) | ||
return rewriter.notifyMatchFailure(arithConst, "type conversion failed"); | ||
|
||
std::optional<Attribute> opAttrib = | ||
this->getTypeConverter()->convertTypeAttribute( | ||
adaptor.getValue().getType(), adaptor.getValue()); | ||
if (!opAttrib) { | ||
return rewriter.notifyMatchFailure(arithConst, | ||
"attribute conversion failed"); | ||
} | ||
|
||
rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy, | ||
adaptor.getValue()); | ||
opAttrib.value()); | ||
return success(); | ||
} | ||
}; | ||
|
@@ -67,6 +77,7 @@ Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) { | |
|
||
/// Insert a cast operation to type \p ty if \p val does not have this type. | ||
Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) { | ||
assert(emitc::isSupportedEmitCType(val.getType())); | ||
return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val); | ||
} | ||
|
||
|
@@ -78,7 +89,7 @@ class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> { | |
matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
if (!isa<FloatType>(adaptor.getRhs().getType())) { | ||
if (!emitc::isFloatOrOpaqueType(adaptor.getRhs().getType())) { | ||
return rewriter.notifyMatchFailure(op.getLoc(), | ||
"cmpf currently only supported on " | ||
"floats, not tensors/vectors thereof"); | ||
|
@@ -273,7 +284,8 @@ class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> { | |
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Type type = adaptor.getLhs().getType(); | ||
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) { | ||
if (!type || !(emitc::isIntegerOrOpaqueType(type) || | ||
emitc::isPointerWideType(type))) { | ||
return rewriter.notifyMatchFailure( | ||
op, "expected integer or size_t/ssize_t/ptrdiff_t type"); | ||
} | ||
|
@@ -307,7 +319,7 @@ class NegFOpConversion : public OpConversionPattern<arith::NegFOp> { | |
"negf currently only supports scalar types, not vectors or tensors"); | ||
} | ||
|
||
if (!emitc::isSupportedFloatType(adaptedOpType)) { | ||
if (!emitc::isFloatOrOpaqueType(adaptedOpType)) { | ||
return rewriter.notifyMatchFailure( | ||
op.getLoc(), "floating-point type is not supported by EmitC"); | ||
} | ||
|
@@ -328,7 +340,7 @@ class CastConversion : public OpConversionPattern<ArithOp> { | |
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Type opReturnType = this->getTypeConverter()->convertType(op.getType()); | ||
if (!opReturnType || !(isa<IntegerType>(opReturnType) || | ||
if (!opReturnType || !(emitc::isIntegerOrOpaqueType(opReturnType) || | ||
emitc::isPointerWideType(opReturnType))) | ||
return rewriter.notifyMatchFailure( | ||
op, "expected integer or size_t/ssize_t/ptrdiff_t result type"); | ||
|
@@ -339,7 +351,7 @@ class CastConversion : public OpConversionPattern<ArithOp> { | |
} | ||
|
||
Type operandType = adaptor.getIn().getType(); | ||
if (!operandType || !(isa<IntegerType>(operandType) || | ||
if (!operandType || !(emitc::isIntegerOrOpaqueType(operandType) || | ||
emitc::isPointerWideType(operandType))) | ||
return rewriter.notifyMatchFailure( | ||
op, "expected integer or size_t/ssize_t/ptrdiff_t operand type"); | ||
|
@@ -433,16 +445,17 @@ class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> { | |
if (!newRetTy) | ||
return rewriter.notifyMatchFailure(uiBinOp, | ||
"converting result type failed"); | ||
if (!isa<IntegerType>(newRetTy)) { | ||
|
||
if (!emitc::isIntegerOrOpaqueType(newRetTy)) { | ||
return rewriter.notifyMatchFailure(uiBinOp, "expected integer type"); | ||
} | ||
Type unsignedType = | ||
adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true); | ||
if (!unsignedType) | ||
return rewriter.notifyMatchFailure(uiBinOp, | ||
"converting result type failed"); | ||
Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType); | ||
Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType); | ||
Value lhsAdapted = adaptValueType(adaptor.getLhs(), rewriter, unsignedType); | ||
Value rhsAdapted = adaptValueType(adaptor.getRhs(), rewriter, unsignedType); | ||
|
||
auto newDivOp = | ||
rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType, | ||
|
@@ -463,7 +476,8 @@ class IntegerOpConversion final : public OpConversionPattern<ArithOp> { | |
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Type type = this->getTypeConverter()->convertType(op.getType()); | ||
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) { | ||
if (!type || !(emitc::isIntegerOrOpaqueType(type) || | ||
emitc::isPointerWideType(type))) { | ||
return rewriter.notifyMatchFailure( | ||
op, "expected integer or size_t/ssize_t/ptrdiff_t type"); | ||
} | ||
|
@@ -506,7 +520,7 @@ class BitwiseOpConversion : public OpConversionPattern<ArithOp> { | |
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Type type = this->getTypeConverter()->convertType(op.getType()); | ||
if (!isa_and_nonnull<IntegerType>(type)) { | ||
if (!type || !emitc::isIntegerOrOpaqueType(type)) { | ||
return rewriter.notifyMatchFailure( | ||
op, | ||
"expected integer type, vector/tensor support not yet implemented"); | ||
|
@@ -546,7 +560,9 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> { | |
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Type type = this->getTypeConverter()->convertType(op.getType()); | ||
if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) { | ||
bool retIsOpaque = isa_and_nonnull<emitc::OpaqueType>(type); | ||
if (!type || (!retIsOpaque && !(isa<IntegerType>(type) || | ||
emitc::isPointerWideType(type)))) { | ||
return rewriter.notifyMatchFailure( | ||
op, "expected integer or size_t/ssize_t/ptrdiff_t type"); | ||
} | ||
|
@@ -572,21 +588,33 @@ class ShiftOpConversion : public OpConversionPattern<ArithOp> { | |
op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight}); | ||
width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight, | ||
sizeOfCall.getResult(0)); | ||
} else { | ||
} else if (!retIsOpaque) { | ||
width = rewriter.create<emitc::ConstantOp>( | ||
op.getLoc(), rhsType, | ||
rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth())); | ||
} else { | ||
width = rewriter.create<emitc::ConstantOp>( | ||
op.getLoc(), rhsType, | ||
emitc::OpaqueAttr::get(rhsType.getContext(), | ||
"opaque_shift_bitwidth")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If opaque types are used, the bitwidth, which is needed for the shiftOp, can't be determined. So the opaque attribute serves as a reference point for where to enter the bitwidth of the type later on. |
||
} | ||
|
||
Value excessCheck = rewriter.create<emitc::CmpOp>( | ||
op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width); | ||
|
||
// Any concrete value is a valid refinement of poison. | ||
Value poison = rewriter.create<emitc::ConstantOp>( | ||
op.getLoc(), arithmeticType, | ||
(isa<IntegerType>(arithmeticType) | ||
? rewriter.getIntegerAttr(arithmeticType, 0) | ||
: rewriter.getIndexAttr(0))); | ||
Value poison; | ||
if (retIsOpaque) { | ||
poison = rewriter.create<emitc::ConstantOp>( | ||
op.getLoc(), arithmeticType, | ||
emitc::OpaqueAttr::get(rhsType.getContext(), "opaque_shift_poison")); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, where is this defined and why is it needed? |
||
} else { | ||
poison = rewriter.create<emitc::ConstantOp>( | ||
op.getLoc(), arithmeticType, | ||
(isa<IntegerType>(arithmeticType) | ||
? rewriter.getIntegerAttr(arithmeticType, 0) | ||
: rewriter.getIndexAttr(0))); | ||
} | ||
|
||
emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>( | ||
op.getLoc(), arithmeticType, /*do_not_inline=*/false); | ||
|
@@ -655,27 +683,31 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> { | |
ConversionPatternRewriter &rewriter) const override { | ||
|
||
Type operandType = adaptor.getIn().getType(); | ||
if (!emitc::isSupportedFloatType(operandType)) | ||
if (!emitc::isFloatOrOpaqueType(operandType)) | ||
return rewriter.notifyMatchFailure(castOp, | ||
"unsupported cast source type"); | ||
|
||
Type dstType = this->getTypeConverter()->convertType(castOp.getType()); | ||
if (!dstType) | ||
return rewriter.notifyMatchFailure(castOp, "type conversion failed"); | ||
|
||
Type actualResultType = dstType; | ||
|
||
// Float-to-i1 casts are not supported: any value with 0 < value < 1 must be | ||
// truncated to 0, whereas a boolean conversion would return true. | ||
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1)) | ||
return rewriter.notifyMatchFailure(castOp, | ||
"unsupported cast destination type"); | ||
|
||
// Convert to unsigned if it's the "ui" variant | ||
// Signless is interpreted as signed, so no need to cast for "si" | ||
Type actualResultType = dstType; | ||
if (isa<arith::FPToUIOp>(castOp)) { | ||
actualResultType = | ||
rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(), | ||
/*isSigned=*/false); | ||
bool dstIsOpaque = isa<emitc::OpaqueType>(dstType); | ||
if (!dstIsOpaque) { | ||
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1)) | ||
return rewriter.notifyMatchFailure(castOp, | ||
"unsupported cast destination type"); | ||
|
||
// Convert to unsigned if it's the "ui" variant | ||
// Signless is interpreted as signed, so no need to cast for "si" | ||
if (isa<arith::FPToUIOp>(castOp)) { | ||
actualResultType = | ||
rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(), | ||
/*isSigned=*/false); | ||
} | ||
} | ||
|
||
Value result = rewriter.create<emitc::CastOp>( | ||
|
@@ -702,22 +734,24 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> { | |
ConversionPatternRewriter &rewriter) const override { | ||
// Vectors in particular are not supported | ||
Type operandType = adaptor.getIn().getType(); | ||
if (!emitc::isSupportedIntegerType(operandType)) | ||
bool opIsOpaque = isa<emitc::OpaqueType>(operandType); | ||
|
||
if (!(opIsOpaque || emitc::isSupportedIntegerType(operandType))) | ||
return rewriter.notifyMatchFailure(castOp, | ||
"unsupported cast source type"); | ||
|
||
Type dstType = this->getTypeConverter()->convertType(castOp.getType()); | ||
if (!dstType) | ||
return rewriter.notifyMatchFailure(castOp, "type conversion failed"); | ||
|
||
if (!emitc::isSupportedFloatType(dstType)) | ||
if (!emitc::isFloatOrOpaqueType(dstType)) | ||
return rewriter.notifyMatchFailure(castOp, | ||
"unsupported cast destination type"); | ||
|
||
// Convert to unsigned if it's the "ui" variant | ||
// Signless is interpreted as signed, so no need to cast for "si" | ||
Type actualOperandType = operandType; | ||
if (isa<arith::UIToFPOp>(castOp)) { | ||
if (!opIsOpaque && isa<arith::UIToFPOp>(castOp)) { | ||
actualOperandType = | ||
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(), | ||
/*isSigned=*/false); | ||
|
@@ -745,7 +779,7 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> { | |
ConversionPatternRewriter &rewriter) const override { | ||
// Vectors in particular are not supported. | ||
Type operandType = adaptor.getIn().getType(); | ||
if (!emitc::isSupportedFloatType(operandType)) | ||
if (!emitc::isFloatOrOpaqueType(operandType)) | ||
return rewriter.notifyMatchFailure(castOp, | ||
"unsupported cast source type"); | ||
if (auto roundingModeOp = | ||
|
@@ -759,7 +793,7 @@ class FpCastOpConversion : public OpConversionPattern<CastOp> { | |
if (!dstType) | ||
return rewriter.notifyMatchFailure(castOp, "type conversion failed"); | ||
|
||
if (!emitc::isSupportedFloatType(dstType)) | ||
if (!emitc::isFloatOrOpaqueType(dstType)) | ||
return rewriter.notifyMatchFailure(castOp, | ||
"unsupported cast destination type"); | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,9 +30,42 @@ namespace { | |
struct ConvertArithToEmitC | ||
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> { | ||
void runOnOperation() override; | ||
|
||
/// Applies conversion to opaque types for f80 and i80 types, both unsupported | ||
/// in emitc. Used to test the pass with opaque types. | ||
void populateOpaqueTypeConversions(TypeConverter &converter); | ||
}; | ||
} // namespace | ||
|
||
void ConvertArithToEmitC::populateOpaqueTypeConversions( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand why these types should be unconditionally legalized, and why only for bitwidth 80? |
||
TypeConverter &converter) { | ||
converter.addConversion([](Type type) -> std::optional<Type> { | ||
if (type.isF80()) | ||
return emitc::OpaqueType::get(type.getContext(), "f80"); | ||
if (type.isInteger() && type.getIntOrFloatBitWidth() == 80) | ||
return emitc::OpaqueType::get(type.getContext(), "i80"); | ||
return type; | ||
}); | ||
|
||
converter.addTypeAttributeConversion( | ||
[](Type type, | ||
Attribute attrToConvert) -> TypeConverter::AttributeConversionResult { | ||
if (auto floatAttr = llvm::dyn_cast<FloatAttr>(attrToConvert)) { | ||
if (floatAttr.getType().isF80()) { | ||
return emitc::OpaqueAttr::get(type.getContext(), "f80"); | ||
} | ||
return attrToConvert; | ||
} | ||
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attrToConvert)) { | ||
if (intAttr.getType().isInteger() && | ||
intAttr.getType().getIntOrFloatBitWidth() == 80) { | ||
return emitc::OpaqueAttr::get(type.getContext(), "i80"); | ||
} | ||
} | ||
return attrToConvert; | ||
}); | ||
} | ||
|
||
void ConvertArithToEmitC::runOnOperation() { | ||
ConversionTarget target(getContext()); | ||
|
||
|
@@ -42,8 +75,8 @@ void ConvertArithToEmitC::runOnOperation() { | |
RewritePatternSet patterns(&getContext()); | ||
|
||
TypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
|
||
populateOpaqueTypeConversions(typeConverter); | ||
populateArithToEmitCPatterns(typeConverter, patterns); | ||
|
||
if (failed( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could invert the condition and swap the branches.