Skip to content

Skip reusing wrap validators / serializers for prebuilt variants #1660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/common/prebuilt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub fn get_prebuilt<T>(
type_: &str,
schema: &Bound<'_, PyDict>,
prebuilt_attr_name: &str,
extractor: impl FnOnce(Bound<'_, PyAny>) -> PyResult<T>,
extractor: impl FnOnce(Bound<'_, PyAny>) -> PyResult<Option<T>>,
) -> PyResult<Option<T>> {
let py = schema.py();

Expand Down Expand Up @@ -40,5 +40,5 @@ pub fn get_prebuilt<T>(

// Retrieve the prebuilt validator / serializer if available
let prebuilt: Bound<'_, PyAny> = class_dict.get_item(prebuilt_attr_name)?;
extractor(prebuilt).map(Some)
extractor(prebuilt)
}
8 changes: 5 additions & 3 deletions src/serializers/prebuilt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ pub struct PrebuiltSerializer {
impl PrebuiltSerializer {
pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult<Option<CombinedSerializer>> {
get_prebuilt(type_, schema, "__pydantic_serializer__", |py_any| {
py_any
.extract::<Py<SchemaSerializer>>()
.map(|schema_serializer| Self { schema_serializer }.into())
let schema_serializer = py_any.extract::<Py<SchemaSerializer>>()?;
if matches!(schema_serializer.get().serializer, CombinedSerializer::FunctionWrap(_)) {
return Ok(None);
}
Comment on lines +21 to +23
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather we check against disallowed variants here rather than all allowed variants - it keeps things cleaner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we go further and check that schema's type matches the type of the serializer?

e.g. maybe

match &schema_serializer.get().serializer {
    CombinedSerializer::Model(_) => // check schema[type] is model
    ... same for other expected types
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, because I think it would be acceptable to have a PlainSerialier, for example for a model?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we perhaps write a test to demonstrate that case? (We can use the repr to show that a prebuilt serializer is used.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, happy to!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done via 6efdfc9

Ok(Some(Self { schema_serializer }.into()))
})
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/validators/prebuilt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ pub struct PrebuiltValidator {
impl PrebuiltValidator {
pub fn try_get_from_schema(type_: &str, schema: &Bound<'_, PyDict>) -> PyResult<Option<CombinedValidator>> {
get_prebuilt(type_, schema, "__pydantic_validator__", |py_any| {
py_any
.extract::<Py<SchemaValidator>>()
.map(|schema_validator| Self { schema_validator }.into())
let schema_validator = py_any.extract::<Py<SchemaValidator>>()?;
if matches!(schema_validator.get().validator, CombinedValidator::FunctionWrap(_)) {
return Ok(None);
}
Ok(Some(Self { schema_validator }.into()))
})
}
}
Expand Down
237 changes: 237 additions & 0 deletions tests/test_prebuilt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

from pydantic_core import SchemaSerializer, SchemaValidator, core_schema


Expand Down Expand Up @@ -46,3 +48,238 @@ class OuterModel:
result = outer_validator.validate_python({'inner': {'x': 1}})
assert result.inner.x == 1
assert outer_serializer.to_python(result) == {'inner': {'x': 1}}


def test_prebuilt_not_used_for_wrap_serializer_functions() -> None:
class InnerModel:
x: str

def __init__(self, x: str) -> None:
self.x = x

def serialize_inner(v: InnerModel, serializer) -> Union[dict[str, str], str]:
v.x = v.x + ' modified'
return serializer(v)

inner_schema = core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
{'x': core_schema.model_field(schema=core_schema.str_schema())},
),
serialization=core_schema.wrap_serializer_function_ser_schema(serialize_inner),
)

inner_schema_serializer = SchemaSerializer(inner_schema)
InnerModel.__pydantic_complete__ = True # pyright: ignore[reportAttributeAccessIssue]
InnerModel.__pydantic_serializer__ = inner_schema_serializer # pyright: ignore[reportAttributeAccessIssue]

class OuterModel:
inner: InnerModel

def __init__(self, inner: InnerModel) -> None:
self.inner = inner

outer_schema = core_schema.model_schema(
OuterModel,
schema=core_schema.model_fields_schema(
{
'inner': core_schema.model_field(
schema=core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
# note, we use a simple str schema (with no custom serialization)
# in order to verify that the prebuilt serializer from InnerModel is not used
{'x': core_schema.model_field(schema=core_schema.str_schema())},
),
)
)
}
),
)

inner_serializer = SchemaSerializer(inner_schema)
outer_serializer = SchemaSerializer(outer_schema)

# the custom serialization function does apply for the inner model
inner_instance = InnerModel(x='hello')
assert inner_serializer.to_python(inner_instance) == {'x': 'hello modified'}

# but the outer model doesn't reuse the custom wrap serializer function, so we see simple str ser
outer_instance = OuterModel(inner=InnerModel(x='hello'))
assert outer_serializer.to_python(outer_instance) == {'inner': {'x': 'hello'}}


def test_prebuilt_not_used_for_wrap_validator_functions() -> None:
class InnerModel:
x: str

def __init__(self, x: str) -> None:
self.x = x

def validate_inner(data, validator) -> InnerModel:
data['x'] = data['x'] + ' modified'
return validator(data)

inner_schema = core_schema.no_info_wrap_validator_function(
validate_inner,
core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
{'x': core_schema.model_field(schema=core_schema.str_schema())},
),
),
)

inner_schema_validator = SchemaValidator(inner_schema)
InnerModel.__pydantic_complete__ = True # pyright: ignore[reportAttributeAccessIssue]
InnerModel.__pydantic_validator__ = inner_schema_validator # pyright: ignore[reportAttributeAccessIssue]

class OuterModel:
inner: InnerModel

def __init__(self, inner: InnerModel) -> None:
self.inner = inner

outer_schema = core_schema.model_schema(
OuterModel,
schema=core_schema.model_fields_schema(
{
'inner': core_schema.model_field(
schema=core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
# note, we use a simple str schema (with no custom validation)
# in order to verify that the prebuilt validator from InnerModel is not used
{'x': core_schema.model_field(schema=core_schema.str_schema())},
),
)
)
}
),
)

