Skip to content

Commit 5f59b72

Browse files
committed
Revert "[mlir][arith] Add overflow flags support to arith ops (#77211)"
Temporarily reverting as it broke python bindings This reverts commit a7262d2.
1 parent 5afc4f3 commit 5f59b72

File tree

11 files changed

+73
-321
lines changed

11 files changed

+73
-321
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,14 @@
1818

1919
namespace mlir {
2020
namespace arith {
21-
/// Maps arithmetic fastmath enum values to LLVM enum values.
21+
// Map arithmetic fastmath enum values to LLVMIR enum values.
2222
LLVM::FastmathFlags
2323
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
2424

25-
/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
26-
/// attribute.
25+
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
2726
LLVM::FastmathFlagsAttr
2827
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
2928

30-
/// Maps arithmetic overflow enum values to LLVM enum values.
31-
LLVM::IntegerOverflowFlags
32-
convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
33-
34-
/// Creates an LLVM overflow attribute from a given arithmetic overflow
35-
/// attribute.
36-
LLVM::IntegerOverflowFlagsAttr
37-
convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
38-
3929
// Attribute converter that populates a NamedAttrList by removing the fastmath
4030
// attribute from the source operation attributes, and replacing it with an
4131
// equivalent LLVM fastmath attribute.
@@ -46,46 +36,19 @@ class AttrConvertFastMathToLLVM {
4636
// Copy the source attributes.
4737
convertedAttr = NamedAttrList{srcOp->getAttrs()};
4838
// Get the name of the arith fastmath attribute.
49-
StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
39+
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
5040
// Remove the source fastmath attribute.
51-
auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
41+
auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
5242
convertedAttr.erase(arithFMFAttrName));
5343
if (arithFMFAttr) {
54-
StringRef targetAttrName = TargetOp::getFastmathAttrName();
44+
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
5545
convertedAttr.set(targetAttrName,
5646
convertArithFastMathAttrToLLVM(arithFMFAttr));
5747
}
5848
}
5949

6050
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
6151

62-
private:
63-
NamedAttrList convertedAttr;
64-
};
65-
66-
// Attribute converter that populates a NamedAttrList by removing the overflow
67-
// attribute from the source operation attributes, and replacing it with an
68-
// equivalent LLVM overflow attribute.
69-
template <typename SourceOp, typename TargetOp>
70-
class AttrConvertOverflowToLLVM {
71-
public:
72-
AttrConvertOverflowToLLVM(SourceOp srcOp) {
73-
// Copy the source attributes.
74-
convertedAttr = NamedAttrList{srcOp->getAttrs()};
75-
// Get the name of the arith overflow attribute.
76-
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
77-
// Remove the source overflow attribute.
78-
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
79-
convertedAttr.erase(arithAttrName));
80-
if (arithAttr) {
81-
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
82-
convertedAttr.set(targetAttrName,
83-
convertArithOveflowAttrToLLVM(arithAttr));
84-
}
85-
}
86-
87-
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
88-
8952
private:
9053
NamedAttrList convertedAttr;
9154
};

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,4 @@ def Arith_FastMathAttr :
133133
let assemblyFormat = "`<` $value `>`";
134134
}
135135

136-
//===----------------------------------------------------------------------===//
137-
// IntegerOverflowFlags
138-
//===----------------------------------------------------------------------===//
139-
140-
def IOFnone : I32BitEnumAttrCaseNone<"none">;
141-
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
142-
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
143-
144-
def IntegerOverflowFlags : I32BitEnumAttr<
145-
"IntegerOverflowFlags",
146-
"Integer overflow arith flags",
147-
[IOFnone, IOFnsw, IOFnuw]> {
148-
let separator = ", ";
149-
let cppNamespace = "::mlir::arith";
150-
let genSpecializedAttr = 0;
151-
let printBitEnumPrimaryGroups = 1;
152-
}
153-
154-
def Arith_IntegerOverflowAttr :
155-
EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> {
156-
let assemblyFormat = "`<` $value `>`";
157-
}
158-
159136
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 20 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,6 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
137137
let results = (outs BoolLikeOfAnyRank:$result);
138138
}
139139

140-
class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
141-
Arith_BinaryOp<mnemonic, traits #
142-
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
143-
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
144-
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
145-
DefaultValuedAttr<
146-
Arith_IntegerOverflowAttr,
147-
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
148-
Results<(outs SignlessIntegerLike:$result)> {
149-
150-
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
151-
attr-dict `:` type($result) }];
152-
}
153-
154140
//===----------------------------------------------------------------------===//
155141
// ConstantOp
156142
//===----------------------------------------------------------------------===//
@@ -206,7 +192,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
206192
// AddIOp
207193
//===----------------------------------------------------------------------===//
208194

