Skip to content

Commit 4548bff

Browse files
[mlir][Parser] Deduplicate floating-point parsing functionality (#116172)
The following functionality is duplicated in multiple places: trying to parse an APFloat from a floating point literal or an integer in hexadecimal representation (bit pattern). Move it to a common helper function. NFC apart from the slightly changed error messages. (We now print the exact same error messages regardless of whether the float is parsed standalone or inside of a tensor literal, etc.)
1 parent 63b926a commit 4548bff

File tree

5 files changed

+69
-103
lines changed

5 files changed

+69
-103
lines changed

mlir/lib/AsmParser/AsmParserImpl.h

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -287,34 +287,13 @@ class AsmParserImpl : public BaseT {
287287
APFloat &result) override {
288288
bool isNegative = parser.consumeIf(Token::minus);
289289
Token curTok = parser.getToken();
290-
SMLoc loc = curTok.getLoc();
291-
292-
// Check for a floating point value.
293-
if (curTok.is(Token::floatliteral)) {
294-
auto val = curTok.getFloatingPointValue();
295-
if (!val)
296-
return emitError(loc, "floating point value too large");
297-
parser.consumeToken(Token::floatliteral);
298-
result = APFloat(isNegative ? -*val : *val);
299-
bool losesInfo;
300-
result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
301-
return success();
302-
}
303-
304-
// Check for a hexadecimal float value.
305-
if (curTok.is(Token::integer)) {
306-
std::optional<APFloat> apResult;
307-
if (failed(parser.parseFloatFromIntegerLiteral(
308-
apResult, curTok, isNegative, semantics,
309-
APFloat::semanticsSizeInBits(semantics))))
310-
return failure();
311-
312-
result = *apResult;
313-
parser.consumeToken(Token::integer);
314-
return success();
315-
}
316-
317-
return emitError(loc, "expected floating point literal");
290+
std::optional<APFloat> apResult;
291+
if (failed(parser.parseFloatFromLiteral(apResult, curTok, isNegative,
292+
semantics)))
293+
return failure();
294+
parser.consumeToken();
295+
result = *apResult;
296+
return success();
318297
}
319298

320299
/// Parse a floating point value from the stream.

mlir/lib/AsmParser/AttributeParser.cpp

Lines changed: 13 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
424424
if (auto floatType = dyn_cast<FloatType>(type)) {
425425
std::optional<APFloat> result;
426426
if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
427-
floatType.getFloatSemantics(),
428-
floatType.getWidth())))
427+
floatType.getFloatSemantics())))
429428
return Attribute();
430429
return FloatAttr::get(floatType, *result);
431430
}
@@ -658,36 +657,11 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
658657
for (const auto &signAndToken : storage) {
659658
bool isNegative = signAndToken.first;
660659
const Token &token = signAndToken.second;
661-
662-
// Handle hexadecimal float literals.
663-
if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
664-
std::optional<APFloat> result;
665-
if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
666-
eltTy.getFloatSemantics(),
667-
eltTy.getWidth())))
668-
return failure();
669-
670-
floatValues.push_back(*result);
671-
continue;
672-
}
673-
674-
// Check to see if any decimal integers or booleans were parsed.
675-
if (!token.is(Token::floatliteral))
676-
return p.emitError()
677-
<< "expected floating-point elements, but parsed integer";
678-
679-
// Build the float values from tokens.
680-
auto val = token.getFloatingPointValue();
681-
if (!val)
682-
return p.emitError("floating point value too large for attribute");
683-
684-
APFloat apVal(isNegative ? -*val : *val);
685-
if (!eltTy.isF64()) {
686-
bool unused;
687-
apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
688-
&unused);
689-
}
690-
floatValues.push_back(apVal);
660+
std::optional<APFloat> result;
661+
if (failed(p.parseFloatFromLiteral(result, token, isNegative,
662+
eltTy.getFloatSemantics())))
663+
return failure();
664+
floatValues.push_back(*result);
691665
}
692666
return success();
693667
}
@@ -905,32 +879,14 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {
905879

906880
ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
907881
bool isNegative = p.consumeIf(Token::minus);
908-
909882
Token token = p.getToken();
910-
std::optional<APFloat> result;
911-
auto floatType = cast<FloatType>(type);
912-
if (p.consumeIf(Token::integer)) {
913-
// Parse an integer literal as a float.
914-
if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
915-
floatType.getFloatSemantics(),
916-
floatType.getWidth()))
917-
return failure();
918-
} else if (p.consumeIf(Token::floatliteral)) {
919-
// Parse a floating point literal.
920-
std::optional<double> val = token.getFloatingPointValue();
921-
if (!val)
922-
return failure();
923-
result = APFloat(isNegative ? -*val : *val);
924-
if (!type.isF64()) {
925-
bool unused;
926-
result->convert(floatType.getFloatSemantics(),
927-
APFloat::rmNearestTiesToEven, &unused);
928-
}
929-
} else {
930-
return p.emitError("expected integer or floating point literal");
931-
}
932-
933-
append(result->bitcastToAPInt());
883+
std::optional<APFloat> fromIntLit;
884+
if (failed(
885+
p.parseFloatFromLiteral(fromIntLit, token, isNegative,
886+
cast<FloatType>(type).getFloatSemantics())))
887+
return failure();
888+
p.consumeToken();
889+
append(fromIntLit->bitcastToAPInt());
934890
return success();
935891
}
936892

