Skip to content

Commit b690e72

Browse files
anandoleecopybara-github
authored andcommitted
Nextgen Proto Pythonic API: Timestamp/Duration assignment, creation and calculation
Timestamp and Duration are now have more support with datetime and timedelta: - Allows assign python datetime to protobuf DateTime field in addition to current FromDatetime/ToDatetime (Note: will throw exceptions for the differences in supported ranges) - Allows assign python timedelta to protobuf Duration field in addition to current FromTimedelta/ToTimedelta - Calculation between Timestamp, Duration, datetime and timedelta will also be supported. example usage: from datetime import datetime, timedelta from event_pb2 import Event e = Event(start_time=datetime(year=2112, month=2, day=3), duration=timedelta(hours=10)) duration = timedelta(hours=10)) end_time = e.start_time + timedelta(hours=4) e.duration = end_time - e.start_time PiperOrigin-RevId: 640639168
1 parent a450c9c commit b690e72

File tree

7 files changed

+445
-31
lines changed

7 files changed

+445
-31
lines changed

python/google/protobuf/internal/descriptor_pool_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from google.protobuf.internal import no_package_pb2
3030
from google.protobuf.internal import testing_refleaks
3131

32+
from google.protobuf import duration_pb2
33+
from google.protobuf import timestamp_pb2
3234
from google.protobuf import unittest_features_pb2
3335
from google.protobuf import unittest_import_pb2
3436
from google.protobuf import unittest_import_public_pb2
@@ -435,6 +437,8 @@ def testAddSerializedFile(self):
435437
self.assertEqual(file2.name,
436438
'google/protobuf/internal/factory_test2.proto')
437439
self.testFindMessageTypeByName()
440+
self.pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb)
441+
self.pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb)
438442
file_json = self.pool.AddSerializedFile(
439443
more_messages_pb2.DESCRIPTOR.serialized_pb)
440444
field = file_json.message_types_by_name['class'].fields_by_name['int_field']
@@ -542,12 +546,18 @@ def testComplexNesting(self):
542546
# that uses a DescriptorDatabase.
543547
# TODO: Fix python and cpp extension diff.
544548
return
549+
timestamp_desc = descriptor_pb2.FileDescriptorProto.FromString(
550+
timestamp_pb2.DESCRIPTOR.serialized_pb)
551+
duration_desc = descriptor_pb2.FileDescriptorProto.FromString(
552+
duration_pb2.DESCRIPTOR.serialized_pb)
545553
more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
546554
more_messages_pb2.DESCRIPTOR.serialized_pb)
547555
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
548556
descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb)
549557
test2_desc = descriptor_pb2.FileDescriptorProto.FromString(
550558
descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
559+
self.pool.Add(timestamp_desc)
560+
self.pool.Add(duration_desc)
551561
self.pool.Add(more_messages_desc)
552562
self.pool.Add(test1_desc)
553563
self.pool.Add(test2_desc)

python/google/protobuf/internal/more_messages.proto

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ syntax = "proto2";
1313

1414
package google.protobuf.internal;
1515

16+
import "google/protobuf/duration.proto";
17+
import "google/protobuf/timestamp.proto";
18+
1619
// A message where tag numbers are listed out of order, to allow us to test our
1720
// canonicalization of serialized output, which should always be in tag order.
1821
// We also mix in some extensions for extra fun.
@@ -348,3 +351,8 @@ message ConflictJsonName {
348351
optional int32 value = 1 [json_name = "old_value"];
349352
optional int32 new_value = 2 [json_name = "value"];
350353
}
354+
355+
message WKTMessage {
356+
optional Timestamp optional_timestamp = 1;
357+
optional Duration optional_duration = 2;
358+
}

python/google/protobuf/internal/python_message.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
__author__ = '[email protected] (Will Robinson)'
2929

30+
import datetime
3031
from io import BytesIO
3132
import struct
3233
import sys
@@ -536,13 +537,30 @@ def init(self, **kwargs):
536537
self._fields[field] = copy
537538
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
538539
copy = field._default_constructor(self)
539-
new_val = field_value
540-
if isinstance(field_value, dict):
540+
new_val = None
541+
if isinstance(field_value, message_mod.Message):
542+
new_val = field_value
543+
elif isinstance(field_value, dict):
541544
new_val = field.message_type._concrete_class(**field_value)
542-
try:
543-
copy.MergeFrom(new_val)
544-
except TypeError:
545-
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
545+
elif field.message_type.full_name == 'google.protobuf.Timestamp':
546+
copy.FromDatetime(field_value)
547+
elif field.message_type.full_name == 'google.protobuf.Duration':
548+
copy.FromTimedelta(field_value)
549+
else:
550+
raise TypeError(
551+
'Message field {0}.{1} must be initialized with a '
552+
'dict or instance of same class, got {2}.'.format(
553+
message_descriptor.name,
554+
field_name,
555+
type(field_value).__name__,
556+
)
557+
)
558+
559+
if new_val:
560+
try:
561+
copy.MergeFrom(new_val)
562+
except TypeError:
563+
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
546564
self._fields[field] = copy
547565
else:
548566
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
@@ -753,8 +771,17 @@ def getter(self):
753771
# We define a setter just so we can throw an exception with a more
754772
# helpful error message.
755773
def setter(self, new_value):
756-
raise AttributeError('Assignment not allowed to composite field '
757-
'"%s" in protocol message object.' % proto_field_name)
774+
if field.message_type.full_name == 'google.protobuf.Timestamp':
775+
getter(self)
776+
self._fields[field].FromDatetime(new_value)
777+
elif field.message_type.full_name == 'google.protobuf.Duration':
778+
getter(self)
779+
self._fields[field].FromTimedelta(new_value)
780+
else:
781+
raise AttributeError(
782+
'Assignment not allowed to composite field '
783+
'"%s" in protocol message object.' % proto_field_name
784+
)
758785

759786
# Add a property to encapsulate the getter.
760787
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name

python/google/protobuf/internal/well_known_types.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
import collections.abc
2222
import datetime
2323
import warnings
24-
2524
from google.protobuf.internal import field_mask
25+
from typing import Union
2626

2727
FieldMask = field_mask.FieldMask
2828

@@ -271,12 +271,35 @@ def FromDatetime(self, dt):
271271
# manipulated into a long value of seconds. During the conversion from
272272
# struct_time to long, the source date in UTC, and so it follows that the
273273
# correct transformation is calendar.timegm()
274-
seconds = calendar.timegm(dt.utctimetuple())
275-
nanos = dt.microsecond * _NANOS_PER_MICROSECOND
274+
try:
275+
seconds = calendar.timegm(dt.utctimetuple())
276+
nanos = dt.microsecond * _NANOS_PER_MICROSECOND
277+
except AttributeError as e:
278+
raise AttributeError(
279+
'Fail to convert to Timestamp. Expected a datetime like '
280+
'object got {0} : {1}'.format(type(dt).__name__, e)
281+
) from e
276282
_CheckTimestampValid(seconds, nanos)
277283
self.seconds = seconds
278284
self.nanos = nanos
279285

286+
def __add__(self, value) -> datetime.datetime:
287+
if isinstance(value, Duration):
288+
return self.ToDatetime() + value.ToTimedelta()
289+
return self.ToDatetime() + value
290+
291+
__radd__ = __add__
292+
293+
def __sub__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
294+
if isinstance(value, Timestamp):
295+
return self.ToDatetime() - value.ToDatetime()
296+
elif isinstance(value, Duration):
297+
return self.ToDatetime() - value.ToTimedelta()
298+
return self.ToDatetime() - value
299+
300+
def __rsub__(self, dt) -> datetime.timedelta:
301+
return dt - self.ToDatetime()
302+
280303

281304
def _CheckTimestampValid(seconds, nanos):
282305
if seconds < _TIMESTAMP_SECONDS_MIN or seconds > _TIMESTAMP_SECONDS_MAX:
@@ -408,8 +431,16 @@ def ToTimedelta(self) -> datetime.timedelta:
408431

409432
def FromTimedelta(self, td):
410433
"""Converts timedelta to Duration."""
411-
self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY,
412-
td.microseconds * _NANOS_PER_MICROSECOND)
434+
try:
435+
self._NormalizeDuration(
436+
td.seconds + td.days * _SECONDS_PER_DAY,
437+
td.microseconds * _NANOS_PER_MICROSECOND,
438+
)
439+
except AttributeError as e:
440+
raise AttributeError(
441+
'Fail to convert to Duration. Expected a timedelta like '
442+
'object got {0}: {1}'.format(type(td).__name__, e)
443+
) from e
413444

414445
def _NormalizeDuration(self, seconds, nanos):
415446
"""Set Duration by seconds and nanos."""
@@ -420,6 +451,16 @@ def _NormalizeDuration(self, seconds, nanos):
420451
self.seconds = seconds
421452
self.nanos = nanos
422453

454+
def __add__(self, value) -> Union[datetime.datetime, datetime.timedelta]:
455+
if isinstance(value, Timestamp):
456+
return self.ToTimedelta() + value.ToDatetime()
457+
return self.ToTimedelta() + value
458+
459+
__radd__ = __add__
460+
461+
def __rsub__(self, dt) -> Union[datetime.datetime, datetime.timedelta]:
462+
return dt - self.ToTimedelta()
463+
423464

424465
def _CheckDurationValid(seconds, nanos):
425466
if seconds < -_DURATION_SECONDS_MAX or seconds > _DURATION_SECONDS_MAX:

0 commit comments

Comments
 (0)