Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix JSON schema generator to work with pydantic v2 #252

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions extra/gen_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,43 @@

os.environ["BUMPS_USE_PYDANTIC"] = "True"

from pydantic.schema import get_model, schema as make_schema
from bumps.parameter import Expression, Parameter # , UnaryExpression
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import PydanticOmit, core_schema
from pydantic import BaseModel
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic import TypeAdapter


class CustomsGenerateJsonSchema(GenerateJsonSchema):
def handle_invalid_for_json_schema(self, schema: core_schema.CoreSchema, error_info: str) -> JsonSchemaValue:
print(f"Handling invalid schema: {error_info}")
raise PydanticOmit

def generate(self, schema: core_schema.CoreSchema, mode: str = "validation") -> JsonSchemaValue:
json_schema = super().generate(schema, mode=mode)
# Filter out properties starting with underscore
filtered_properties = {k: v for k, v in json_schema.get("properties", {}).items() if not k.startswith("_")}
json_schema["properties"] = filtered_properties
filtered_requires = [k for k in json_schema.get("required", []) if not k.startswith("_")]
json_schema["required"] = filtered_requires
return json_schema

from refl1d.names import FitProblem, Repeat, Stack
def handle_property(self, property_name: str, property_value: JsonSchemaValue):
if property_name.startswith("_"):
return JsonSchemaValue(None)
return super().handle_property(property_name, property_value)


from bumps.parameter import Expression, Parameter # , UnaryExpression

base_model = get_model(FitProblem)
from refl1d.bumps_interface.fitproblem import FitProblem

# resolve circular dependencies and self-references
# TODO: this will be unnecessary in python 3.7+ with
# 'from __future__ import annotations'
# and in python 4.0+ presumably that can be removed as well.
to_resolve = [
Expression,
Parameter, # UnaryExpression
Repeat,
Stack,
]
for module in to_resolve:
get_model(module).update_forward_refs()
ta = TypeAdapter(FitProblem)
ta.rebuild()

schema = {"$schema": "https://json-schema.org/draft-07/schema#", "$id": "refl1d-draft-01"}
# schema.update(get_model(FitProblem).schema())
schema.update(FitProblem.schema())
schema.update(ta.json_schema(schema_generator=CustomsGenerateJsonSchema))


def remove_default_typename(schema):
Expand Down Expand Up @@ -71,4 +85,5 @@ def convert_inf(obj):
remove_proptitles(schema)
schema = convert_inf(schema)

os.makedirs("schema", exist_ok=True)
open("schema/refl1d.schema.json", "w").write(json.dumps(schema, allow_nan=False, indent=2))
2 changes: 1 addition & 1 deletion refl1d/names.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)
from .sample.cheby import ChebyVF, FreeformCheby, cheby_approx, cheby_points
from .sample.flayer import FunctionalMagnetism, FunctionalProfile
from .sample.layers import Slab, Stack
from .sample.layers import Slab, Stack, Repeat, Layer
from .sample.magnetism import FreeMagnetism, Magnetism, MagnetismStack, MagnetismTwist
from .sample.material import SLD, Compound, Material, Mixture

Expand Down
Loading