Skip to content

Introduce a schema variant to reuse Validators, Serializers and CoreSchema #1414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions python/pydantic_core/core_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,30 @@ def uuid_schema(
)


class NestedSchema(TypedDict, total=False):
type: Required[Literal['nested']]
cls: Required[Type[Any]]
# Should return `(CoreSchema, SchemaValidator, SchemaSerializer)` but this requires a forward ref
get_info: Required[Callable[[], Any]]
metadata: Dict[str, Any]
serialization: SerSchema

def nested_schema(
*,
cls: Type[Any],
get_info: Callable[[], Any],
metadata: Dict[str, Any] | None = None,
serialization: SerSchema | None = None
) -> NestedSchema:
return _dict_not_none(
type='nested',
cls=cls,
get_info=get_info,
metadata=metadata,
serialization=serialization
)


class IncExSeqSerSchema(TypedDict, total=False):
type: Required[Literal['include-exclude-sequence']]
include: Set[int]
Expand Down Expand Up @@ -3866,6 +3890,7 @@ def definition_reference_schema(
DefinitionReferenceSchema,
UuidSchema,
ComplexSchema,
NestedSchema,
]
elif False:
CoreSchema: TypeAlias = Mapping[str, Any]
Expand Down Expand Up @@ -3922,6 +3947,7 @@ def definition_reference_schema(
'definition-ref',
'uuid',
'complex',
'nested',
]

CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']
Expand Down
21 changes: 20 additions & 1 deletion src/py_gc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use ahash::AHashMap;
use enum_dispatch::enum_dispatch;
Expand Down Expand Up @@ -58,6 +58,25 @@ impl<T: PyGcTraverse> PyGcTraverse for Option<T> {
}
}

impl<T: PyGcTraverse, E> PyGcTraverse for Result<T, E> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self {
Ok(v) => T::py_gc_traverse(v, visit),
// FIXME(BoxyUwU): Lol
Err(_) => Ok(()),
}
}
}

impl<T: PyGcTraverse> PyGcTraverse for OnceLock<T> {
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
match self.get() {
Some(item) => T::py_gc_traverse(item, visit),
None => Ok(()),
}
}
}

/// A crude alternative to a "derive" macro to help with building PyGcTraverse implementations
macro_rules! impl_py_gc_traverse {
($name:ty { }) => {
Expand Down
2 changes: 2 additions & 0 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ combined_serializer! {
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
Tuple: super::type_serializers::tuple::TupleSerializer;
Complex: super::type_serializers::complex::ComplexSerializer;
Nested: super::type_serializers::nested::NestedSerializer;
}
}

Expand Down Expand Up @@ -254,6 +255,7 @@ impl PyGcTraverse for CombinedSerializer {
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
CombinedSerializer::Nested(inner) => inner.py_gc_traverse(visit),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/serializers/type_serializers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub mod json_or_python;
pub mod list;
pub mod literal;
pub mod model;
pub mod nested;
pub mod nullable;
pub mod other;
pub mod set_frozenset;
Expand Down
135 changes: 135 additions & 0 deletions src/serializers/type_serializers/nested.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
use std::{borrow::Cow, sync::OnceLock};

use pyo3::{
intern,
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
Bound, Py, PyAny, PyObject, PyResult, Python,
};

use crate::{
definitions::DefinitionsBuilder,
serializers::{
shared::{BuildSerializer, TypeSerializer},
CombinedSerializer, Extra,
},
SchemaSerializer,
};

#[derive(Debug)]
pub struct NestedSerializer {
model: Py<PyType>,
name: String,
get_serializer: Py<PyAny>,
serializer: OnceLock<PyResult<Py<SchemaSerializer>>>,
}

impl_py_gc_traverse!(NestedSerializer {
model,
get_serializer,
serializer
});

impl BuildSerializer for NestedSerializer {
const EXPECTED_TYPE: &'static str = "nested";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
) -> PyResult<CombinedSerializer> {
let py = schema.py();

let get_serializer = schema
.get_item(intern!(py, "get_info"))?
.expect("Invalid core schema for `nested` type, no `get_info`")
.unbind();

let model = schema
.get_item(intern!(py, "cls"))?
.expect("Invalid core schema for `nested` type, no `model`")
.downcast::<PyType>()
.expect("Invalid core schema for `nested` type, not a `PyType`")
.clone();

let name = model.getattr(intern!(py, "__name__"))?.extract()?;

Ok(CombinedSerializer::Nested(NestedSerializer {
model: model.clone().unbind(),
name,
get_serializer,
serializer: OnceLock::new(),
}))
}
}

impl NestedSerializer {
fn nested_serializer<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaSerializer>> {
self.serializer
.get_or_init(|| {
Ok(self
.get_serializer
.bind(py)
.call((), None)?
.downcast::<PyTuple>()?
.get_item(2)?
.downcast::<SchemaSerializer>()?
.clone()
.unbind())
})
.as_ref()
.map_err(|e| e.clone_ref(py))
}
}

impl TypeSerializer for NestedSerializer {
fn to_python(
&self,
value: &Bound<'_, PyAny>,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
mut extra: &Extra,
) -> PyResult<PyObject> {
let mut guard = extra.recursion_guard(value, self.model.as_ptr() as usize)?;

self.nested_serializer(value.py())?
.bind(value.py())
.get()
.serializer
.to_python(value, include, exclude, guard.state())
}

fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
self.nested_serializer(key.py())?
.bind(key.py())
.get()
.serializer
.json_key(key, extra)
}

fn serde_serialize<S: serde::ser::Serializer>(
&self,
value: &Bound<'_, PyAny>,
serializer: S,
include: Option<&Bound<'_, PyAny>>,
exclude: Option<&Bound<'_, PyAny>>,
mut extra: &Extra,
) -> Result<S::Ok, S::Error> {
use super::py_err_se_err;

let mut guard = extra
.recursion_guard(value, self.model.as_ptr() as usize)
.map_err(py_err_se_err)?;

self.nested_serializer(value.py())
// FIXME(BoxyUwU): Don't unwrap this
.unwrap()
.bind(value.py())
.get()
.serializer
.serde_serialize(value, serializer, include, exclude, guard.state())
}

fn get_name(&self) -> &str {
&self.name
}
}
4 changes: 4 additions & 0 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod list;
mod literal;
mod model;
mod model_fields;
mod nested;
mod none;
mod nullable;
mod set;
Expand Down Expand Up @@ -584,6 +585,7 @@ pub fn build_validator(
definitions::DefinitionRefValidator,
definitions::DefinitionsValidatorBuilder,
complex::ComplexValidator,
nested::NestedValidator,
)
}

