Skip to content

Draft: Reusing schema validators and serializers #1614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/serializers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::fmt::Debug;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyTuple, PyType};
Expand Down Expand Up @@ -37,7 +38,7 @@ pub enum WarningsArg {
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaSerializer {
serializer: CombinedSerializer,
serializer: Arc<CombinedSerializer>,
definitions: Definitions<CombinedSerializer>,
expected_json_size: AtomicUsize,
config: SerializationConfig,
Expand Down Expand Up @@ -92,7 +93,7 @@ impl SchemaSerializer {
let mut definitions_builder = DefinitionsBuilder::new();
let serializer = CombinedSerializer::build(schema.downcast()?, config, &mut definitions_builder)?;
Ok(Self {
serializer,
serializer: Arc::new(serializer),
definitions: definitions_builder.finish()?,
expected_json_size: AtomicUsize::new(1024),
config: SerializationConfig::from_config(config)?,
Expand Down
19 changes: 17 additions & 2 deletions src/serializers/type_serializers/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::borrow::Cow;
use std::sync::Arc;

use pyo3::intern;
use pyo3::prelude::*;
Expand All @@ -18,6 +19,7 @@ use crate::definitions::DefinitionsBuilder;
use crate::serializers::errors::PydanticSerializationUnexpectedValue;
use crate::serializers::extra::DuckTypingSerMode;
use crate::tools::SchemaDict;
use crate::SchemaSerializer;

const ROOT_FIELD: &str = "root";

Expand Down Expand Up @@ -76,7 +78,7 @@ impl BuildSerializer for ModelFieldsBuilder {
#[derive(Debug)]
pub struct ModelSerializer {
class: Py<PyType>,
serializer: Box<CombinedSerializer>,
serializer: Arc<CombinedSerializer>,
has_extra: bool,
root_model: bool,
name: String,
Expand All @@ -97,8 +99,21 @@ impl BuildSerializer for ModelSerializer {

let class: Py<PyType> = schema.get_as_req(intern!(py, "cls"))?;
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
let serializer = Box::new(CombinedSerializer::build(&sub_schema, config.as_ref(), definitions)?);

let bound_class = class.bind(py);
let serializer = if bound_class.getattr("__pydantic_complete__")?.extract::<bool>()? {
if let Ok(prebuilt_serializer) = bound_class.getattr("__pydantic_serializer__") {
let schema_serializer: PyRef<SchemaSerializer> = prebuilt_serializer.extract()?;
schema_serializer.serializer.clone()
} else {
Arc::new(CombinedSerializer::build(&sub_schema, config.as_ref(), definitions)?)
}
} else {
Arc::new(CombinedSerializer::build(&sub_schema, config.as_ref(), definitions)?)
};

let root_model = schema.get_as(intern!(py, "root_model"))?.unwrap_or(false);

let name = class.bind(py).getattr(intern!(py, "__name__"))?.extract()?;

Ok(Self {
Expand Down
7 changes: 4 additions & 3 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::fmt::Debug;
use std::sync::Arc;

use enum_dispatch::enum_dispatch;
use jiter::{PartialMode, StringCacheMode};
Expand Down Expand Up @@ -105,7 +106,7 @@ impl PySome {
#[pyclass(module = "pydantic_core._pydantic_core", frozen)]
#[derive(Debug)]
pub struct SchemaValidator {
validator: CombinedValidator,
validator: Arc<CombinedValidator>,
definitions: Definitions<CombinedValidator>,
// References to the Python schema and config objects are saved to enable
// reconstructing the object for cloudpickle support (see `__reduce__`).
Expand Down Expand Up @@ -146,7 +147,7 @@ impl SchemaValidator {
.get_as(intern!(py, "cache_strings"))?
.unwrap_or(StringCacheMode::All);
Ok(Self {
validator,
validator: Arc::new(validator),
definitions,
py_schema,
py_config,
Expand Down Expand Up @@ -455,7 +456,7 @@ impl<'py> SelfValidator<'py> {
};
let definitions = definitions_builder.finish()?;
Ok(SchemaValidator {
validator,
validator: Arc::new(validator),
definitions,
py_schema: py.None(),
py_config: None,
Expand Down
20 changes: 16 additions & 4 deletions src/validators/model.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::ptr::null_mut;
use std::sync::Arc;

use pyo3::exceptions::PyTypeError;
use pyo3::types::{PyDict, PySet, PyString, PyTuple, PyType};
Expand All @@ -8,7 +9,7 @@ use pyo3::{intern, prelude::*};
use super::function::convert_err;
use super::validation_state::Exactness;
use super::{
build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator,
build_validator, BuildValidator, CombinedValidator, DefinitionsBuilder, Extra, ValidationState, Validator, SchemaValidator
};
use crate::build_tools::py_schema_err;
use crate::build_tools::schema_or_config_same;
Expand Down Expand Up @@ -53,7 +54,7 @@ impl Revalidate {
#[derive(Debug)]
pub struct ModelValidator {
revalidate: Revalidate,
validator: Box<CombinedValidator>,
validator: Arc<CombinedValidator>,
class: Py<PyType>,
generic_origin: Option<Py<PyType>>,
post_init: Option<Py<PyString>>,
Expand All @@ -79,7 +80,18 @@ impl BuildValidator for ModelValidator {
let class: Bound<'_, PyType> = schema.get_as_req(intern!(py, "cls"))?;
let generic_origin: Option<Bound<'_, PyType>> = schema.get_as(intern!(py, "generic_origin"))?;
let sub_schema = schema.get_as_req(intern!(py, "schema"))?;
let validator = build_validator(&sub_schema, config.as_ref(), definitions)?;

let validator = if class.getattr("__pydantic_complete__")?.extract::<bool>()? {
if let Ok(prebuilt_validator) = class.getattr("__pydantic_validator__") {
let schema_validator: PyRef<SchemaValidator> = prebuilt_validator.extract()?;
schema_validator.validator.clone()
} else {
Arc::new(build_validator(&sub_schema, config.as_ref(), definitions)?)
}
} else {
Arc::new(build_validator(&sub_schema, config.as_ref(), definitions)?)
};

let name = class.getattr(intern!(py, "__name__"))?.extract()?;

Ok(Self {
Expand All @@ -93,7 +105,7 @@ impl BuildValidator for ModelValidator {
.map(|s| s.to_str())
.transpose()?,
)?,
validator: Box::new(validator),
validator,
class: class.into(),
generic_origin: generic_origin.map(std::convert::Into::into),
post_init: schema.get_as(intern!(py, "post_init"))?,
Expand Down
Loading