Skip to content

[RFC][mlir] Conditional support for fast-math attributes. #125620

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,21 @@ def fir_CallOp : fir_Op<"call",
llvm::cast<mlir::SymbolRefAttr>(callee));
setOperand(0, llvm::cast<mlir::Value>(callee));
}

/// Always allow FastMathFlags for fir.call's.
/// It is required to be able to propagate the call site's
/// FastMathFlags to the operations resulting from inlining
/// (if any) of a fir.call (see SimplifyIntrinsics pass).
/// We could analyze the arguments' data types to see if there are
/// any floating point types, but this is unreliable. For example,
/// the runtime calls mostly take !fir.box<none> arguments,
/// and tracking them to the definitions may be not easy.
/// TODO: this should be restricted to fir.runtime calls,
/// because FastMathFlags for the user calls must come
/// from the function body, not the call site.
bool isArithFastMathApplicable() {
return true;
}
}];
}

Expand Down Expand Up @@ -2672,6 +2687,15 @@ def fir_CmpcOp : fir_Op<"cmpc",
}

static mlir::arith::CmpFPredicate getPredicateByName(llvm::StringRef name);

/// Always allow FastMathFlags on fir.cmpc.
/// It does not produce a floating point result, but
/// LLVM is currently relying on fast-math flags attached
/// to floating point comparison.
/// This can be removed whenever LLVM stops doing it.
bool isArithFastMathApplicable() {
return true;
}
}];
}

Expand Down Expand Up @@ -2735,6 +2759,8 @@ def fir_ConvertOp : fir_SimpleOneResultOp<"convert", [NoMemoryEffect]> {
static bool isPointerCompatible(mlir::Type ty);
static bool canBeConverted(mlir::Type inType, mlir::Type outType);
static bool areVectorsCompatible(mlir::Type inTy, mlir::Type outTy);

// FIXME: fir.convert should support ArithFastMathInterface.
}];
let hasCanonicalizer = 1;
}
Expand Down
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ bool mayHaveAllocatableComponent(mlir::Type ty);
/// Scalar integer or a sequence of integers (via boxed array or expr).
bool isFortranIntegerScalarOrArrayObject(mlir::Type type);

/// Return true iff FastMathFlagsAttr is applicable
/// to the given HLFIR dialect operation that supports
/// ArithFastMathInterface.
bool isArithFastMathApplicable(mlir::Operation *op);

} // namespace hlfir

#endif // FORTRAN_OPTIMIZER_HLFIR_HLFIRDIALECT_H
54 changes: 54 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/HLFIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,12 @@ def hlfir_MaxvalOp : hlfir_Op<"maxval", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
Expand Down Expand Up @@ -461,6 +467,12 @@ def hlfir_MinvalOp : hlfir_Op<"minval", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
Expand All @@ -487,6 +499,12 @@ def hlfir_MinlocOp : hlfir_Op<"minloc", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
Expand All @@ -513,6 +531,12 @@ def hlfir_MaxlocOp : hlfir_Op<"maxloc", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
Expand All @@ -539,6 +563,12 @@ def hlfir_ProductOp : hlfir_Op<"product", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_SetLengthOp : hlfir_Op<"set_length",
Expand Down Expand Up @@ -604,6 +634,12 @@ def hlfir_SumOp : hlfir_Op<"sum", [AttrSizedOperandSegments,
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_DotProductOp : hlfir_Op<"dot_product",
Expand All @@ -628,6 +664,12 @@ def hlfir_DotProductOp : hlfir_Op<"dot_product",
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_MatmulOp : hlfir_Op<"matmul",
Expand Down Expand Up @@ -655,6 +697,12 @@ def hlfir_MatmulOp : hlfir_Op<"matmul",
let hasCanonicalizeMethod = 1;

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_TransposeOp : hlfir_Op<"transpose",
Expand Down Expand Up @@ -697,6 +745,12 @@ def hlfir_MatmulTransposeOp : hlfir_Op<"matmul_transpose",
}];

let hasVerifier = 1;

let extraClassDeclaration = [{
bool isArithFastMathApplicable() {
return hlfir::isArithFastMathApplicable(getOperation());
}
}];
}

