Skip to content

Commit a78828e

Browse files
committed
WIP
1 parent ba8eab4 commit a78828e

File tree

8 files changed

+304
-2
lines changed

8 files changed

+304
-2
lines changed

python/pydantic_core/core_schema.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,6 +1437,30 @@ def uuid_schema(
14371437
)
14381438

14391439

1440+
class NestedModelSchema(TypedDict, total=False):
1441+
type: Required[Literal['nested-model']]
1442+
model: Required[Type[Any]]
1443+
# Should return `(CoreSchema, SchemaValidator, SchemaSerializer)` but this requires a forward ref
1444+
get_info: Required[Callable[[], Any]]
1445+
metadata: Any
1446+
serialization: SerSchema
1447+
1448+
def nested_model_schema(
1449+
*,
1450+
model: Type[Any],
1451+
get_info: Callable[[], Any],
1452+
metadata: Any = None,
1453+
serialization: SerSchema | None = None
1454+
) -> NestedModelSchema:
1455+
return _dict_not_none(
1456+
type='nested-model',
1457+
model=model,
1458+
get_info=get_info,
1459+
metadata=metadata,
1460+
serialization=serialization
1461+
)
1462+
1463+
14401464
class IncExSeqSerSchema(TypedDict, total=False):
14411465
type: Required[Literal['include-exclude-sequence']]
14421466
include: Set[int]
@@ -3866,6 +3890,7 @@ def definition_reference_schema(
38663890
DefinitionReferenceSchema,
38673891
UuidSchema,
38683892
ComplexSchema,
3893+
NestedModelSchema,
38693894
]
38703895
elif False:
38713896
CoreSchema: TypeAlias = Mapping[str, Any]
@@ -3922,6 +3947,7 @@ def definition_reference_schema(
39223947
'definition-ref',
39233948
'uuid',
39243949
'complex',
3950+
'nested-model',
39253951
]
39263952

39273953
CoreSchemaFieldType = Literal['model-field', 'dataclass-field', 'typed-dict-field', 'computed-field']

src/py_gc.rs

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use std::sync::Arc;
1+
use std::sync::{Arc, OnceLock};
22

33
use ahash::AHashMap;
44
use enum_dispatch::enum_dispatch;
@@ -58,6 +58,25 @@ impl<T: PyGcTraverse> PyGcTraverse for Option<T> {
5858
}
5959
}
6060

61+
impl<T: PyGcTraverse, E> PyGcTraverse for Result<T, E> {
62+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
63+
match self {
64+
Ok(v) => T::py_gc_traverse(v, visit),
65+
// FIXME(BoxyUwU): Lol
66+
Err(_) => Ok(()),
67+
}
68+
}
69+
}
70+
71+
impl<T: PyGcTraverse> PyGcTraverse for OnceLock<T> {
72+
fn py_gc_traverse(&self, visit: &PyVisit<'_>) -> Result<(), PyTraverseError> {
73+
match self.get() {
74+
Some(item) => T::py_gc_traverse(item, visit),
75+
None => Ok(()),
76+
}
77+
}
78+
}
79+
6180
/// A crude alternative to a "derive" macro to help with building PyGcTraverse implementations
6281
macro_rules! impl_py_gc_traverse {
6382
($name:ty { }) => {

src/serializers/shared.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ combined_serializer! {
143143
Recursive: super::type_serializers::definitions::DefinitionRefSerializer;
144144
Tuple: super::type_serializers::tuple::TupleSerializer;
145145
Complex: super::type_serializers::complex::ComplexSerializer;
146+
NestedModel: super::type_serializers::nested_model::NestedModelSerializer;
146147
}
147148
}
148149

@@ -254,6 +255,7 @@ impl PyGcTraverse for CombinedSerializer {
254255
CombinedSerializer::Tuple(inner) => inner.py_gc_traverse(visit),
255256
CombinedSerializer::Uuid(inner) => inner.py_gc_traverse(visit),
256257
CombinedSerializer::Complex(inner) => inner.py_gc_traverse(visit),
258+
CombinedSerializer::NestedModel(inner) => inner.py_gc_traverse(visit),
257259
}
258260
}
259261
}