209-
def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
195+
def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
210196
let summary = "integer addition operation";
211197
let description = [{
212198
Performs N-bit addition on the operands. The operands are interpreted as
@@ -217,23 +203,16 @@ def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
217203

218204
The `addi` operation takes two operands and returns one result, each of
219205
these is required to be the same type. This type may be an integer scalar type,
220-
a vector whose element type is integer, or a tensor of integers.
221-
222-
This op supports `nuw`/`nsw` overflow flags which stands stand for
223-
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
224-
`nsw` flags are present, and an unsigned/signed overflow occurs
225-
(respectively), the result is poison.
206+
a vector whose element type is integer, or a tensor of integers. It has no
207+
standard attributes.
226208

227209
Example:
228210

229211
```mlir
230212
// Scalar addition.
231213
%a = arith.addi %b, %c : i64
232214

233-
// Scalar addition with overflow flags.
234-
%a = arith.addi %b, %c overflow<nsw, nuw> : i64
235-
236-
// SIMD vector element-wise addition.
215+
// SIMD vector element-wise addition, e.g. for Intel SSE.
237216
%f = arith.addi %g, %h : vector<4xi32>
238217

239218
// Tensor element-wise addition.
@@ -299,41 +278,21 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
299278
// SubIOp
300279
//===----------------------------------------------------------------------===//
301280

302-
def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
281+
def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
303282
let summary = [{
304283
Integer subtraction operation.
305284
}];
306285
let description = [{
307-
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
308-
bitvectors. The result is represented by a bitvector containing the mathematical
309-
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
310-
integers use a two's complement representation, this operation is applicable on
286+
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
287+
bitvectors. The result is represented by a bitvector containing the mathematical
288+
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
289+
integers use a two's complement representation, this operation is applicable on
311290
both signed and unsigned integer operands.
312291

313292
The `subi` operation takes two operands and returns one result, each of
314-
these is required to be the same type. This type may be an integer scalar type,
315-
a vector whose element type is integer, or a tensor of integers.
316-
317-
This op supports `nuw`/`nsw` overflow flags which stands stand for
318-
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
319-
`nsw` flags are present, and an unsigned/signed overflow occurs
320-
(respectively), the result is poison.
321-
322-
Example:
323-
324-
```mlir
325-
// Scalar subtraction.
326-
%a = arith.subi %b, %c : i64
327-
328-
// Scalar subtraction with overflow flags.
329-
%a = arith.subi %b, %c overflow<nsw, nuw> : i64
330-
331-
// SIMD vector element-wise subtraction.
332-
%f = arith.subi %g, %h : vector<4xi32>
333-
334-
// Tensor element-wise subtraction.
335-
%x = arith.subi %y, %z : tensor<4x?xi8>
336-
```
293+
these is required to be the same type. This type may be an integer scalar type,
294+
a vector whose element type is integer, or a tensor of integers. It has no
295+
standard attributes.
337296
}];
338297
let hasFolder = 1;
339298
let hasCanonicalizer = 1;
@@ -343,41 +302,21 @@ def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
343302
// MulIOp
344303
//===----------------------------------------------------------------------===//
345304

346-
def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
305+
def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
347306
let summary = [{
348307
Integer multiplication operation.
349308
}];
350309
let description = [{
351-
Performs N-bit multiplication on the operands. The operands are interpreted as
352-
unsigned bitvectors. The result is represented by a bitvector containing the
353-
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
354-
Because `arith` integers use a two's complement representation, this operation is
310+
Performs N-bit multiplication on the operands. The operands are interpreted as
311+
unsigned bitvectors. The result is represented by a bitvector containing the
312+
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
313+
Because `arith` integers use a two's complement representation, this operation is
355314
applicable on both signed and unsigned integer operands.
356315

357316
The `muli` operation takes two operands and returns one result, each of
358-
these is required to be the same type. This type may be an integer scalar type,
359-
a vector whose element type is integer, or a tensor of integers.
360-
361-
This op supports `nuw`/`nsw` overflow flags which stands stand for
362-
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
363-
`nsw` flags are present, and an unsigned/signed overflow occurs
364-
(respectively), the result is poison.
365-
366-
Example:
367-
368-
```mlir
369-
// Scalar multiplication.
370-
%a = arith.muli %b, %c : i64
371-
372-
// Scalar multiplication with overflow flags.
373-
%a = arith.muli %b, %c overflow<nsw, nuw> : i64
374-
375-
// SIMD vector element-wise multiplication.
376-
%f = arith.muli %g, %h : vector<4xi32>
377-
378-
// Tensor element-wise multiplication.
379-
%x = arith.muli %y, %z : tensor<4x?xi8>
380-
```
317+
these is required to be the same type. This type may be an integer scalar type,
318+
a vector whose element type is integer, or a tensor of integers. It has no
319+
standard attributes.
381320
}];
382321
let hasFolder = 1;
383322
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -49,61 +49,4 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
4949
];
5050
}
5151

52-
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
53-
let description = [{
54-
Access to op integer overflow flags.
55-
}];
56-
57-
let cppNamespace = "::mlir::arith";
58-
59-
let methods = [
60-
InterfaceMethod<
61-
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
62-
/*returnType=*/ "IntegerOverflowFlagsAttr",
63-
/*methodName=*/ "getOverflowAttr",
64-
/*args=*/ (ins),
65-
/*methodBody=*/ [{}],
66-
/*defaultImpl=*/ [{
67-
auto op = cast<ConcreteOp>(this->getOperation());
68-
return op.getOverflowFlagsAttr();
69-
}]
70-
>,
71-
InterfaceMethod<
72-
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
73-
/*returnType=*/ "bool",
74-
/*methodName=*/ "hasNoUnsignedWrap",
75-
/*args=*/ (ins),
76-
/*methodBody=*/ [{}],
77-
/*defaultImpl=*/ [{
78-
auto op = cast<ConcreteOp>(this->getOperation());
79-
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
80-
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
81-
}]
82-
>,
83-
InterfaceMethod<
84-
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
85-
/*returnType=*/ "bool",
86-
/*methodName=*/ "hasNoSignedWrap",
87-
/*args=*/ (ins),
88-
/*methodBody=*/ [{}],
89-
/*defaultImpl=*/ [{
90-
auto op = cast<ConcreteOp>(this->getOperation());
91-
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
92-
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
93-
}]
94-
>,
95-
StaticInterfaceMethod<
96-
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
97-
for the operation}],
98-
/*returnType=*/ "StringRef",
99-
/*methodName=*/ "getIntegerOverflowAttrName",
100-
/*args=*/ (ins),
101-
/*methodBody=*/ [{}],
102-
/*defaultImpl=*/ [{
103-
return "overflowFlags";
104-
}]
105-
>
106-
];
107-
}
108-
10952
#endif // ARITH_OPS_INTERFACES

mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
using namespace mlir;
1212

13+
// Map arithmetic fastmath enum values to LLVMIR enum values.
1314
LLVM::FastmathFlags
1415
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
1516
LLVM::FastmathFlags llvmFMF{};
@@ -21,37 +22,17 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
2122
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
2223
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
2324
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
24-
for (auto [arithFlag, llvmFlag] : flags) {
25-
if (bitEnumContainsAny(arithFMF, arithFlag))
26-
llvmFMF = llvmFMF | llvmFlag;
25+
for (auto fmfMap : flags) {
26+
if (bitEnumContainsAny(arithFMF, fmfMap.first))
27+
llvmFMF = llvmFMF | fmfMap.second;
2728
}
2829
return llvmFMF;
2930
}
3031

32+
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
3133
LLVM::FastmathFlagsAttr
3234
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
3335
arith::FastMathFlags arithFMF = fmfAttr.getValue();
3436
return LLVM::FastmathFlagsAttr::get(
3537
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
3638
}
37-
38-
LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
39-
arith::IntegerOverflowFlags arithFlags) {
40-
LLVM::IntegerOverflowFlags llvmFlags{};
41-
const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
42-
flags[] = {
43-
{arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
44-
{arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
45-
for (auto [arithFlag, llvmFlag] : flags) {
46-
if (bitEnumContainsAny(arithFlags, arithFlag))
47-
llvmFlags = llvmFlags | llvmFlag;
48-
}
49-
return llvmFlags;
50-
}
51-
52-
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
53-
arith::IntegerOverflowFlagsAttr flagsAttr) {
54-
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
55-
return LLVM::IntegerOverflowFlagsAttr::get(
56-
flagsAttr.getContext(), convertArithOveflowFlagsToLLVM(arithFlags));
57-
}

0 commit comments

Comments
 (0)