diff --git a/python/pydantic_core/__init__.py b/python/pydantic_core/__init__.py index 5b2655c91..711b7feb6 100644 --- a/python/pydantic_core/__init__.py +++ b/python/pydantic_core/__init__.py @@ -21,6 +21,8 @@ TzInfo, Url, ValidationError, + WalkCoreSchema, + WalkCoreSchemaFilterBuilder, __version__, from_json, to_json, @@ -67,6 +69,8 @@ 'from_json', 'to_jsonable_python', 'validate_core_schema', + 'WalkCoreSchema', + 'WalkCoreSchemaFilterBuilder', ] diff --git a/python/pydantic_core/_pydantic_core.pyi b/python/pydantic_core/_pydantic_core.pyi index 382a6c804..a0272df12 100644 --- a/python/pydantic_core/_pydantic_core.pyi +++ b/python/pydantic_core/_pydantic_core.pyi @@ -5,7 +5,7 @@ import sys from typing import Any, Callable, Generic, Optional, Type, TypeVar from pydantic_core import ErrorDetails, ErrorTypeInfo, InitErrorDetails, MultiHostHost -from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType +from pydantic_core.core_schema import CoreConfig, CoreSchema, ErrorType, SerSchema if sys.version_info < (3, 8): from typing_extensions import final @@ -46,6 +46,8 @@ __all__ = [ 'list_all_errors', 'TzInfo', 'validate_core_schema', + 'WalkCoreSchema', + 'WalkCoreSchemaFilterBuilder', ] __version__: str build_profile: str @@ -864,3 +866,27 @@ def validate_core_schema(schema: CoreSchema, *, strict: bool | None = None) -> C We may also remove this function altogether, do not rely on it being present if you are using pydantic-core directly. """ + +class _WalkCoreSchemaFilter(Generic[_T]): + pass + +@final +class WalkCoreSchemaFilterBuilder(Generic[_T]): + def __and__(self, other: WalkCoreSchemaFilterBuilder) -> WalkCoreSchemaFilterBuilder: ... + def __or__(self, other: WalkCoreSchemaFilterBuilder) -> WalkCoreSchemaFilterBuilder: ... + @staticmethod + def has_ref() -> WalkCoreSchemaFilterBuilder: ... + @staticmethod + def has_type(type: str) -> WalkCoreSchemaFilterBuilder: ... + @staticmethod + def predicate(predicate: Callable[[_T], bool]) -> WalkCoreSchemaFilterBuilder: ... + def build(self, func: Callable[[_T, Callable[[_T], _T]], _T]) -> _WalkCoreSchemaFilter[_T]: ... + +@final +class WalkCoreSchema: + def __init__( + self, + visit_core_schema: _WalkCoreSchemaFilter[CoreSchema] | None = None, + visit_ser_schema: _WalkCoreSchemaFilter[SerSchema] | None = None, + ) -> None: ... + def walk(self, schema: CoreSchema) -> CoreSchema: ... diff --git a/src/lib.rs b/src/lib.rs index de4a6d9bd..9a1c1855a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,7 @@ mod serializers; mod tools; mod url; mod validators; +mod walk_core_schema; // required for benchmarks pub use self::input::TzInfo; @@ -111,5 +112,7 @@ fn _pydantic_core(py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(to_jsonable_python, m)?)?; m.add_function(wrap_pyfunction!(list_all_errors, m)?)?; m.add_function(wrap_pyfunction!(validate_core_schema, m)?)?; + m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/src/walk_core_schema.rs b/src/walk_core_schema.rs new file mode 100644 index 000000000..4ed32d3e7 --- /dev/null +++ b/src/walk_core_schema.rs @@ -0,0 +1,761 @@ +use pyo3::exceptions::PyTypeError; +use pyo3::intern; +use pyo3::prelude::*; +use pyo3::types::PyTuple; +use pyo3::types::{PyDict, PyList, PyString}; + +#[pyclass(subclass, module = "pydantic_core._pydantic_core")] +#[derive(Debug, Clone)] +pub struct WalkCoreSchema { + visit_core_schema: Option, + visit_ser_schema: Option, +} + +#[pymethods] +impl WalkCoreSchema { + #[new] + #[pyo3(signature = (visit_core_schema = None, visit_ser_schema = None))] + fn new(visit_core_schema: Option, visit_ser_schema: Option) -> Self { + WalkCoreSchema { + visit_core_schema, + visit_ser_schema, + } + } + + fn walk<'s>(&self, py: Python<'s>, schema: &'s PyDict) -> PyResult> { + match &self.visit_core_schema { + Some(visit_core_schema) => { + if visit_core_schema.matches(py, schema)? { + let call_next = self + .clone() + .into_py(py) + .getattr(py, intern!(py, "_walk_core_schema"))? + .clone(); + visit_core_schema.call(py, schema.copy()?, call_next)?.extract(py) + } else { + self._walk_core_schema(py, schema) + } + } + None => self._walk_core_schema(py, schema), + } + } + + fn _walk_core_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + // TODO: can we remove this copy by keeping track of when we hit a filter + // (schemas con only get modified if they hit a filter) + let schema = schema.copy()?; + let schema_type: &str = schema.get_item("type")?.unwrap().extract()?; + match schema_type { + "any" => self._handle_any_schema(py, schema), + "none" => self._handle_none_schema(py, schema), + "bool" => self._handle_bool_schema(py, schema), + "int" => self._handle_int_schema(py, schema), + "float" => self._handle_float_schema(py, schema), + "decimal" => self._handle_decimal_schema(py, schema), + "str" => self._handle_string_schema(py, schema), + "bytes" => self._handle_bytes_schema(py, schema), + "date" => self._handle_date_schema(py, schema), + "time" => self._handle_time_schema(py, schema), + "datetime" => self._handle_datetime_schema(py, schema), + "timedelta" => self._handle_timedelta_schema(py, schema), + "literal" => self._handle_literal_schema(py, schema), + "is-instance" => self._handle_is_instance_schema(py, schema), + "is-subclass" => self._handle_is_subclass_schema(py, schema), + "callable" => self._handle_callable_schema(py, schema), + "list" => self._handle_list_schema(py, schema), + "tuple-positional" => self._handle_tuple_positional_schema(py, schema), + "tuple-variable" => self._handle_tuple_variable_schema(py, schema), + "set" => self._handle_set_schema(py, schema), + "frozenset" => self._handle_frozenset_schema(py, schema), + "generator" => self._handle_generator_schema(py, schema), + "dict" => self._handle_dict_schema(py, schema), + "function-after" => self._handle_after_validator_function_schema(py, schema), + "function-before" => self._handle_before_validator_function_schema(py, schema), + "function-wrap" => self._handle_wrap_validator_function_schema(py, schema), + "function-plain" => self._handle_plain_validator_function_schema(py, schema), + "default" => self._handle_with_default_schema(py, schema), + "nullable" => self._handle_nullable_schema(py, schema), + "union" => self._handle_union_schema(py, schema), + "tagged-union" => self._handle_tagged_union_schema(py, schema), + "chain" => self._handle_chain_schema(py, schema), + "lax-or-strict" => self._handle_lax_or_strict_schema(py, schema), + "json-or-python" => self._handle_json_or_python_schema(py, schema), + "typed-dict" => self._handle_typed_dict_schema(py, schema), + "model-fields" => self._handle_model_fields_schema(py, schema), + "model-field" => self._handle_model_field_schema(py, schema), + "model" => self._handle_model_schema(py, schema), + "dataclass-args" => self._handle_dataclass_args_schema(py, schema), + "dataclass" => self._handle_dataclass_schema(py, schema), + "arguments" => self._handle_arguments_schema(py, schema), + "call" => self._handle_call_schema(py, schema), + "custom-error" => self._handle_custom_error_schema(py, schema), + "json" => self._handle_json_schema(py, schema), + "url" => self._handle_url_schema(py, schema), + "multi-host-url" => self._handle_multi_host_url_schema(py, schema), + "definitions" => self._handle_definitions_schema(py, schema), + "definition-ref" => self._handle_definition_reference_schema(py, schema), + "uuid" => self._handle_uuid_schema(py, schema), + _ => Err(PyTypeError::new_err(format!("Unknown schema type: {schema_type}"))), + } + } + + fn _walk_ser_schema(&self, py: Python, ser_schema: &PyDict) -> PyResult> { + // TODO: can we remove this copy by keeping track of when we hit a filter + // (schemas con only get modified if they hit a filter) + let ser_schema = ser_schema.copy()?; + let schema_type: &str = ser_schema.get_item("type")?.unwrap().extract()?; + match schema_type { + "none" | "int" | "bool" | "float" | "str" | "bytes" | "bytearray" | "list" | "tuple" | "set" + | "frozenset" | "generator" | "dict" | "datetime" | "date" | "time" | "timedelta" | "url" + | "multi-host-url" | "json" | "uuid" => self._handle_simple_ser_schema(py, ser_schema), + "function-plain" => self._handle_plain_serializer_function_ser_schema(py, ser_schema), + "function-wrap" => self._handle_wrap_serializer_function_ser_schema(py, ser_schema), + "format" => self._handle_format_ser_schema(py, ser_schema), + "to-string" => self._handle_to_string_ser_schema(py, ser_schema), + "model" => self._handle_model_ser_schema(py, ser_schema), + _ => Err(PyTypeError::new_err(format!("Unknown ser schema type: {schema_type}"))), + } + } + + // Check if there is a "serialization" key and if so handle it + // and replace the result + fn _check_ser_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let serialization_key = intern!(py, "serialization"); + let ser_schema: Option<&PyDict> = invert(schema.get_item(serialization_key)?.map(pyo3::PyAny::extract))?; + if let Some(ser_schema) = ser_schema { + if let Some(visit_ser_schema) = &self.visit_ser_schema { + if visit_ser_schema.matches(py, ser_schema)? { + let call_next = self + .clone() + .into_py(py) + .getattr(py, intern!(py, "_walk_ser_schema"))? + .clone(); + let new_ser_schema = visit_ser_schema.call(py, ser_schema, call_next)?; + schema.set_item(serialization_key, new_ser_schema)?; + } else { + let new_ser_schema = self._walk_ser_schema(py, ser_schema)?; + schema.set_item(serialization_key, new_ser_schema)?; + } + } else { + let new_ser_schema = self._walk_ser_schema(py, ser_schema)?; + schema.set_item(serialization_key, new_ser_schema)?; + } + } + Ok(schema.into()) + } + + fn _handle_any_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_none_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_bool_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_int_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_float_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_decimal_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_string_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_bytes_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_date_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_time_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_datetime_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_timedelta_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_literal_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_is_instance_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_is_subclass_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_callable_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_list_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self._handle_items_schema(py, schema)?.into_ref(py).extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_tuple_positional_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let items_schema_key = intern!(py, "items_schema"); + let items_schema: Option<&PyList> = invert(schema.get_item(items_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(items_schema) = items_schema { + let new_items_schema = items_schema + .iter() + .map(|item_schema| self.walk(py, item_schema.extract()?)) + .collect::>>>()?; + schema.set_item(items_schema_key, new_items_schema)?; + } + let schema = self._handle_extras_schema(py, schema)?.into_ref(py).extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_tuple_variable_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self._handle_items_schema(py, schema)?.into_ref(py).extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_set_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self._handle_items_schema(py, schema)?.into_ref(py).extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_frozenset_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self._handle_items_schema(py, schema)?.into_ref(py).extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_generator_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self._handle_items_schema(py, schema)?.into_ref(py).extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_dict_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let keys_schema: Option<&PyDict> = invert(schema.get_item("keys_schema")?.map(pyo3::PyAny::extract))?; + if let Some(keys_schema) = keys_schema { + let new_keys_schema = self.walk(py, keys_schema)?; + schema.set_item("keys_schema", new_keys_schema)?; + } + let values_schema: Option<&PyDict> = invert(schema.get_item("values_schema")?.map(pyo3::PyAny::extract))?; + if let Some(values_schema) = values_schema { + let new_values_schema = self.walk(py, values_schema)?; + schema.set_item("values_schema", new_values_schema)?; + } + self._check_ser_schema(py, schema) + } + + fn _handle_after_validator_function_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_before_validator_function_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_wrap_validator_function_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_plain_validator_function_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_with_default_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_nullable_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_union_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let choices_key = intern!(py, "choices"); + let choices: Option<&PyList> = invert(schema.get_item(choices_key)?.map(pyo3::PyAny::extract))?; + if let Some(choices) = choices { + let new_choices = choices + .iter() + .map(|choice| match choice.extract::<(&PyDict, &PyString)>() { + Ok(choice) => { + let (schema, tag) = choice; + let schema = self.walk(py, schema)?; + Ok(PyTuple::new(py, [schema.into_py(py), tag.into_py(py)]).into()) + } + Err(_) => Ok(self.walk(py, choice.extract()?)?.into()), + }) + .collect::>>()?; + schema.set_item(choices_key, new_choices)?; + } + self._check_ser_schema(py, schema) + } + + fn _handle_tagged_union_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let choices_key = intern!(py, "choices"); + let choices: Option<&PyDict> = invert(schema.get_item(choices_key)?.map(pyo3::PyAny::extract))?; + if let Some(choices) = choices { + let new_choices = choices.iter().map(|(k, v)| { + let new_v = self.walk(py, v.extract()?); + Ok((k, new_v?)) + }); + schema.set_item(choices_key, py_dict_from_iterator(py, new_choices)?)?; + } + self._check_ser_schema(py, schema) + } + + fn _handle_chain_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let steps_key = intern!(py, "steps"); + let steps: Option<&PyList> = invert(schema.get_item(steps_key)?.map(pyo3::PyAny::extract))?; + if let Some(steps) = steps { + let new_steps = steps + .iter() + .map(|step| self.walk(py, step.extract()?)) + .collect::>>>()?; + schema.set_item(steps_key, new_steps)?; + } + self._check_ser_schema(py, schema) + } + + fn _handle_lax_or_strict_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let lax_schema_key = intern!(py, "lax_schema"); + let lax_schema: Option<&PyDict> = invert(schema.get_item(lax_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(lax_schema) = lax_schema { + let new_lax_schema = self.walk(py, lax_schema)?; + schema.set_item(lax_schema_key, new_lax_schema)?; + } + let strict_schema_key = intern!(py, "strict_schema"); + let strict_schema: Option<&PyDict> = invert(schema.get_item(strict_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(strict_schema) = strict_schema { + let new_strict_schema = self.walk(py, strict_schema)?; + schema.set_item(strict_schema_key, new_strict_schema)?; + } + self._check_ser_schema(py, schema) + } + + fn _handle_json_or_python_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let json_schema_key = intern!(py, "json_schema"); + let json_schema: Option<&PyDict> = invert(schema.get_item(json_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(json_schema) = json_schema { + let new_json_schema = self.walk(py, json_schema)?; + schema.set_item(json_schema_key, new_json_schema)?; + } + let python_schema_key = intern!(py, "python_schema"); + let python_schema: Option<&PyDict> = invert(schema.get_item(python_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(python_schema) = python_schema { + let new_python_schema = self.walk(py, python_schema)?; + schema.set_item(python_schema_key, new_python_schema)?; + } + self._check_ser_schema(py, schema) + } + + fn _handle_typed_dict_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let fields_key = intern!(py, "fields"); + let fields: Option<&PyDict> = invert(schema.get_item(fields_key)?.map(pyo3::PyAny::extract))?; + if let Some(fields) = fields { + let new_fields = fields.iter().map(|(k, v)| { + let typed_dict_field: &PyDict = v.extract()?; + let schema: &PyDict = typed_dict_field + .get_item("schema") + .ok() + .flatten() + .ok_or_else(|| PyTypeError::new_err("Missing schema in TypedDictField"))? + .extract()?; + let new_schema = self.walk(py, schema)?; + typed_dict_field.set_item("schema", new_schema)?; + Ok((k, v)) + }); + schema.set_item(fields_key, py_dict_from_iterator(py, new_fields)?)?; + } + let schema = self._handle_extras_schema(py, schema)?.into_ref(py).extract()?; + let schema = self + ._handle_computed_fields_schema(py, schema)? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_model_field_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_model_fields_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let fields_key = intern!(py, "fields"); + let fields: Option<&PyDict> = invert(schema.get_item(fields_key)?.map(pyo3::PyAny::extract))?; + if let Some(fields) = fields { + let new_fields = fields.iter().map(|(k, v)| { + let new_v = self.walk(py, v.extract()?)?; + Ok((k, new_v)) + }); + schema.set_item(fields_key, py_dict_from_iterator(py, new_fields)?)?; + } + let schema = self._handle_extras_schema(py, schema)?.into_ref(py).extract()?; + let schema = self + ._handle_computed_fields_schema(py, schema)? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_model_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_dataclass_args_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let fields_key = intern!(py, "fields"); + let schema_key = intern!(py, "schema"); + let fields: Option<&PyList> = invert(schema.get_item(fields_key)?.map(pyo3::PyAny::extract))?; + if let Some(fields) = fields { + for v in fields { + let dataclass_field: &PyDict = v.extract()?; + let dataclass_field_schema: &PyDict = + invert(dataclass_field.get_item(schema_key)?.map(pyo3::PyAny::extract))?.unwrap(); + let new_dataclass_field_schema = self.walk(py, dataclass_field_schema)?; + dataclass_field.set_item(schema_key, new_dataclass_field_schema)?; + } + } + let schema = self + ._handle_computed_fields_schema(py, schema)? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_dataclass_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_arguments_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let arguments_schema_key = intern!(py, "arguments_schema"); + let arguments_schema: Option<&PyList> = + invert(schema.get_item(arguments_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(arguments_schema) = arguments_schema { + for argument_parameter in arguments_schema { + let argument_parameter: &PyDict = argument_parameter.extract()?; + let argument_schema: &PyDict = argument_parameter + .get_item("schema") + .ok() + .flatten() + .ok_or_else(|| PyTypeError::new_err("Missing schema in ArgumentParameter"))? + .extract()?; + let new_argument_schema = self.walk(py, argument_schema)?; + argument_parameter.set_item("schema", new_argument_schema)?; + } + } + self._check_ser_schema(py, schema) + } + + fn _handle_call_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let arguments_schema_key = intern!(py, "arguments_schema"); + let return_schema_key = intern!(py, "return_schema"); + let arguments_schema: Option<&PyDict> = + invert(schema.get_item(arguments_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(arguments_schema) = arguments_schema { + let new_arguments_schema = self.walk(py, arguments_schema)?; + schema.set_item(arguments_schema_key, new_arguments_schema)?; + } + let return_schema: Option<&PyDict> = invert(schema.get_item(return_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(return_schema) = return_schema { + let new_return_schema = self.walk(py, return_schema)?; + schema.set_item(return_schema_key, new_return_schema)?; + } + self._check_ser_schema(py, schema) + } + + fn _handle_custom_error_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_json_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_url_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_multi_host_url_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_definitions_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let definitions_key = intern!(py, "definitions"); + let definitions: Option<&PyList> = invert(schema.get_item(definitions_key)?.map(pyo3::PyAny::extract))?; + if let Some(definitions) = definitions { + let new_definitions = definitions + .iter() + .map(|definition| self.walk(py, definition.extract()?)) + .collect::>>>()?; + schema.set_item(definitions_key, new_definitions)?; + } + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._check_ser_schema(py, schema) + } + + fn _handle_definition_reference_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_uuid_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._check_ser_schema(py, schema) + } + + fn _handle_simple_ser_schema(&self, _py: Python, schema: &PyDict) -> PyResult> { + Ok(schema.into()) + } + + fn _handle_plain_serializer_function_ser_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._handle_inner_schema(py, schema, intern!(py, "return_schema")) + } + + fn _handle_wrap_serializer_function_ser_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let schema = self + ._handle_inner_schema(py, schema, intern!(py, "schema"))? + .into_ref(py) + .extract()?; + self._handle_inner_schema(py, schema, intern!(py, "return_schema")) + } + + fn _handle_format_ser_schema(&self, _py: Python, schema: &PyDict) -> PyResult> { + Ok(schema.into()) + } + + fn _handle_to_string_ser_schema(&self, _py: Python, schema: &PyDict) -> PyResult> { + Ok(schema.into()) + } + + fn _handle_model_ser_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + self._handle_inner_schema(py, schema, intern!(py, "schema")) + } + + // Handle a dict where there may be a `"schema": CoreSchema` key + fn _handle_inner_schema(&self, py: Python, schema: &PyDict, schema_key: &PyString) -> PyResult> { + let inner_schema: Option<&PyDict> = invert(schema.get_item(schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(inner_schema) = inner_schema { + let new_inner_schema = self.walk(py, inner_schema)?; + schema.set_item(schema_key, new_inner_schema)?; + } + Ok(schema.into()) + } + + // Handle a dict where there may be a `"items_schema": CoreSchema` key + fn _handle_items_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let items_schema_key = intern!(py, "items_schema"); + let items_schema: Option<&PyDict> = invert(schema.get_item(items_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(items_schema) = items_schema { + let new_items_schema = self.walk(py, items_schema)?; + schema.set_item(items_schema_key, new_items_schema)?; + } + Ok(schema.into()) + } + + fn _handle_computed_fields_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let computed_fields_key = intern!(py, "computed_fields"); + let computed_fields: Option<&PyList> = invert(schema.get_item(computed_fields_key)?.map(pyo3::PyAny::extract))?; + if let Some(computed_fields) = computed_fields { + let schema_key = intern!(py, "schema"); + let return_schema_key = intern!(py, "return_schema"); + let new_computed_fields = computed_fields + .iter() + .map(|computed_field| { + let computed_field_schema: &PyDict = computed_field.extract()?; + let computed_field_schema = self + ._handle_inner_schema(py, computed_field_schema, schema_key)? + .into_ref(py); + let computed_field_schema = + self._handle_inner_schema(py, computed_field_schema, return_schema_key)?; + Ok(computed_field_schema) + }) + .collect::>>>()?; + schema.set_item(computed_fields_key, new_computed_fields)?; + } + Ok(schema.into()) + } + + fn _handle_extras_schema(&self, py: Python, schema: &PyDict) -> PyResult> { + let extras_schema_key = intern!(py, "extras_schema"); + let extras_schema: Option<&PyDict> = invert(schema.get_item(extras_schema_key)?.map(pyo3::PyAny::extract))?; + if let Some(extras_schema) = extras_schema { + let new_extras_schema = self.walk(py, extras_schema)?; + schema.set_item(extras_schema_key, new_extras_schema)?; + } + Ok(schema.into()) + } +} + +fn invert(x: Option>) -> Result, E> { + x.map_or(Ok(None), |v| v.map(Some)) +} + +#[derive(Debug, Clone)] +enum Filter { + HasRef, + HasType { type_: String }, + Python { predicate: PyObject }, + And { left: Box, right: Box }, + Or { left: Box, right: Box }, +} + +impl Filter { + fn matches(&self, py: Python, schema: &PyDict) -> PyResult { + match self { + Filter::HasRef => { + let ref_ = schema.get_item("ref")?; + Ok(ref_.is_some()) + } + Filter::HasType { type_ } => { + if let Some(schema_type) = invert(schema.get_item("type")?.map(pyo3::PyAny::extract::<&str>))? { + Ok(schema_type == type_) + } else { + Ok(false) + } + } + Filter::Python { predicate } => { + let result: bool = predicate.call1(py, (schema,))?.extract(py)?; + Ok(result) + } + Filter::And { left, right } => Ok(left.matches(py, schema)? && right.matches(py, schema)?), + Filter::Or { left, right } => Ok(left.matches(py, schema)? || right.matches(py, schema)?), + } + } +} + +#[derive(Debug, Clone)] +#[pyclass(module = "pydantic_core._pydantic_core")] +pub struct WalkCoreSchemaFilterBuilder { + filter: Filter, +} + +#[pymethods] +impl WalkCoreSchemaFilterBuilder { + #[staticmethod] + fn has_ref() -> Self { + WalkCoreSchemaFilterBuilder { filter: Filter::HasRef } + } + + #[staticmethod] + #[pyo3(text_signature = "(type)")] + fn has_type(type_: String) -> Self { + WalkCoreSchemaFilterBuilder { + filter: Filter::HasType { type_ }, + } + } + + #[staticmethod] + fn predicate(predicate: PyObject) -> Self { + WalkCoreSchemaFilterBuilder { + filter: Filter::Python { predicate }, + } + } + + fn __and__(&self, other: WalkCoreSchemaFilterBuilder) -> Self { + WalkCoreSchemaFilterBuilder { + filter: Filter::And { + left: Box::new(self.filter.clone()), + right: Box::new(other.filter.clone()), + }, + } + } + + fn __or__(&self, other: WalkCoreSchemaFilterBuilder) -> Self { + WalkCoreSchemaFilterBuilder { + filter: Filter::Or { + left: Box::new(self.filter.clone()), + right: Box::new(other.filter.clone()), + }, + } + } + + fn build(&self, func: PyObject) -> FilterCallable { + FilterCallable { + filter: self.filter.clone(), + func, + } + } +} + +#[derive(Debug, Clone)] +#[pyclass(module = "pydantic_core._pydantic_core")] +struct FilterCallable { + filter: Filter, + func: PyObject, +} + +impl FilterCallable { + fn matches(&self, py: Python, schema: &PyDict) -> PyResult { + self.filter.matches(py, schema) + } + + fn call(&self, py: Python, schema: &PyDict, call_next: PyObject) -> PyResult> { + self.func.call1(py, (schema, call_next))?.extract(py) + } +} + +fn py_dict_from_iterator( + py: Python, + iterator: impl IntoIterator>, +) -> PyResult> { + let dict = PyDict::new(py); + for item in iterator { + let (k, v) = item?; + dict.set_item(k, v)?; + } + Ok(dict.into()) +} diff --git a/tests/test_walk_core_schema.py b/tests/test_walk_core_schema.py new file mode 100644 index 000000000..5c2c8f96f --- /dev/null +++ b/tests/test_walk_core_schema.py @@ -0,0 +1,686 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Callable, Generic, TypeVar, Union + +from pydantic_core import WalkCoreSchema, WalkCoreSchemaFilterBuilder +from pydantic_core import core_schema as cs +from pydantic_core.core_schema import CoreSchema, SerSchema + +CoreSchemaCallNext = Callable[[CoreSchema], CoreSchema] +SerSchemaCallNext = Callable[[SerSchema], SerSchema] + +CallableF = TypeVar('CallableF', bound=Callable[..., Any]) + + +class NamedFunction(Generic[CallableF]): + def __init__(self, func: CallableF) -> None: + self.func = func + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.func(*args, **kwargs) + + def __repr__(self) -> str: + return f'NamedFunction({self.func.__name__})' + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, NamedFunction): + return False + return self.func is other.func # type: ignore + + +class SimpleRepr(type): + def __repr__(cls): + return cls.__name__ + + +class NamedClass(metaclass=SimpleRepr): + pass + + +def _plain_ser_func(x: Any) -> str: + return 'abc' + + +plain_ser_func = NamedFunction(_plain_ser_func) + + +def _wrap_ser_func(x: Any, handler: cs.SerializerFunctionWrapHandler) -> Any: + return handler(x) + + +wrap_ser_func = NamedFunction(_wrap_ser_func) + + +def _no_info_val_func(x: Any) -> Any: + return x + + +no_info_val_func = NamedFunction(_no_info_val_func) + + +def _no_info_wrap_val_func(x: Any, handler: cs.ValidatorFunctionWrapHandler) -> Any: + return handler(x) + + +no_info_wrap_val_func = NamedFunction(_no_info_wrap_val_func) + + +SchemaT = TypeVar('SchemaT', bound=Union[CoreSchema, SerSchema]) + + +@dataclass +class TrackingHandler: + called: list[str] = field(default_factory=list) + stack: list[str] = field(default_factory=list) + + def __call__(self, schema: SchemaT, call_next: Callable[[SchemaT], SchemaT]) -> SchemaT: + self.stack.append(schema['type']) + self.called.append(' -> '.join(self.stack)) + old = deepcopy(schema) + try: + new = call_next(schema) + assert new == old + return new + except Exception as e: + print(e) + print(schema['type']) + raise + finally: + self.stack.pop() + + +def test_walk_core_schema_before(): + handler = TrackingHandler() + + def match_any_predicate(schema: CoreSchema | SerSchema) -> bool: + return True + + walk = WalkCoreSchema( + visit_core_schema=WalkCoreSchemaFilterBuilder.predicate(match_any_predicate).build(handler), + visit_ser_schema=WalkCoreSchemaFilterBuilder.predicate(match_any_predicate).build(handler), + ) + + schema = cs.union_schema( + [ + cs.any_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)), + cs.none_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)), + cs.bool_schema(serialization=cs.simple_ser_schema('bool')), + cs.int_schema(serialization=cs.simple_ser_schema('int')), + cs.float_schema(serialization=cs.simple_ser_schema('float')), + cs.decimal_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)), + cs.str_schema(serialization=cs.simple_ser_schema('str')), + cs.bytes_schema(serialization=cs.simple_ser_schema('bytes')), + cs.date_schema(serialization=cs.simple_ser_schema('date')), + cs.time_schema(serialization=cs.simple_ser_schema('time')), + cs.datetime_schema(serialization=cs.simple_ser_schema('datetime')), + cs.timedelta_schema(serialization=cs.simple_ser_schema('timedelta')), + cs.literal_schema( + expected=[1, 2, 3], + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.is_instance_schema( + cls=NamedClass, + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.is_subclass_schema( + cls=NamedClass, + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.callable_schema( + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.list_schema( + cs.int_schema(serialization=cs.simple_ser_schema('int')), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.tuple_positional_schema( + [cs.int_schema(serialization=cs.simple_ser_schema('int'))], + extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.tuple_variable_schema( + cs.int_schema(serialization=cs.simple_ser_schema('int')), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.set_schema( + cs.int_schema(serialization=cs.simple_ser_schema('int')), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.frozenset_schema( + cs.int_schema(serialization=cs.simple_ser_schema('int')), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.generator_schema( + cs.int_schema(serialization=cs.simple_ser_schema('int')), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.dict_schema( + cs.int_schema(serialization=cs.simple_ser_schema('int')), + cs.int_schema(serialization=cs.simple_ser_schema('int')), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.no_info_after_validator_function( + no_info_val_func, + cs.int_schema(), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.no_info_before_validator_function( + no_info_val_func, + cs.int_schema(), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.no_info_wrap_validator_function( + no_info_wrap_val_func, + cs.int_schema(), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.no_info_plain_validator_function( + no_info_val_func, + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.with_default_schema( + cs.int_schema(), + default=1, + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.nullable_schema( + cs.int_schema(), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.union_schema( + [ + cs.int_schema(), + cs.str_schema(), + ], + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.tagged_union_schema( + { + 'a': cs.int_schema(), + 'b': cs.str_schema(), + }, + 'type', + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.chain_schema( + [ + cs.int_schema(), + cs.str_schema(), + ], + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.lax_or_strict_schema( + cs.int_schema(), + cs.str_schema(), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.json_or_python_schema( + cs.int_schema(), + cs.str_schema(), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.typed_dict_schema( + {'a': cs.typed_dict_field(cs.int_schema())}, + computed_fields=[ + cs.computed_field( + 'b', + cs.int_schema(), + ) + ], + extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.model_schema( + NamedClass, + cs.model_fields_schema( + {'a': cs.model_field(cs.int_schema())}, + extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')), + computed_fields=[ + cs.computed_field( + 'b', + cs.int_schema(), + ) + ], + ), + ), + cs.dataclass_schema( + NamedClass, + cs.dataclass_args_schema( + 'Model', + [cs.dataclass_field('a', cs.int_schema())], + computed_fields=[ + cs.computed_field( + 'b', + cs.int_schema(), + ) + ], + ), + ['a'], + ), + cs.call_schema( + cs.arguments_schema( + [cs.arguments_parameter('x', cs.int_schema())], + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + no_info_val_func, + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.custom_error_schema( + cs.int_schema(), + custom_error_type='CustomError', + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.json_schema( + cs.int_schema(), + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.url_schema( + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.multi_host_url_schema( + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.definitions_schema( + cs.int_schema(), + [ + cs.int_schema(ref='#/definitions/int'), + ], + ), + cs.definition_reference_schema( + '#/definitions/int', + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.uuid_schema( + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + ] + ) + + walk.walk(schema) + + # insert_assert(handler.called) + assert handler.called == [ + 'union', + 'union -> any', + 'union -> any -> function-plain', + 'union -> none', + 'union -> none -> function-plain', + 'union -> bool', + 'union -> bool -> bool', + 'union -> int', + 'union -> int -> int', + 'union -> float', + 'union -> float -> float', + 'union -> decimal', + 'union -> decimal -> function-plain', + 'union -> str', + 'union -> str -> str', + 'union -> bytes', + 'union -> bytes -> bytes', + 'union -> date', + 'union -> date -> date', + 'union -> time', + 'union -> time -> time', + 'union -> datetime', + 'union -> datetime -> datetime', + 'union -> timedelta', + 'union -> timedelta -> timedelta', + 'union -> literal', + 'union -> literal -> function-plain', + 'union -> is-instance', + 'union -> is-instance -> function-plain', + 'union -> is-subclass', + 'union -> is-subclass -> function-plain', + 'union -> callable', + 'union -> callable -> function-plain', + 'union -> list', + 'union -> list -> int', + 'union -> list -> int -> int', + 'union -> list -> function-plain', + 'union -> tuple-positional', + 'union -> tuple-positional -> int', + 'union -> tuple-positional -> int -> int', + 'union -> tuple-positional -> int', + 'union -> tuple-positional -> int -> int', + 'union -> tuple-positional -> function-plain', + 'union -> tuple-variable', + 'union -> tuple-variable -> int', + 'union -> tuple-variable -> int -> int', + 'union -> tuple-variable -> function-plain', + 'union -> set', + 'union -> set -> int', + 'union -> set -> int -> int', + 'union -> set -> function-plain', + 'union -> frozenset', + 'union -> frozenset -> int', + 'union -> frozenset -> int -> int', + 'union -> frozenset -> function-plain', + 'union -> generator', + 'union -> generator -> int', + 'union -> generator -> int -> int', + 'union -> generator -> function-plain', + 'union -> dict', + 'union -> dict -> int', + 'union -> dict -> int -> int', + 'union -> dict -> int', + 'union -> dict -> int -> int', + 'union -> dict -> function-plain', + 'union -> function-after', + 'union -> function-after -> int', + 'union -> function-after -> function-plain', + 'union -> function-before', + 'union -> function-before -> int', + 'union -> function-before -> function-plain', + 'union -> function-wrap', + 'union -> function-wrap -> int', + 'union -> function-wrap -> function-plain', + 'union -> function-plain', + 'union -> function-plain -> function-plain', + 'union -> default', + 'union -> default -> int', + 'union -> default -> function-plain', + 'union -> nullable', + 'union -> nullable -> int', + 'union -> nullable -> function-plain', + 'union -> union', + 'union -> union -> int', + 'union -> union -> str', + 'union -> union -> function-plain', + 'union -> tagged-union', + 'union -> tagged-union -> int', + 'union -> tagged-union -> str', + 'union -> tagged-union -> function-plain', + 'union -> chain', + 'union -> chain -> int', + 'union -> chain -> str', + 'union -> chain -> function-plain', + 'union -> lax-or-strict', + 'union -> lax-or-strict -> int', + 'union -> lax-or-strict -> str', + 'union -> lax-or-strict -> function-plain', + 'union -> json-or-python', + 'union -> json-or-python -> int', + 'union -> json-or-python -> str', + 'union -> json-or-python -> function-plain', + 'union -> typed-dict', + 'union -> typed-dict -> int', + 'union -> typed-dict -> int', + 'union -> typed-dict -> int -> int', + 'union -> typed-dict -> int', + 'union -> typed-dict -> function-plain', + 'union -> model', + 'union -> model -> model-fields', + 'union -> model -> model-fields -> model-field', + 'union -> model -> model-fields -> model-field -> int', + 'union -> model -> model-fields -> int', + 'union -> model -> model-fields -> int -> int', + 'union -> model -> model-fields -> int', + 'union -> dataclass', + 'union -> dataclass -> dataclass-args', + 'union -> dataclass -> dataclass-args -> int', + 'union -> dataclass -> dataclass-args -> int', + 'union -> call', + 'union -> call -> arguments', + 'union -> call -> arguments -> int', + 'union -> call -> arguments -> function-plain', + 'union -> call -> function-plain', + 'union -> custom-error', + 'union -> custom-error -> int', + 'union -> custom-error -> function-plain', + 'union -> json', + 'union -> json -> int', + 'union -> json -> function-plain', + 'union -> url', + 'union -> url -> function-plain', + 'union -> multi-host-url', + 'union -> multi-host-url -> function-plain', + 'union -> definitions', + 'union -> definitions -> int', + 'union -> definitions -> int', + 'union -> definition-ref', + 'union -> definition-ref -> function-plain', + 'union -> uuid', + 'union -> uuid -> function-plain', + ] + + +def test_filter_has_ref() -> None: + handler = TrackingHandler() + + walk = WalkCoreSchema( + visit_core_schema=WalkCoreSchemaFilterBuilder.has_ref().build(handler), + visit_ser_schema=WalkCoreSchemaFilterBuilder.has_ref().build(handler), + ) + + schema = cs.chain_schema( + [ + cs.int_schema(ref='int'), + cs.str_schema( + serialization=cs.wrap_serializer_function_ser_schema( + wrap_ser_func, + schema=cs.str_schema(ref='str'), + ), + ), + cs.list_schema( + cs.float_schema(ref='float'), + ), + ] + ) + + walk.walk(schema) + + # insert_assert(handler.called) + assert handler.called == ['int', 'str', 'float'] + + +def test_filter_type() -> None: + handler = TrackingHandler() + + walk = WalkCoreSchema( + visit_core_schema=WalkCoreSchemaFilterBuilder.has_type('float').build(handler), + visit_ser_schema=WalkCoreSchemaFilterBuilder.has_type('function-wrap').build(handler), + ) + + schema = cs.chain_schema( + [ + cs.int_schema(), + cs.str_schema( + serialization=cs.wrap_serializer_function_ser_schema( + wrap_ser_func, + schema=cs.str_schema(), + ), + ), + cs.list_schema( + cs.float_schema(), + ), + ] + ) + + walk.walk(schema) + + # insert_assert(handler.called) + assert handler.called == ['function-wrap', 'float'] + + +def test_filter_and() -> None: + handler = TrackingHandler() + + walk = WalkCoreSchema( + visit_core_schema=(WalkCoreSchemaFilterBuilder.has_type('float') & WalkCoreSchemaFilterBuilder.has_ref()).build( + handler + ), + visit_ser_schema=( + WalkCoreSchemaFilterBuilder.has_type('function-wrap') + & WalkCoreSchemaFilterBuilder.predicate(lambda s: s['schema']['type'] == 'str') + ).build(handler), + ) + + schema = cs.chain_schema( + [ + cs.int_schema(ref='int'), + cs.str_schema( + serialization=cs.wrap_serializer_function_ser_schema( + wrap_ser_func, + schema=cs.int_schema(), + ), + ), + cs.str_schema( + serialization=cs.wrap_serializer_function_ser_schema( + wrap_ser_func, + schema=cs.str_schema(), + ), + ), + cs.str_schema( + serialization=cs.plain_serializer_function_ser_schema(plain_ser_func), + ), + cs.float_schema(), + cs.list_schema( + cs.float_schema(), + ), + cs.list_schema( + cs.float_schema(ref='float'), + ), + ] + ) + + walk.walk(schema) + + # insert_assert(handler.called) + assert handler.called == ['function-wrap', 'float'] + + +def test_filter_or() -> None: + handler = TrackingHandler() + + walk = WalkCoreSchema( + visit_core_schema=( + WalkCoreSchemaFilterBuilder.has_type('float') + | WalkCoreSchemaFilterBuilder.predicate(lambda s: s.get('ref') == 'int') + ).build(handler), + visit_ser_schema=( + WalkCoreSchemaFilterBuilder.has_type('function-wrap') + | WalkCoreSchemaFilterBuilder.predicate(lambda s: s.get('type', '') == 'str') + ).build(handler), + ) + + schema = cs.chain_schema( + [ + cs.int_schema(ref='int'), + cs.str_schema( + serialization=cs.wrap_serializer_function_ser_schema( + wrap_ser_func, + schema=cs.int_schema(), + ), + ), + cs.str_schema( + serialization=cs.wrap_serializer_function_ser_schema( + wrap_ser_func, + schema=cs.str_schema(), + ), + ), + cs.str_schema( + serialization=cs.simple_ser_schema('str'), + ), + cs.bool_schema(), + cs.list_schema( + cs.float_schema(), + ), + cs.list_schema( + cs.float_schema(ref='float'), + ), + ] + ) + + walk.walk(schema) + + # insert_assert(handler.called) + assert handler.called == [ + 'int', + 'function-wrap', + 'function-wrap', + 'str', + 'float', + 'float', + ] + + +def test_edit_core_schema() -> None: + def replace_with_float_schema(schema: CoreSchema, call_next: Callable[[CoreSchema], CoreSchema]) -> CoreSchema: + return cs.float_schema() + + def replace_with_float_serializer(schema: SerSchema, call_next: Callable[[SerSchema], SerSchema]) -> SerSchema: + return cs.simple_ser_schema('float') + + walk = WalkCoreSchema( + visit_core_schema=WalkCoreSchemaFilterBuilder.has_type('int').build(replace_with_float_schema), + visit_ser_schema=WalkCoreSchemaFilterBuilder.has_type('int').build(replace_with_float_serializer), + ) + + schema = cs.chain_schema( + [ + cs.bool_schema(), + cs.int_schema(), + cs.bool_schema( + serialization=cs.simple_ser_schema('int'), + ), + cs.int_schema( + serialization=cs.simple_ser_schema('int'), + ), + ] + ) + + schema = walk.walk(schema) + + # insert_assert(schema) + assert schema == { + 'type': 'chain', + 'steps': [ + {'type': 'bool'}, + {'type': 'float'}, + {'type': 'bool', 'serialization': {'type': 'float'}}, + {'type': 'float'}, + ], + } + + +def test_skip_call_next() -> None: + def return_if_list(schema: CoreSchema, call_next: CoreSchemaCallNext) -> CoreSchema: + if schema['type'] == 'list': + return schema + if schema['type'] == 'int': + return cs.float_schema() + return call_next(schema) + + walk = WalkCoreSchema( + visit_core_schema=( + WalkCoreSchemaFilterBuilder.has_type('list') + | WalkCoreSchemaFilterBuilder.has_type('set') + | WalkCoreSchemaFilterBuilder.has_type('int') + ).build(return_if_list), + ) + + schema = cs.chain_schema( + [ + cs.list_schema( + cs.int_schema(), + ), + cs.set_schema( + cs.int_schema(), + ), + cs.float_schema(), + ] + ) + + schema = walk.walk(schema) + + # insert_assert(schema) + assert schema == { + 'type': 'chain', + 'steps': [ + {'type': 'list', 'items_schema': {'type': 'int'}}, + {'type': 'set', 'items_schema': {'type': 'float'}}, + {'type': 'float'}, + ], + }