src/serializers/type_serializers/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ pub mod json_or_python;
1616
pub mod list;
1717
pub mod literal;
1818
pub mod model;
19+
pub mod nested_model;
1920
pub mod nullable;
2021
pub mod other;
2122
pub mod set_frozenset;
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
use std::{borrow::Cow, sync::OnceLock};
2+
3+
use pyo3::{
4+
intern,
5+
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
6+
Bound, Py, PyAny, PyObject, PyResult, Python,
7+
};
8+
9+
use crate::{
10+
definitions::DefinitionsBuilder,
11+
serializers::{
12+
shared::{BuildSerializer, TypeSerializer},
13+
CombinedSerializer, Extra,
14+
},
15+
SchemaSerializer,
16+
};
17+
18+
#[derive(Debug)]
19+
pub struct NestedModelSerializer {
20+
model: Py<PyType>,
21+
name: String,
22+
get_serializer: Py<PyAny>,
23+
serializer: OnceLock<PyResult<Py<SchemaSerializer>>>,
24+
}
25+
26+
impl_py_gc_traverse!(NestedModelSerializer {
27+
model,
28+
get_serializer,
29+
serializer
30+
});
31+
32+
impl BuildSerializer for NestedModelSerializer {
33+
const EXPECTED_TYPE: &'static str = "nested-model";
34+
35+
fn build(
36+
schema: &Bound<'_, PyDict>,
37+
_config: Option<&Bound<'_, PyDict>>,
38+
_definitions: &mut DefinitionsBuilder<CombinedSerializer>,
39+
) -> PyResult<CombinedSerializer> {
40+
let py = schema.py();
41+
42+
let get_serializer = schema
43+
.get_item(intern!(py, "get_info"))?
44+
.expect("Invalid core schema for `nested-model` type, no `get_info`")
45+
.unbind();
46+
47+
let model = schema
48+
.get_item(intern!(py, "model"))?
49+
.expect("Invalid core schema for `nested-model` type, no `model`")
50+
.downcast::<PyType>()
51+
.expect("Invalid core schema for `nested-model` type, not a `PyType`")
52+
.clone();
53+
54+
let name = model.getattr(intern!(py, "__name__"))?.extract()?;
55+
56+
Ok(CombinedSerializer::NestedModel(NestedModelSerializer {
57+
model: model.clone().unbind(),
58+
name,
59+
get_serializer,
60+
serializer: OnceLock::new(),
61+
}))
62+
}
63+
}
64+
65+
impl NestedModelSerializer {
66+
fn nested_serializer<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaSerializer>> {
67+
self.serializer
68+
.get_or_init(|| {
69+
Ok(self
70+
.get_serializer
71+
.bind(py)
72+
.call((), None)?
73+
.downcast::<PyTuple>()?
74+
.get_item(2)?
75+
.downcast::<SchemaSerializer>()?
76+
.clone()
77+
.unbind())
78+
})
79+
.as_ref()
80+
.map_err(|e| e.clone_ref(py))
81+
}
82+
}
83+
84+
impl TypeSerializer for NestedModelSerializer {
85+
fn to_python(
86+
&self,
87+
value: &Bound<'_, PyAny>,
88+
include: Option<&Bound<'_, PyAny>>,
89+
exclude: Option<&Bound<'_, PyAny>>,
90+
mut extra: &Extra,
91+
) -> PyResult<PyObject> {
92+
let mut guard = extra.recursion_guard(value, self.model.as_ptr() as usize)?;
93+
94+
self.nested_serializer(value.py())?
95+
.bind(value.py())
96+
.get()
97+
.serializer
98+
.to_python(value, include, exclude, guard.state())
99+
}
100+
101+
fn json_key<'a>(&self, key: &'a Bound<'_, PyAny>, extra: &Extra) -> PyResult<Cow<'a, str>> {
102+
self.nested_serializer(key.py())?
103+
.bind(key.py())
104+
.get()
105+
.serializer
106+
.json_key(key, extra)
107+
}
108+
109+
fn serde_serialize<S: serde::ser::Serializer>(
110+
&self,
111+
value: &Bound<'_, PyAny>,
112+
serializer: S,
113+
include: Option<&Bound<'_, PyAny>>,
114+
exclude: Option<&Bound<'_, PyAny>>,
115+
mut extra: &Extra,
116+
) -> Result<S::Ok, S::Error> {
117+
use super::py_err_se_err;
118+
119+
let mut guard = extra
120+
.recursion_guard(value, self.model.as_ptr() as usize)
121+
.map_err(py_err_se_err)?;
122+
123+
self.nested_serializer(value.py())
124+
// FIXME(BoxyUwU): Don't unwrap this
125+
.unwrap()
126+
.bind(value.py())
127+
.get()
128+
.serializer
129+
.serde_serialize(value, serializer, include, exclude, guard.state())
130+
}
131+
132+
fn get_name(&self) -> &str {
133+
&self.name
134+
}
135+
}

src/validators/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ mod list;
4949
mod literal;
5050
mod model;
5151
mod model_fields;
52+
mod nested_model;
5253
mod none;
5354
mod nullable;
5455
mod set;
@@ -584,6 +585,7 @@ pub fn build_validator(
584585
definitions::DefinitionRefValidator,
585586
definitions::DefinitionsValidatorBuilder,
586587
complex::ComplexValidator,
588+
nested_model::NestedModelValidator,
587589
)
588590
}
589591