inner_validator = SchemaValidator(inner_schema)
outer_validator = SchemaValidator(outer_schema)

# the custom validation function does apply for the inner model
result_inner = inner_validator.validate_python({'x': 'hello'})
assert result_inner.x == 'hello modified'

# but the outer model doesn't reuse the custom wrap validator function, so we see simple str val
result_outer = outer_validator.validate_python({'inner': {'x': 'hello'}})
assert result_outer.inner.x == 'hello'


def test_reuse_plain_serializer_ok() -> None:
class InnerModel:
x: str

def __init__(self, x: str) -> None:
self.x = x

def serialize_inner(v: InnerModel) -> str:
return v.x + ' modified'

inner_schema = core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
{'x': core_schema.model_field(schema=core_schema.str_schema())},
),
serialization=core_schema.plain_serializer_function_ser_schema(serialize_inner),
)

inner_schema_serializer = SchemaSerializer(inner_schema)
InnerModel.__pydantic_complete__ = True # pyright: ignore[reportAttributeAccessIssue]
InnerModel.__pydantic_serializer__ = inner_schema_serializer # pyright: ignore[reportAttributeAccessIssue]

class OuterModel:
inner: InnerModel

def __init__(self, inner: InnerModel) -> None:
self.inner = inner

outer_schema = core_schema.model_schema(
OuterModel,
schema=core_schema.model_fields_schema(
{
'inner': core_schema.model_field(
schema=core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
# note, we use a simple str schema (with no custom serialization)
# in order to verify that the prebuilt serializer from InnerModel is used instead
{'x': core_schema.model_field(schema=core_schema.str_schema())},
),
)
)
}
),
)

inner_serializer = SchemaSerializer(inner_schema)
outer_serializer = SchemaSerializer(outer_schema)

# the custom serialization function does apply for the inner model
inner_instance = InnerModel(x='hello')
assert inner_serializer.to_python(inner_instance) == 'hello modified'
assert 'FunctionPlainSerializer' in repr(inner_serializer)

# the custom ser function applies for the outer model as well, a plain serializer is permitted as a prebuilt candidate
outer_instance = OuterModel(inner=InnerModel(x='hello'))
assert outer_serializer.to_python(outer_instance) == {'inner': 'hello modified'}
assert 'PrebuiltSerializer' in repr(outer_serializer)


def test_reuse_plain_validator_ok() -> None:
class InnerModel:
x: str

def __init__(self, x: str) -> None:
self.x = x

def validate_inner(data) -> InnerModel:
data['x'] = data['x'] + ' modified'
return InnerModel(**data)

inner_schema = core_schema.no_info_plain_validator_function(validate_inner)

inner_schema_validator = SchemaValidator(inner_schema)
InnerModel.__pydantic_complete__ = True # pyright: ignore[reportAttributeAccessIssue]
InnerModel.__pydantic_validator__ = inner_schema_validator # pyright: ignore[reportAttributeAccessIssue]

class OuterModel:
inner: InnerModel

def __init__(self, inner: InnerModel) -> None:
self.inner = inner

outer_schema = core_schema.model_schema(
OuterModel,
schema=core_schema.model_fields_schema(
{
'inner': core_schema.model_field(
schema=core_schema.model_schema(
InnerModel,
schema=core_schema.model_fields_schema(
# note, we use a simple str schema (with no custom validation)
# in order to verify that the prebuilt validator from InnerModel is used instead
{'x': core_schema.model_field(schema=core_schema.str_schema())},
),
)
)
}
),
)

inner_validator = SchemaValidator(inner_schema)
outer_validator = SchemaValidator(outer_schema)

# the custom validation function does apply for the inner model
result_inner = inner_validator.validate_python({'x': 'hello'})
assert result_inner.x == 'hello modified'
assert 'FunctionPlainValidator' in repr(inner_validator)

# the custom validation function does apply for the outer model as well, a plain validator is permitted as a prebuilt candidate
result_outer = outer_validator.validate_python({'inner': {'x': 'hello'}})
assert result_outer.inner.x == 'hello modified'
assert 'PrebuiltValidator' in repr(outer_validator)
Loading