mlir/lib/AsmParser/Parser.cpp

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -347,34 +347,58 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
347347
return success();
348348
}
349349

350+
ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result,
351+
const Token &tok, bool isNegative,
352+
const llvm::fltSemantics &semantics) {
353+
// Check for a floating point value.
354+
if (tok.is(Token::floatliteral)) {
355+
auto val = tok.getFloatingPointValue();
356+
if (!val)
357+
return emitError(tok.getLoc()) << "floating point value too large";
358+
359+
result.emplace(isNegative ? -*val : *val);
360+
bool unused;
361+
result->convert(semantics, APFloat::rmNearestTiesToEven, &unused);
362+
return success();
363+
}
364+
365+
// Check for a hexadecimal float value.
366+
if (tok.is(Token::integer))
367+
return parseFloatFromIntegerLiteral(result, tok, isNegative, semantics);
368+
369+
return emitError(tok.getLoc()) << "expected floating point literal";
370+
}
371+
350372
/// Parse a floating point value from an integer literal token.
351-
ParseResult Parser::parseFloatFromIntegerLiteral(
352-
std::optional<APFloat> &result, const Token &tok, bool isNegative,
353-
const llvm::fltSemantics &semantics, size_t typeSizeInBits) {
354-
SMLoc loc = tok.getLoc();
373+
ParseResult
374+
Parser::parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
375+
const Token &tok, bool isNegative,
376+
const llvm::fltSemantics &semantics) {
355377
StringRef spelling = tok.getSpelling();
356378
bool isHex = spelling.size() > 1 && spelling[1] == 'x';
357379
if (!isHex) {
358-
return emitError(loc, "unexpected decimal integer literal for a "
359-
"floating point value")
380+
return emitError(tok.getLoc(), "unexpected decimal integer literal for a "
381+
"floating point value")
360382
.attachNote()
361383
<< "add a trailing dot to make the literal a float";
362384
}
363385
if (isNegative) {
364-
return emitError(loc, "hexadecimal float literal should not have a "
365-
"leading minus");
386+
return emitError(tok.getLoc(),
387+
"hexadecimal float literal should not have a "
388+
"leading minus");
366389
}
367390

368391
APInt intValue;
369392
tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
370-
if (intValue.getActiveBits() > typeSizeInBits)
371-
return emitError(loc, "hexadecimal float constant out of range for type");
393+
auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
394+
if (intValue.getActiveBits() > typeSizeInBits) {
395+
return emitError(tok.getLoc(),
396+
"hexadecimal float constant out of range for type");
397+
}
372398

373399
APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
374400
intValue.getRawData());
375-
376401
result.emplace(semantics, truncatedValue);
377-
378402
return success();
379403
}
380404

mlir/lib/AsmParser/Parser.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
namespace mlir {
1818
namespace detail {
19+
1920
//===----------------------------------------------------------------------===//
2021
// Parser
2122
//===----------------------------------------------------------------------===//
@@ -151,11 +152,15 @@ class Parser {
151152
/// Parse an optional integer value only in decimal format from the stream.
152153
OptionalParseResult parseOptionalDecimalInteger(APInt &result);
153154

155+
/// Parse a floating point value from a literal.
156+
ParseResult parseFloatFromLiteral(std::optional<APFloat> &result,
157+
const Token &tok, bool isNegative,
158+
const llvm::fltSemantics &semantics);
159+
154160
/// Parse a floating point value from an integer literal token.
155161
ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
156162
const Token &tok, bool isNegative,
157-
const llvm::fltSemantics &semantics,
158-
size_t typeSizeInBits);
163+
const llvm::fltSemantics &semantics);
159164

160165
/// Returns true if the current token corresponds to a keyword.
161166
bool isCurrentTokenAKeyword() const {

mlir/test/IR/invalid-builtin-attributes.mlir

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () {
4545
// -----
4646

4747
func.func @elementsattr_floattype2() -> () {
48-
// expected-error@+1 {{expected floating-point elements, but parsed integer}}
48+
// expected-error@below {{unexpected decimal integer literal for a floating point value}}
49+
// expected-note@below {{add a trailing dot to make the literal a float}}
4950
"foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> ()
5051
}
5152

@@ -138,21 +139,22 @@ func.func @float_in_int_tensor() {
138139
// -----
139140

140141
func.func @float_in_bool_tensor() {
141-
// expected-error @+1 {{expected integer elements, but parsed floating-point}}
142+
// expected-error@below {{expected integer elements, but parsed floating-point}}
142143
"foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> ()
143144
}
144145

145146
// -----
146147

147148
func.func @decimal_int_in_float_tensor() {
148-
// expected-error @+1 {{expected floating-point elements, but parsed integer}}
149+
// expected-error@below {{unexpected decimal integer literal for a floating point value}}
150+
// expected-note@below {{add a trailing dot to make the literal a float}}
149151
"foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> ()
150152
}
151153

152154
// -----
153155

154156
func.func @bool_in_float_tensor() {
155-
// expected-error @+1 {{expected floating-point elements, but parsed integer}}
157+
// expected-error @+1 {{expected floating point literal}}
156158
"foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> ()
157159
}
158160

0 commit comments

Comments
 (0)