Skip to content

Commit 56eca25

Browse files
committed
Fix choices in ChoiceField to support IntEnum
Python support Enum in version 3.4, but changed __str__ to int.__str__ until version 3.11 to better support the replacement of existing constants use-case. [https://docs.python.org/3/library/enum.html#enum.IntEnum](https://docs.python.org/3/library/enum.html#enum.IntEnum) rest_frame work support Python 3.6+, this commit will support the Enum in choices of Field.
1 parent 38a74b4 commit 56eca25

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

rest_framework/fields.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import uuid
99
from collections import OrderedDict
1010
from collections.abc import Mapping
11+
from enum import Enum
1112

1213
from django.conf import settings
1314
from django.core.exceptions import ObjectDoesNotExist
@@ -1397,7 +1398,8 @@ def __init__(self, choices, **kwargs):
13971398
def to_internal_value(self, data):
13981399
if data == '' and self.allow_blank:
13991400
return ''
1400-
1401+
if isinstance(data, Enum) and str(data) != str(data.value):
1402+
data = data.value
14011403
try:
14021404
return self.choice_strings_to_values[str(data)]
14031405
except KeyError:
@@ -1406,6 +1408,8 @@ def to_internal_value(self, data):
14061408
def to_representation(self, value):
14071409
if value in ('', None):
14081410
return value
1411+
if isinstance(value, Enum) and str(value) != str(value.value):
1412+
value = value.value
14091413
return self.choice_strings_to_values.get(str(value), value)
14101414

14111415
def iter_options(self):
@@ -1429,7 +1433,7 @@ def _set_choices(self, choices):
14291433
# Allows us to deal with eg. integer choices while supporting either
14301434
# integer or string input, but still get the correct datatype out.
14311435
self.choice_strings_to_values = {
1432-
str(key): key for key in self.choices
1436+
str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices
14331437
}
14341438

14351439
choices = property(_get_choices, _set_choices)
@@ -1815,6 +1819,7 @@ class HiddenField(Field):
18151819
constraint on a pair of fields, as we need some way to include the date in
18161820
the validated data.
18171821
"""
1822+
18181823
def __init__(self, **kwargs):
18191824
assert 'default' in kwargs, 'default is a required argument.'
18201825
kwargs['write_only'] = True
@@ -1844,6 +1849,7 @@ class ExampleSerializer(Serializer):
18441849
def get_extra_info(self, obj):
18451850
return ... # Calculate some data to return.
18461851
"""
1852+
18471853
def __init__(self, method_name=None, **kwargs):
18481854
self.method_name = method_name
18491855
kwargs['source'] = '*'

tests/test_fields.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class TestEmpty:
131131
"""
132132
Tests for `required`, `allow_null`, `allow_blank`, `default`.
133133
"""
134+
134135
def test_required(self):
135136
"""
136137
By default a field must be included in the input.
@@ -657,6 +658,7 @@ class FieldValues:
657658
"""
658659
Base class for testing valid and invalid input values.
659660
"""
661+
660662
def test_valid_inputs(self, *args):
661663
"""
662664
Ensure that valid values return the expected validated data.
@@ -1824,6 +1826,31 @@ def test_edit_choices(self):
18241826
field.run_validation(2)
18251827
assert exc_info.value.detail == ['"2" is not a valid choice.']
18261828

1829+
def test_enum_choices(self):
1830+
from enum import IntEnum, auto
1831+
1832+
class ChoiceCase(IntEnum):
1833+
first = auto()
1834+
second = auto()
1835+
# Enum validate
1836+
choices = [
1837+
(ChoiceCase.first, "1"),
1838+
(ChoiceCase.second, "2")
1839+
]
1840+
field = serializers.ChoiceField(choices=choices)
1841+
assert field.run_validation(1) == 1
1842+
assert field.run_validation(ChoiceCase.first) == 1
1843+
assert field.run_validation("1") == 1
1844+
# Enum.value validate
1845+
choices = [
1846+
(ChoiceCase.first.value, "1"),
1847+
(ChoiceCase.second.value, "2")
1848+
]
1849+
field = serializers.ChoiceField(choices=choices)
1850+
assert field.run_validation(1) == 1
1851+
assert field.run_validation(ChoiceCase.first) == 1
1852+
assert field.run_validation("1") == 1
1853+
18271854

18281855
class TestChoiceFieldWithType(FieldValues):
18291856
"""

0 commit comments

Comments
 (0)