Skip to content

Commit e17821c

Browse files
anandoleecopybara-github
authored andcommitted
Nextgen Proto Pythonic API: Struct/ListValue assignment and creation
Python dict is now able to be assigned (by create and copy, not reference) and compared with the Protobuf Struct field. Python list is now able to be assigned (by create and copy, not reference) and compared with the Protobuf ListValue field. example usage: dictionary = {'key1': 5.0, 'key2': {'subkey': 11.0, 'k': False},} list_value = [6, 'seven', True, False, None, dictionary] msg = more_messages_pb2.WKTMessage( optional_struct=dictionary, optional_list_value=list_value ) self.assertEqual(msg.optional_struct, dictionary) self.assertEqual(msg.optional_list_value, list_value) PiperOrigin-RevId: 646099987
1 parent 0302c4c commit e17821c

File tree

7 files changed

+256
-103
lines changed

7 files changed

+256
-103
lines changed

python/google/protobuf/internal/descriptor_pool_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from google.protobuf.internal import testing_refleaks
3131

3232
from google.protobuf import duration_pb2
33+
from google.protobuf import struct_pb2
3334
from google.protobuf import timestamp_pb2
3435
from google.protobuf import unittest_features_pb2
3536
from google.protobuf import unittest_import_pb2
@@ -439,6 +440,7 @@ def testAddSerializedFile(self):
439440
self.testFindMessageTypeByName()
440441
self.pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb)
441442
self.pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb)
443+
self.pool.AddSerializedFile(struct_pb2.DESCRIPTOR.serialized_pb)
442444
file_json = self.pool.AddSerializedFile(
443445
more_messages_pb2.DESCRIPTOR.serialized_pb)
444446
field = file_json.message_types_by_name['class'].fields_by_name['int_field']
@@ -550,6 +552,9 @@ def testComplexNesting(self):
550552
timestamp_pb2.DESCRIPTOR.serialized_pb)
551553
duration_desc = descriptor_pb2.FileDescriptorProto.FromString(
552554
duration_pb2.DESCRIPTOR.serialized_pb)
555+
struct_desc = descriptor_pb2.FileDescriptorProto.FromString(
556+
struct_pb2.DESCRIPTOR.serialized_pb
557+
)
553558
more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString(
554559
more_messages_pb2.DESCRIPTOR.serialized_pb)
555560
test1_desc = descriptor_pb2.FileDescriptorProto.FromString(
@@ -558,6 +563,7 @@ def testComplexNesting(self):
558563
descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb)
559564
self.pool.Add(timestamp_desc)
560565
self.pool.Add(duration_desc)
566+
self.pool.Add(struct_desc)
561567
self.pool.Add(more_messages_desc)
562568
self.pool.Add(test1_desc)
563569
self.pool.Add(test2_desc)

python/google/protobuf/internal/more_messages.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ syntax = "proto2";
1414
package google.protobuf.internal;
1515

1616
import "google/protobuf/duration.proto";
17+
import "google/protobuf/struct.proto";
1718
import "google/protobuf/timestamp.proto";
1819

