Skip to content

Commit f651080

Browse files
Fixed printing of nan floats/doubles in Python.
The second assert in _upb_EncodeRoundTripFloat is raised if val is a nan. This fix just returns the output of first spnprintf. I am not sure how changes to this repo are made so feel free to ignore this CL. To test this, you could 1. Define a proto with a float field message Test { float val = 1; } 2. In a python script, import the library and then set the val to nan and try to print it. proto = Test(val=float('nan')) print(proto) This will cause a coredump due to assertion error: assert.h assertion failed at third_party/upb/upb/lex/round_trip.c:46 in void _upb_EncodeRoundTripFloat(float, char *, size_t): strtof(buf, NULL) == val Added the corresponding change to double too PiperOrigin-RevId: 637127851
1 parent ee98ba2 commit f651080

File tree

4 files changed

+65
-0
lines changed

4 files changed

+65
-0
lines changed

python/google/protobuf/internal/message_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,23 @@ def testFloatPrinting(self, message_module):
382382
message.optional_float = 2.0
383383
self.assertEqual(str(message), 'optional_float: 2.0\n')
384384

385+
def testFloatNanPrinting(self, message_module):
386+
message = message_module.TestAllTypes()
387+
message.optional_float = float('nan')
388+
self.assertEqual(str(message), 'optional_float: nan\n')
389+
385390
def testHighPrecisionFloatPrinting(self, message_module):
386391
msg = message_module.TestAllTypes()
387392
msg.optional_float = 0.12345678912345678
388393
old_float = msg.optional_float
389394
msg.ParseFromString(msg.SerializeToString())
390395
self.assertEqual(old_float, msg.optional_float)
391396

397+
def testDoubleNanPrinting(self, message_module):
398+
message = message_module.TestAllTypes()
399+
message.optional_double = float('nan')
400+
self.assertEqual(str(message), 'optional_double: nan\n')
401+
392402
def testHighPrecisionDoublePrinting(self, message_module):
393403
msg = message_module.TestAllTypes()
394404
msg.optional_double = 0.12345678912345678

upb/lex/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ cc_test(
4141
],
4242
)
4343

44+
cc_test(
45+
name = "round_trip_test",
46+
srcs = ["round_trip_test.cc"],
47+
deps = [
48+
":lex",
49+
"@com_google_googletest//:gtest",
50+
"@com_google_googletest//:gtest_main",
51+
],
52+
)
53+
4454
# begin:github_only
4555
filegroup(
4656
name = "source_files",

upb/lex/round_trip.c

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
#include "upb/lex/round_trip.h"
99

1010
#include <float.h>
11+
#include <math.h>
12+
#include <stdio.h>
1113
#include <stdlib.h>
1214

1315
// Must be last.
@@ -28,6 +30,10 @@ static void upb_FixLocale(char* p) {
2830

2931
void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) {
3032
assert(size >= kUpb_RoundTripBufferSize);
33+
if (isnan(val)) {
34+
snprintf(buf, size, "%s", "nan");
35+
return;
36+
}
3137
snprintf(buf, size, "%.*g", DBL_DIG, val);
3238
if (strtod(buf, NULL) != val) {
3339
snprintf(buf, size, "%.*g", DBL_DIG + 2, val);
@@ -38,6 +44,10 @@ void _upb_EncodeRoundTripDouble(double val, char* buf, size_t size) {
3844

3945
void _upb_EncodeRoundTripFloat(float val, char* buf, size_t size) {
4046
assert(size >= kUpb_RoundTripBufferSize);
47+
if (isnan(val)) {
48+
snprintf(buf, size, "%s", "nan");
49+
return;
50+
}
4151
snprintf(buf, size, "%.*g", FLT_DIG, val);
4252
if (strtof(buf, NULL) != val) {
4353
snprintf(buf, size, "%.*g", FLT_DIG + 3, val);

upb/lex/round_trip_test.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include "upb/lex/round_trip.h"
2+
3+
#include <math.h>
4+
5+
#include <gtest/gtest.h>
6+
7+
namespace {
8+
9+
TEST(RoundTripTest, Double) {
10+
char buf[32];
11+
12+
_upb_EncodeRoundTripDouble(0.123456789, buf, sizeof(buf));
13+
EXPECT_STREQ(buf, "0.123456789");
14+
15+
_upb_EncodeRoundTripDouble(0.0, buf, sizeof(buf));
16+
EXPECT_STREQ(buf, "0");
17+
18+
_upb_EncodeRoundTripDouble(nan(""), buf, sizeof(buf));
19+
EXPECT_STREQ(buf, "nan");
20+
}
21+
22+
TEST(RoundTripTest, Float) {
23+
char buf[32];
24+
25+
_upb_EncodeRoundTripFloat(0.123456, buf, sizeof(buf));
26+
EXPECT_STREQ(buf, "0.123456");
27+
28+
_upb_EncodeRoundTripFloat(0.0, buf, sizeof(buf));
29+
EXPECT_STREQ(buf, "0");
30+
31+
_upb_EncodeRoundTripFloat(nan(""), buf, sizeof(buf));
32+
EXPECT_STREQ(buf, "nan");
33+
}
34+
35+
} // namespace

0 commit comments

Comments
 (0)