def hlfir_CShiftOp
Expand Down
4 changes: 1 addition & 3 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,7 @@ mlir::Value fir::FirOpBuilder::genAbsentOp(mlir::Location loc,

void fir::FirOpBuilder::setCommonAttributes(mlir::Operation *op) const {
auto fmi = mlir::dyn_cast<mlir::arith::ArithFastMathInterface>(*op);
if (fmi) {
// TODO: use fmi.setFastMathFlagsAttr() after D137114 is merged.
// For now set the attribute by the name.
if (fmi && fmi.isArithFastMathApplicable()) {
llvm::StringRef arithFMFAttrName = fmi.getFastMathAttrName();
if (fastMathFlags != mlir::arith::FastMathFlags::none)
op->setAttr(arithFMFAttrName, mlir::arith::FastMathFlagsAttr::get(
Expand Down
9 changes: 7 additions & 2 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,15 @@ struct CallOpConversion : public fir::FIROpConversion<fir::CallOp> {
// Convert arith::FastMathFlagsAttr to LLVM::FastMathFlagsAttr.
mlir::arith::AttrConvertFastMathToLLVM<fir::CallOp, mlir::LLVM::CallOp>
attrConvert(call);
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
call, resultTys, adaptor.getOperands(),
auto llvmCall = rewriter.create<mlir::LLVM::CallOp>(
call.getLoc(), resultTys, adaptor.getOperands(),
addLLVMOpBundleAttrs(rewriter, attrConvert.getAttrs(),
adaptor.getOperands().size()));
auto fmi =
mlir::cast<mlir::LLVM::FastmathFlagsInterface>(llvmCall.getOperation());
if (!fmi.isFastmathApplicable())
llvmCall.setFastmathFlags(mlir::LLVM::FastmathFlags::none);
rewriter.replaceOp(call, llvmCall);
return mlir::success();
}
};
Expand Down
17 changes: 17 additions & 0 deletions flang/lib/Optimizer/HLFIR/IR/HLFIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,20 @@ bool hlfir::isFortranIntegerScalarOrArrayObject(mlir::Type type) {
mlir::Type elementType = getFortranElementType(unwrappedType);
return mlir::isa<mlir::IntegerType>(elementType);
}

bool hlfir::isArithFastMathApplicable(mlir::Operation *op) {
if (llvm::any_of(op->getResults(), [](mlir::Value v) {
mlir::Type elementType = getFortranElementType(v.getType());
return mlir::arith::ArithFastMathInterface::isCompatibleType(
elementType);
}))
return true;
if (llvm::any_of(op->getOperands(), [](mlir::Value v) {
mlir::Type elementType = getFortranElementType(v.getType());
return mlir::arith::ArithFastMathInterface::isCompatibleType(
elementType);
}))
return true;

return true;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change this to false after fixing the lowering tests.

}
2 changes: 1 addition & 1 deletion flang/test/Fir/CUDA/cuda-gpu-launch-func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<i1, dense<8> : ve
%45 = llvm.call @_FortranACUFDataTransferPtrPtr(%14, %25, %2, %11, %13, %5) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
gpu.launch_func @cuda_device_mod::@_QMmod1Psub1 blocks in (%7, %7, %7) threads in (%12, %7, %7) : i64 dynamic_shared_memory_size %11 args(%14 : !llvm.ptr)
%46 = llvm.call @_FortranACUFDataTransferPtrPtr(%25, %14, %2, %10, %13, %4) : (!llvm.ptr, !llvm.ptr, i64, i32, !llvm.ptr, i32) -> !llvm.struct<()>
%47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
%47 = llvm.call @_FortranAioBeginExternalListOutput(%9, %13, %8) : (i32, !llvm.ptr, i32) -> !llvm.ptr
%48 = llvm.mlir.constant(9 : i32) : i32
%49 = llvm.mlir.zero : !llvm.ptr
%50 = llvm.getelementptr %49[1] : (!llvm.ptr) -> !llvm.ptr, i32
Expand Down
6 changes: 3 additions & 3 deletions flang/test/Fir/tbaa.fir
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ module {
// CHECK: %[[VAL_6:.*]] = llvm.mlir.constant(-1 : i32) : i32
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @_QFEx : !llvm.ptr
// CHECK: %[[VAL_8:.*]] = llvm.mlir.addressof @_QQclX2E2F64756D6D792E66393000 : !llvm.ptr
// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) {fastmathFlags = #llvm.fastmath<contract>} : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_10:.*]] = llvm.call @_FortranAioBeginExternalListOutput(%[[VAL_6]], %[[VAL_8]], %[[VAL_5]]) : (i32, !llvm.ptr, i32) -> !llvm.ptr
// CHECK: %[[VAL_11:.*]] = llvm.mlir.constant(64 : i32) : i32
// CHECK: "llvm.intr.memcpy"(%[[VAL_3]], %[[VAL_7]], %[[VAL_11]]) <{isVolatile = false, tbaa = [#[[$BOXT]]]}>
// CHECK: %[[VAL_12:.*]] = llvm.getelementptr %[[VAL_3]][0, 7, %[[VAL_4]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
Expand Down Expand Up @@ -188,8 +188,8 @@ module {
// CHECK: %[[VAL_59:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_58]][7, 0, 2] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: %[[VAL_61:.*]] = llvm.insertvalue %[[VAL_52]], %[[VAL_59]][0] : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>
// CHECK: llvm.store %[[VAL_61]], %[[VAL_1]] {tbaa = [#[[$BOXT]]]} : !llvm.struct<(ptr, i64, i32, i8, i8, i8, i8, array<1 x array<3 x i64>>, ptr, array<1 x i64>)>, !llvm.ptr
// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr, !llvm.ptr) -> i1
// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) {fastmathFlags = #llvm.fastmath<contract>} : (!llvm.ptr) -> i32
// CHECK: %[[VAL_63:.*]] = llvm.call @_FortranAioOutputDescriptor(%[[VAL_10]], %[[VAL_1]]) : (!llvm.ptr, !llvm.ptr) -> i1
// CHECK: %[[VAL_64:.*]] = llvm.call @_FortranAioEndIoStatement(%[[VAL_10]]) : (!llvm.ptr) -> i32
// CHECK: llvm.return
// CHECK: }
// CHECK: llvm.func @_FortranAioBeginExternalListOutput(i32, !llvm.ptr, i32) -> !llvm.ptr attributes {fir.io, fir.runtime, sym_visibility = "private"}
Expand Down
11 changes: 11 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,17 @@ def Arith_CmpFOp : Arith_CompareOp<"cmpf",
let hasCanonicalizer = 1;
let assemblyFormat = [{ $predicate `,` $lhs `,` $rhs (`fastmath` `` $fastmath^)?
attr-dict `:` type($lhs)}];

let extraClassDeclaration = [{
/// Always allow FastMathFlags on arith.cmpf.
/// It does not produce a floating point result, but
/// LLVM is currently relying on fast-math flags attached
/// to floating point comparison.
/// This can be removed whenever LLVM stops doing it.
bool isArithFastMathApplicable() {
return true;
}
}];
}

//===----------------------------------------------------------------------===//
Expand Down
72 changes: 52 additions & 20 deletions mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,63 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {

let cppNamespace = "::mlir::arith";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns a FastMathFlagsAttr attribute for the operation",
/*returnType=*/ "FastMathFlagsAttr",
/*methodName=*/ "getFastMathFlagsAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
let methods =
[InterfaceMethod<
/*desc=*/"Returns a FastMathFlagsAttr attribute for the operation",
/*returnType=*/"FastMathFlagsAttr",
/*methodName=*/"getFastMathFlagsAttr",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImpl=*/[{
ConcreteOp op = cast<ConcreteOp>(this->getOperation());
return op.getFastmathAttr();
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the FastMathFlagsAttr attribute
}]>,
StaticInterfaceMethod<
/*desc=*/[{Returns the name of the FastMathFlagsAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getFastMathAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
/*returnType=*/"StringRef",
/*methodName=*/"getFastMathAttrName",
/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImpl=*/[{
return "fastmath";
}]
>
}]>,
InterfaceMethod<
/*desc=*/[{Returns true iff FastMathFlagsAttr attribute
is applicable to the operation that supports
ArithFastMathInterface. If it returns false,
then the FastMathFlagsAttr of the operation
must be nullptr or have 'none' value}],
/*returnType=*/"bool",
/*methodName=*/"isArithFastMathApplicable",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I think of this as a sort of "verifier" for fastMath flags for the given operation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its intention is to tell whether fast-math flags are applicable. It is used in the verified code below, but it may also be used by the passes/builders the create new operations supporting ArithFastMathInterface, e.g. see its usage in FIRBuilder.cpp file above.

/*args=*/(ins),
/*methodBody=*/[{}],
/*defaultImpl=*/[{
return ::mlir::cast<::mlir::arith::ArithFastMathInterface>(this->getOperation()).isApplicableImpl();
}]>];

];
let extraClassDeclaration = [{
/// Returns true iff the given type is a floating point type
/// or contains one.
static bool isCompatibleType(::mlir::Type);

/// Default implementation of isArithFastMathApplicable().
/// It returns true iff any of the results of the operations
/// has a type that is compatible with fast-math.
bool isApplicableImpl();
}];

let verify = [{
auto fmi = ::mlir::cast<::mlir::arith::ArithFastMathInterface>($_op);
auto attr = fmi.getFastMathFlagsAttr();
if (attr && attr.getValue() != ::mlir::arith::FastMathFlags::none &&
!fmi.isArithFastMathApplicable())
return $_op->emitOpError()
<< "has flag(s) `" << stringifyEnum(attr.getValue())
<< "`, but fast-math flags are not applicable "
"(`isArithFastMathApplicable()` returns false)";
return ::mlir::success();
}];
}

def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
Expand Down
Loading