Skip to content

Commit 8999b29

Browse files
committed
Fix choices in ChoiceField to support IntEnum
Python support Enum in version 3.4, but changed __str__ to int.__str__ until version 3.11 to better support the replacement of existing constants use-case. [https://docs.python.org/3/library/enum.html#enum.IntEnum](https://docs.python.org/3/library/enum.html#enum.IntEnum) rest_frame work support Python 3.6+, this commit will support the Enum in choices of Field.
1 parent 4f7e9ed commit 8999b29

File tree

2 files changed

+13
-16
lines changed

2 files changed

+13
-16
lines changed

rest_framework/fields.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import uuid
1010
from collections.abc import Mapping
11+
from enum import Enum
1112

1213
from django.conf import settings
1314
from django.core.exceptions import ObjectDoesNotExist
@@ -17,7 +18,6 @@
1718
MinValueValidator, ProhibitNullCharactersValidator, RegexValidator,
1819
URLValidator, ip_address_validators
1920
)
20-
from django.db.models import IntegerChoices, TextChoices
2121
from django.forms import FilePathField as DjangoFilePathField
2222
from django.forms import ImageField as DjangoImageField
2323
from django.utils import timezone
@@ -1401,11 +1401,8 @@ def __init__(self, choices, **kwargs):
14011401
def to_internal_value(self, data):
14021402
if data == '' and self.allow_blank:
14031403
return ''
1404-
1405-
if isinstance(data, (IntegerChoices, TextChoices)) and str(data) != \
1406-
str(data.value):
1404+
if isinstance(data, Enum) and str(data) != str(data.value):
14071405
data = data.value
1408-
14091406
try:
14101407
return self.choice_strings_to_values[str(data)]
14111408
except KeyError:
@@ -1414,11 +1411,8 @@ def to_internal_value(self, data):
14141411
def to_representation(self, value):
14151412
if value in ('', None):
14161413
return value
1417-
1418-
if isinstance(value, (IntegerChoices, TextChoices)) and str(value) != \
1419-
str(value.value):
1414+
if isinstance(value, Enum) and str(value) != str(value.value):
14201415
value = value.value
1421-
14221416
return self.choice_strings_to_values.get(str(value), value)
14231417

14241418
def iter_options(self):
@@ -1442,8 +1436,7 @@ def _set_choices(self, choices):
14421436
# Allows us to deal with eg. integer choices while supporting either
14431437
# integer or string input, but still get the correct datatype out.
14441438
self.choice_strings_to_values = {
1445-
str(key.value) if isinstance(key, (IntegerChoices, TextChoices))
1446-
and str(key) != str(key.value) else str(key): key for key in self.choices
1439+
str(key.value) if isinstance(key, Enum) and str(key) != str(key.value) else str(key): key for key in self.choices
14471440
}
14481441

14491442
choices = property(_get_choices, _set_choices)
@@ -1829,6 +1822,7 @@ class HiddenField(Field):
18291822
constraint on a pair of fields, as we need some way to include the date in
18301823
the validated data.
18311824
"""
1825+
18321826
def __init__(self, **kwargs):
18331827
assert 'default' in kwargs, 'default is a required argument.'
18341828
kwargs['write_only'] = True
@@ -1858,6 +1852,7 @@ class ExampleSerializer(Serializer):
18581852
def get_extra_info(self, obj):
18591853
return ... # Calculate some data to return.
18601854
"""
1855+
18611856
def __init__(self, method_name=None, **kwargs):
18621857
self.method_name = method_name
18631858
kwargs['source'] = '*'

tests/test_fields.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class TestEmpty:
138138
"""
139139
Tests for `required`, `allow_null`, `allow_blank`, `default`.
140140
"""
141+
141142
def test_required(self):
142143
"""
143144
By default a field must be included in the input.
@@ -664,6 +665,7 @@ class FieldValues:
664665
"""
665666
Base class for testing valid and invalid input values.
666667
"""
668+
667669
def test_valid_inputs(self, *args):
668670
"""
669671
Ensure that valid values return the expected validated data.
@@ -1875,26 +1877,26 @@ def test_edit_choices(self):
18751877
field.run_validation(2)
18761878
assert exc_info.value.detail == ['"2" is not a valid choice.']
18771879

1878-
def test_integer_choices(self):
1879-
class ChoiceCase(IntegerChoices):
1880+
def test_enum_choices(self):
1881+
from enum import IntEnum, auto
1882+
1883+
class ChoiceCase(IntEnum):
18801884
first = auto()
18811885
second = auto()
18821886
# Enum validate
18831887
choices = [
18841888
(ChoiceCase.first, "1"),
18851889
(ChoiceCase.second, "2")
18861890
]
1887-
18881891
field = serializers.ChoiceField(choices=choices)
18891892
assert field.run_validation(1) == 1
18901893
assert field.run_validation(ChoiceCase.first) == 1
18911894
assert field.run_validation("1") == 1
1892-
1895+
# Enum.value validate
18931896
choices = [
18941897
(ChoiceCase.first.value, "1"),
18951898
(ChoiceCase.second.value, "2")
18961899
]
1897-
18981900
field = serializers.ChoiceField(choices=choices)
18991901
assert field.run_validation(1) == 1
19001902
assert field.run_validation(ChoiceCase.first) == 1

0 commit comments

Comments
 (0)