Expand Down Expand Up @@ -738,6 +740,8 @@ pub enum CombinedValidator {
// input dependent
JsonOrPython(json_or_python::JsonOrPython),
Complex(complex::ComplexValidator),
// Schema for reusing an existing validator
Nested(nested::NestedValidator),
}

/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,
Expand Down
2 changes: 1 addition & 1 deletion src/validators/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator {

let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;
let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?;
let name = class.getattr(intern!(py, "__name__"))?.extract()?;

Ok(Self {
Expand Down
115 changes: 115 additions & 0 deletions src/validators/nested.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use std::sync::OnceLock;

use pyo3::{
intern,
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyTupleMethods, PyType},
Bound, Py, PyAny, PyObject, PyResult, Python,
};

use crate::{
definitions::DefinitionsBuilder,
errors::{ErrorTypeDefaults, ValError, ValResult},
input::Input,
recursion_guard::RecursionGuard,
};

use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};

#[derive(Debug)]
pub struct NestedValidator {
cls: Py<PyType>,
name: String,
get_validator: Py<PyAny>,
validator: OnceLock<PyResult<Py<SchemaValidator>>>,
}

impl_py_gc_traverse!(NestedValidator {
cls,
get_validator,
validator
});

impl BuildValidator for NestedValidator {
const EXPECTED_TYPE: &'static str = "nested";

fn build(
schema: &Bound<'_, PyDict>,
_config: Option<&Bound<'_, PyDict>>,
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
) -> PyResult<super::CombinedValidator> {
let py = schema.py();

let get_validator = schema.get_item(intern!(py, "get_info"))?.unwrap().unbind();

let cls = schema
.get_item(intern!(py, "cls"))?
.unwrap()
.downcast::<PyType>()?
.clone();

let name = cls.getattr(intern!(py, "__name__"))?.extract()?;

Ok(CombinedValidator::Nested(NestedValidator {
cls: cls.clone().unbind(),
name,
get_validator: get_validator,
validator: OnceLock::new(),
}))
}
}

impl NestedValidator {
fn nested_validator<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaValidator>> {
self.validator
.get_or_init(|| {
Ok(self
.get_validator
.bind(py)
.call((), None)?
.downcast::<PyTuple>()?
.get_item(1)?
.downcast::<SchemaValidator>()?
.clone()
.unbind())
})
.as_ref()
.map_err(|e| e.clone_ref(py))
}
}

impl Validator for NestedValidator {
fn validate<'py>(
&self,
py: Python<'py>,
input: &(impl Input<'py> + ?Sized),
state: &mut ValidationState<'_, 'py>,
) -> ValResult<PyObject> {
let Some(id) = input.as_python().map(py_identity) else {
return self
.nested_validator(py)?
.bind(py)
.get()
.validator
.validate(py, input, state);
};

// Python objects can be cyclic, so need recursion guard
let Ok(mut guard) = RecursionGuard::new(state, id, self.cls.as_ptr() as usize) else {
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
};

self.nested_validator(py)?
.bind(py)
.get()
.validator
.validate(py, input, guard.state())
}

fn get_name(&self) -> &str {
&self.name
}
}

fn py_identity(obj: &Bound<'_, PyAny>) -> usize {
obj.as_ptr() as usize
}
Loading