Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a WalkCoreSchema #1099

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Add a WalkCoreSchema #1099

wants to merge 8 commits into from

Conversation

adriangb
Copy link
Member

@adriangb adriangb commented Nov 29, 2023

@samuelcolvin @dmontagu we discussed this previously and I shot it down because just implementing what we have in pydantic in Rust would not be much faster (aside from the speedup of calling CPython APIs from Rust) because for every 2-3 key accesses we do in Rust (which would be faster) we'd be calling into Python and back (so the absolute change may not be very large and there's the FFI slowdown to contend with).

I was thinking about it some more and I think if we change the API we have in pydantic to this we can get a much larger speedup. Essentially, instead of having a single callback for all schemas I'm doing a different callback for each schema. This serves as a sort of "filter" to minimize calls into Python. Out of the ~3 "walks" we do in pydantic this covers two:

However, this does not cover the case where we need to visit every schema:
https://github.com/pydantic/pydantic/blob/667cd3776ee40e06018d0b7ff477c6cd0199b098/pydantic/_internal/_core_utils.py#L449-L450

For that last case I see a couple of options:

  • Add a visit_all_schemas callback that slows things down significantly but allows visiting all schemas (and hence collecting all refs).
  • Add a visit_schema_with_ref that gets called for any schema with a ref. This seems somewhat reasonable but it may be a bit too "specialized" of an implementation for our current use case. That is, it's a bandaid solution to a poor API.
  • Add a more powerful filter predicate system. For example you could have Walk(visit=[if_schema_has_key("ref")(callback), if_schema_has_type("int")(callback), (if_schema_has_type("int") & if_schema_has_key("ref"))(callback)]). This maybe also works to get rid of the dozens of arguments to the constructor this implementation currently has.

@adriangb adriangb requested review from davidhewitt and samuelcolvin and removed request for davidhewitt November 29, 2023 21:50
@adriangb adriangb self-assigned this Nov 29, 2023
Copy link

codecov bot commented Nov 29, 2023

Codecov Report

Merging #1099 (772c8c3) into main (7fa450d) will increase coverage by 0.12%.
The diff coverage is 93.04%.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1099      +/-   ##
==========================================
+ Coverage   89.70%   89.83%   +0.12%     
==========================================
  Files         106      107       +1     
  Lines       16364    16982     +618     
  Branches       35       35              
==========================================
+ Hits        14680    15255     +575     
- Misses       1677     1720      +43     
  Partials        7        7              
Files Coverage Δ
python/pydantic_core/__init__.py 92.59% <ø> (ø)
src/lib.rs 87.50% <100.00%> (+0.35%) ⬆️
src/walk_core_schema.rs 93.01% <93.01%> (ø)

Continue to review full report in Codecov by Sentry.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 7fa450d...772c8c3. Read the comment docs.

Copy link

codspeed-hq bot commented Nov 29, 2023

CodSpeed Performance Report

Merging #1099 will degrade performances by 18.81%

Comparing walk-core-schema (772c8c3) with main (7fa450d)

Summary

❌ 1 regressions
✅ 139 untouched benchmarks

⚠️ Please fix the performance issues or acknowledge them on CodSpeed.

Benchmarks breakdown

Benchmark main walk-core-schema Change
test_core_future_str 31.5 µs 38.8 µs -18.81%

@davidhewitt
Copy link
Contributor

I like the predicate filter idea, I think hopefully it'll lead to a smaller implementation that'll also be easier to maintain if we introduce future schema types.

You could even make the predicates user-suppliable in Python which may help adopting the implementation (move more predicates to Rust as needed).

@adriangb
Copy link
Member Author

adriangb commented Dec 1, 2023

That's an interesting idea. Like make something like:

VisitPredicate = Callable[[CoreSchema], bool]

That gets run from Rust. Then write any predicates we need in Python. For example:

@dataclass
class CombinedPredicate:
    call: Callable[[CoreSchema], bool]
    def __call__(self, schema):
        return self.call(schema)

class CombinablePredicate:
    def __or__(self, other):
        return CombinedPredicate(lambda s: self(s) or other(s))

class HasRef(CombinablePredicate):
    def __call__(self, schema: CoreSchema) -> bool:
        return bool(schema.get('ref', False))

And once those are stabilized we can move them to Rust. Is that what you had in mind?

@adriangb
Copy link
Member Author

adriangb commented Dec 1, 2023

@davidhewitt I implemented the filter API as discussed above

@adriangb adriangb marked this pull request as ready for review December 1, 2023 20:15
@adriangb adriangb changed the title Add a WalkCoreSchema implementation Add a WalkCoreSchema Dec 1, 2023
@adriangb
Copy link
Member Author

adriangb commented Dec 3, 2023

@davidhewitt I benchmarked this and it's coming out no faster than our existing Python version (which calls a Python function at every level in addition to doing the traversal in Python) even when there is no filter (so it never calls into Python).

import timeit
from typing import Any, Callable