1920
// A message where tag numbers are listed out of order, to allow us to test our
@@ -355,4 +356,6 @@ message ConflictJsonName {
355356
message WKTMessage {
356357
optional Timestamp optional_timestamp = 1;
357358
optional Duration optional_duration = 2;
359+
optional Struct optional_struct = 3;
360+
optional ListValue optional_list_value = 4;
358361
}

python/google/protobuf/internal/python_message.py

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@
5151

5252
_FieldDescriptor = descriptor_mod.FieldDescriptor
5353
_AnyFullTypeName = 'google.protobuf.Any'
54+
_StructFullTypeName = 'google.protobuf.Struct'
55+
_ListValueFullTypeName = 'google.protobuf.ListValue'
5456
_ExtensionDict = extension_dict._ExtensionDict
5557

5658
class GeneratedProtocolMessageType(type):
@@ -515,37 +517,47 @@ def init(self, **kwargs):
515517
# field=None is the same as no field at all.
516518
continue
517519
if field.label == _FieldDescriptor.LABEL_REPEATED:
518-
copy = field._default_constructor(self)
520+
field_copy = field._default_constructor(self)
519521
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
520522
if _IsMapField(field):
521523
if _IsMessageMapField(field):
522524
for key in field_value:
523-
copy[key].MergeFrom(field_value[key])
525+
field_copy[key].MergeFrom(field_value[key])
524526
else:
525-
copy.update(field_value)
527+
field_copy.update(field_value)
526528
else:
527529
for val in field_value:
528530
if isinstance(val, dict):
529-
copy.add(**val)
531+
field_copy.add(**val)
530532
else:
531-
copy.add().MergeFrom(val)
533+
field_copy.add().MergeFrom(val)
532534
else: # Scalar
533535
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
534536
field_value = [_GetIntegerEnumValue(field.enum_type, val)
535537
for val in field_value]
536-
copy.extend(field_value)
537-
self._fields[field] = copy
538+
field_copy.extend(field_value)
539+
self._fields[field] = field_copy
538540
elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
539-
copy = field._default_constructor(self)
541+
field_copy = field._default_constructor(self)
540542
new_val = None
541543
if isinstance(field_value, message_mod.Message):
542544
new_val = field_value
543545
elif isinstance(field_value, dict):
544-
new_val = field.message_type._concrete_class(**field_value)
545-
elif field.message_type.full_name == 'google.protobuf.Timestamp':
546-
copy.FromDatetime(field_value)
547-
elif field.message_type.full_name == 'google.protobuf.Duration':
548-
copy.FromTimedelta(field_value)
546+
if field.message_type.full_name == _StructFullTypeName:
547+
field_copy.Clear()
548+
if len(field_value) == 1 and 'fields' in field_value:
549+
try:
550+
field_copy.update(field_value)
551+
except:
552+
# Fall back to init normal message field
553+
field_copy.Clear()
554+
new_val = field.message_type._concrete_class(**field_value)
555+
else:
556+
field_copy.update(field_value)
557+
else:
558+
new_val = field.message_type._concrete_class(**field_value)
559+
elif hasattr(field_copy, '_internal_assign'):
560+
field_copy._internal_assign(field_value)
549561
else:
550562
raise TypeError(
551563
'Message field {0}.{1} must be initialized with a '
@@ -558,10 +570,10 @@ def init(self, **kwargs):
558570

559571
if new_val:
560572
try:
561-
copy.MergeFrom(new_val)
573+
field_copy.MergeFrom(new_val)
562574
except TypeError:
563575
_ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name)
564-
self._fields[field] = copy
576+
self._fields[field] = field_copy
565577
else:
566578
if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM:
567579
field_value = _GetIntegerEnumValue(field.enum_type, field_value)
@@ -777,6 +789,14 @@ def setter(self, new_value):
777789
elif field.message_type.full_name == 'google.protobuf.Duration':
778790
getter(self)
779791
self._fields[field].FromTimedelta(new_value)
792+
elif field.message_type.full_name == _StructFullTypeName:
793+
getter(self)
794+
self._fields[field].Clear()
795+
self._fields[field].update(new_value)
796+
elif field.message_type.full_name == _ListValueFullTypeName:
797+
getter(self)
798+
self._fields[field].Clear()
799+
self._fields[field].extend(new_value)
780800
else:
781801
raise AttributeError(
782802
'Assignment not allowed to composite field '
@@ -978,6 +998,15 @@ def _InternalUnpackAny(msg):
978998
def _AddEqualsMethod(message_descriptor, cls):
979999
"""Helper for _AddMessageMethods()."""
9801000
def __eq__(self, other):
1001+
if self.DESCRIPTOR.full_name == _ListValueFullTypeName and isinstance(
1002+
other, list
1003+
):
1004+
return self._internal_compare(other)
1005+
if self.DESCRIPTOR.full_name == _StructFullTypeName and isinstance(
1006+
other, dict
1007+
):
1008+
return self._internal_compare(other)
1009+
9811010
if (not isinstance(other, message_mod.Message) or
9821011
other.DESCRIPTOR != self.DESCRIPTOR):
9831012
return NotImplemented

python/google/protobuf/internal/well_known_types.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ def FromDatetime(self, dt):
283283
self.seconds = seconds
284284
self.nanos = nanos
285285

286+
def _internal_assign(self, dt):
287+
self.FromDatetime(dt)
288+
286289
def __add__(self, value) -> datetime.datetime:
287290
if isinstance(value, Duration):
288291
return self.ToDatetime() + value.ToTimedelta()
@@ -442,6 +445,9 @@ def FromTimedelta(self, td):
442445
'object got {0}: {1}'.format(type(td).__name__, e)
443446
) from e
444447

448+
def _internal_assign(self, td):
449+
self.FromTimedelta(td)
450+
445451
def _NormalizeDuration(self, seconds, nanos):
446452
"""Set Duration by seconds and nanos."""
447453
# Force nanos to be negative if the duration is negative.
@@ -550,6 +556,24 @@ def __len__(self):
550556
def __iter__(self):
551557
return iter(self.fields)
552558

559+
def _internal_assign(self, dictionary):
560+
self.Clear()
561+
self.update(dictionary)
562+
563+
def _internal_compare(self, other):
564+
size = len(self)
565+
if size != len(other):
566+
return False
567+
for key, value in self.items():
568+
if key not in other:
569+
return False
570+
if isinstance(other[key], (dict, list)):
571+
if not value._internal_compare(other[key]):
572+
return False
573+
elif value != other[key]:
574+
return False
575+
return True
576+
553577
def keys(self): # pylint: disable=invalid-name
554578
return self.fields.keys()
555579

@@ -605,6 +629,22 @@ def __setitem__(self, index, value):
605629
def __delitem__(self, key):
606630
del self.values[key]
607631

632+
def _internal_assign(self, elem_seq):
633+
self.Clear()
634+
self.extend(elem_seq)
635+
636+
def _internal_compare(self, other):
637+
size = len(self)
638+
if size != len(other):
639+
return False
640+
for i in range(size):
641+
if isinstance(other[i], (dict, list)):
642+
if not self[i]._internal_compare(other[i]):
643+
return False
644+
elif self[i] != other[i]:
645+
return False
646+
return True
647+
608648
def items(self):
609649
for i in range(len(self)):
610650
yield self[i]

python/google/protobuf/internal/well_known_types_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,73 @@ def testStructAssignment(self):
838838
s2['x'] = s1['x']
839839
self.assertEqual(s1['x'], s2['x'])
840840

841+
dictionary = {
842+
'key1': 5.0,
843+
'key2': 'abc',
844+
'key3': {'subkey': 11.0, 'k': False},
845+
}
846+
msg = more_messages_pb2.WKTMessage()
847+
msg.optional_struct = dictionary
848+
self.assertEqual(msg.optional_struct, dictionary)
849+
850+
# Tests assign is not merge
851+
dictionary2 = {
852+
'key4': {'subkey': 11.0, 'k': True},
853+
}
854+
msg.optional_struct = dictionary2
855+
self.assertEqual(msg.optional_struct, dictionary2)
856+
857+
# Tests assign empty
858+
msg2 = more_messages_pb2.WKTMessage()
859+
self.assertNotIn('optional_struct', msg2)
860+
msg2.optional_struct = {}
861+
self.assertIn('optional_struct', msg2)
862+
self.assertEqual(msg2.optional_struct, {})
863+
864+
def testListValueAssignment(self):
865+
list_value = [6, 'seven', True, False, None, {}]
866+
msg = more_messages_pb2.WKTMessage()
867+
msg.optional_list_value = list_value
868+
self.assertEqual(msg.optional_list_value, list_value)
869+
870+
def testStructConstruction(self):
871+
dictionary = {
872+
'key1': 5.0,
873+
'key2': 'abc',
874+
'key3': {'subkey': 11.0, 'k': False},
875+
}
876+
list_value = [6, 'seven', True, False, None, dictionary]
877+
msg = more_messages_pb2.WKTMessage(
878+
optional_struct=dictionary, optional_list_value=list_value
879+
)
880+
self.assertEqual(len(msg.optional_struct), len(dictionary))
881+
self.assertEqual(msg.optional_struct, dictionary)
882+
self.assertEqual(len(msg.optional_list_value), len(list_value))
883+
self.assertEqual(msg.optional_list_value, list_value)
884+
885+
msg2 = more_messages_pb2.WKTMessage(
886+
optional_struct={}, optional_list_value=[]
887+
)
888+
self.assertIn('optional_struct', msg2)
889+
self.assertIn('optional_list_value', msg2)
890+
self.assertEqual(msg2.optional_struct, {})
891+
self.assertEqual(msg2.optional_list_value, [])
892+
893+
def testSpecialStructConstruct(self):
894+
dictionary = {'key1': 6.0}
895+
msg = more_messages_pb2.WKTMessage(optional_struct=dictionary)
896+
self.assertEqual(msg.optional_struct, dictionary)
897+
898+
dictionary2 = {'fields': 7.0}
899+
msg2 = more_messages_pb2.WKTMessage(optional_struct=dictionary2)
900+
self.assertEqual(msg2.optional_struct, dictionary2)
901+
902+
# Construct Struct as normal message
903+
value_msg = struct_pb2.Value(number_value=5.0)
904+
dictionary3 = {'fields': {'key1': value_msg}}
905+
msg3 = more_messages_pb2.WKTMessage(optional_struct=dictionary3)
906+
self.assertEqual(msg3.optional_struct, {'key1': 5.0})
907+
841908
def testMergeFrom(self):
842909
struct = struct_pb2.Struct()
843910
struct_class = struct.__class__

0 commit comments

Comments
 (0)