@@ -738,6 +740,8 @@ pub enum CombinedValidator {
738740
// input dependent
739741
JsonOrPython(json_or_python::JsonOrPython),
740742
Complex(complex::ComplexValidator),
743+
// Schema for a model inside of another schema
744+
NestedModel(nested_model::NestedModelValidator),
741745
}
742746

743747
/// This trait must be implemented by all validators, it allows various validators to be accessed consistently,

src/validators/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ impl BuildValidator for ModelValidator {
7777

7878
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
7979
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
80-
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;
80+
let validator: CombinedValidator = build_validator(&sub_schema, config.as_ref(), definitions)?;
8181
let name = class.getattr(intern!(py, "__name__"))?.extract()?;
8282

8383
Ok(Self {

src/validators/nested_model.rs

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
use std::sync::OnceLock;
2+
3+
use pyo3::{
4+
intern,
5+
types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyTupleMethods, PyType},
6+
Bound, Py, PyAny, PyObject, PyResult, Python,
7+
};
8+
9+
use crate::{
10+
definitions::DefinitionsBuilder,
11+
errors::{ErrorTypeDefaults, ValError, ValResult},
12+
input::Input,
13+
recursion_guard::RecursionGuard,
14+
};
15+
16+
use super::{BuildValidator, CombinedValidator, SchemaValidator, ValidationState, Validator};
17+
18+
#[derive(Debug)]
19+
pub struct NestedModelValidator {
20+
model: Py<PyType>,
21+
name: String,
22+
get_validator: Py<PyAny>,
23+
validator: OnceLock<PyResult<Py<SchemaValidator>>>,
24+
}
25+
26+
impl_py_gc_traverse!(NestedModelValidator {
27+
model,
28+
get_validator,
29+
validator
30+
});
31+
32+
impl BuildValidator for NestedModelValidator {
33+
const EXPECTED_TYPE: &'static str = "nested-model";
34+
35+
fn build(
36+
schema: &Bound<'_, PyDict>,
37+
_config: Option<&Bound<'_, PyDict>>,
38+
_definitions: &mut DefinitionsBuilder<super::CombinedValidator>,
39+
) -> PyResult<super::CombinedValidator> {
40+
let py = schema.py();
41+
42+
let get_validator = schema.get_item(intern!(py, "get_info"))?.unwrap().unbind();
43+
44+
let model = schema
45+
.get_item(intern!(py, "model"))?
46+
.unwrap()
47+
.downcast::<PyType>()?
48+
.clone();
49+
50+
let name = model.getattr(intern!(py, "__name__"))?.extract()?;
51+
52+
Ok(CombinedValidator::NestedModel(NestedModelValidator {
53+
model: model.clone().unbind(),
54+
name,
55+
get_validator: get_validator,
56+
validator: OnceLock::new(),
57+
}))
58+
}
59+
}
60+
61+
impl NestedModelValidator {
62+
fn nested_validator<'py>(&self, py: Python<'py>) -> PyResult<&Py<SchemaValidator>> {
63+
self.validator
64+
.get_or_init(|| {
65+
Ok(self
66+
.get_validator
67+
.bind(py)
68+
.call((), None)?
69+
.downcast::<PyTuple>()?
70+
.get_item(1)?
71+
.downcast::<SchemaValidator>()?
72+
.clone()
73+
.unbind())
74+
})
75+
.as_ref()
76+
.map_err(|e| e.clone_ref(py))
77+
}
78+
}
79+
80+
impl Validator for NestedModelValidator {
81+
fn validate<'py>(
82+
&self,
83+
py: Python<'py>,
84+
input: &(impl Input<'py> + ?Sized),
85+
state: &mut ValidationState<'_, 'py>,
86+
) -> ValResult<PyObject> {
87+
let Some(id) = input.as_python().map(py_identity) else {
88+
return self
89+
.nested_validator(py)?
90+
.bind(py)
91+
.get()
92+
.validator
93+
.validate(py, input, state);
94+
};
95+
96+
// Python objects can be cyclic, so need recursion guard
97+
let Ok(mut guard) = RecursionGuard::new(state, id, self.model.as_ptr() as usize) else {
98+
return Err(ValError::new(ErrorTypeDefaults::RecursionLoop, input));
99+
};
100+
101+
self.nested_validator(py)?
102+
.bind(py)
103+
.get()
104+
.validator
105+
.validate(py, input, guard.state())
106+
}
107+
108+
fn get_name(&self) -> &str {
109+
&self.name
110+
}
111+
}
112+
113+
fn py_identity(obj: &Bound<'_, PyAny>) -> usize {
114+
obj.as_ptr() as usize
115+
}

0 commit comments

Comments
 (0)