from pydantic._internal._core_utils import walk_core_schema

from pydantic_core import CoreSchema, WalkCoreSchema
from pydantic_core import core_schema as cs


def plain_ser_func(x: Any) -> str:
    return 'abc'


def wrap_ser_func(x: Any, handler: cs.SerializerFunctionWrapHandler) -> Any:
    return handler(x)



def no_info_val_func(x: Any) -> Any:
    return x



def no_info_wrap_val_func(x: Any, handler: cs.ValidatorFunctionWrapHandler) -> Any:
    return handler(x)


class NamedClass:
    pass


schema = cs.union_schema(
    [
        cs.any_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
        cs.none_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
        cs.bool_schema(serialization=cs.simple_ser_schema('bool')),
        cs.int_schema(serialization=cs.simple_ser_schema('int')),
        cs.float_schema(serialization=cs.simple_ser_schema('float')),
        cs.decimal_schema(serialization=cs.plain_serializer_function_ser_schema(plain_ser_func)),
        cs.str_schema(serialization=cs.simple_ser_schema('str')),
        cs.bytes_schema(serialization=cs.simple_ser_schema('bytes')),
        cs.date_schema(serialization=cs.simple_ser_schema('date')),
        cs.time_schema(serialization=cs.simple_ser_schema('time')),
        cs.datetime_schema(serialization=cs.simple_ser_schema('datetime')),
        cs.timedelta_schema(serialization=cs.simple_ser_schema('timedelta')),
        cs.literal_schema(
            expected=[1, 2, 3],
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.is_instance_schema(
            cls=NamedClass,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.is_subclass_schema(
            cls=NamedClass,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.callable_schema(
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.list_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.tuple_positional_schema(
            [cs.int_schema(serialization=cs.simple_ser_schema('int'))],
            extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.tuple_variable_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.set_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.frozenset_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.generator_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.dict_schema(
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.no_info_after_validator_function(
            no_info_val_func,
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.no_info_before_validator_function(
            no_info_val_func,
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.no_info_wrap_validator_function(
            no_info_wrap_val_func,
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.no_info_plain_validator_function(
            no_info_val_func,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.with_default_schema(
            cs.int_schema(),
            default=1,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.nullable_schema(
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.union_schema(
            [
                cs.int_schema(),
                cs.str_schema(),
            ],
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.tagged_union_schema(
            {
                'a': cs.int_schema(),
                'b': cs.str_schema(),
            },
            'type',
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.chain_schema(
            [
                cs.int_schema(),
                cs.str_schema(),
            ],
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.lax_or_strict_schema(
            cs.int_schema(),
            cs.str_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.json_or_python_schema(
            cs.int_schema(),
            cs.str_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.typed_dict_schema(
            {'a': cs.typed_dict_field(cs.int_schema())},
            computed_fields=[
                cs.computed_field(
                    'b',
                    cs.int_schema(),
                )
            ],
            extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.model_schema(
            NamedClass,
            cs.model_fields_schema(
                {'a': cs.model_field(cs.int_schema())},
                extras_schema=cs.int_schema(serialization=cs.simple_ser_schema('int')),
                computed_fields=[
                    cs.computed_field(
                        'b',
                        cs.int_schema(),
                    )
                ],
            ),
        ),
        cs.dataclass_schema(
            NamedClass,
            cs.dataclass_args_schema(
                'Model',
                [cs.dataclass_field('a', cs.int_schema())],
                computed_fields=[
                    cs.computed_field(
                        'b',
                        cs.int_schema(),
                    )
                ],
            ),
            ['a'],
        ),
        cs.call_schema(
            cs.arguments_schema(
                [cs.arguments_parameter('x', cs.int_schema())],
                serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
            ),
            no_info_val_func,
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.custom_error_schema(
            cs.int_schema(),
            custom_error_type='CustomError',
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.json_schema(
            cs.int_schema(),
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.url_schema(
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.multi_host_url_schema(
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.definitions_schema(
            cs.int_schema(),
            [
                cs.int_schema(ref='#/definitions/int'),
            ],
        ),
        cs.definition_reference_schema(
            '#/definitions/int',
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
        cs.uuid_schema(
            serialization=cs.plain_serializer_function_ser_schema(plain_ser_func),
        ),
    ]
)


def walk_core() -> None:
    WalkCoreSchema().walk(schema)

Recurse = Callable[[cs.CoreSchema, 'Walk'], cs.CoreSchema]
Walk = Callable[[cs.CoreSchema, Recurse], cs.CoreSchema]


def visit_pydantic(schema: cs.CoreSchema, recurse: Recurse) -> CoreSchema:
    return recurse(schema, visit_pydantic)

def walk_pydantic() -> None:
    walk_core_schema(schema, visit_pydantic)


print(timeit.timeit(walk_core, number=1000))
print(timeit.timeit(walk_pydantic, number=1000))

@samuelcolvin
Copy link
Member

Let's revisit this once #1085 gets merged which might improve performance significantly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants