Skip to content

Commit

Permalink
working pydantic generator
Browse files Browse the repository at this point in the history
  • Loading branch information
sneakers-the-rat committed Jul 3, 2024
1 parent 087064b commit 01cfb54
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 88 deletions.
3 changes: 2 additions & 1 deletion nwb_linkml/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ dependencies = [
"linkml-runtime>=1.7.7",
"nwb-schema-language>=0.1.3",
"rich>=13.5.2",
"linkml>=1.7.10",
#"linkml>=1.7.10",
"linkml @ git+https://github.com/sneakers-the-rat/linkml@arrays-numpydantic",
"nptyping>=2.5.0",
"pydantic>=2.3.0",
"h5py>=3.9.0",
Expand Down
153 changes: 66 additions & 87 deletions nwb_linkml/src/nwb_linkml/generators/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,26 @@
# ruff: noqa

import inspect
import pdb
import sys
import warnings
from copy import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type, Union

from jinja2 import Template
from linkml.generators import PydanticGenerator
from linkml.generators.pydanticgen.array import ArrayRepresentation
from linkml.generators.common.type_designators import (
get_type_designator_value,
)
from linkml.utils.ifabsent_functions import ifabsent_value_declaration
from linkml_runtime.linkml_model.meta import (
Annotation,
AnonymousSlotExpression,
ArrayExpression,
ClassDefinition,
ClassDefinitionName,
ElementName,
Expand All @@ -53,8 +57,9 @@
SlotDefinitionName,
)
from linkml_runtime.utils.compile_python import file_text
from linkml_runtime.utils.formatutils import camelcase, underscore
from linkml_runtime.utils.formatutils import camelcase, underscore, remove_empty_items
from linkml_runtime.utils.schemaview import SchemaView

from pydantic import BaseModel

from nwb_linkml.maps import flat_to_nptyping
Expand Down Expand Up @@ -268,6 +273,9 @@ class NWBPydanticGenerator(PydanticGenerator):
versions: dict = None
"""See :meth:`.LinkMLProvider.build` for usage - a list of specific versions to import from"""
pydantic_version = "2"
array_representations: List[ArrayRepresentation] = field(
default_factory=lambda: [ArrayRepresentation.NUMPYDANTIC])
black: bool = True

def _locate_imports(self, needed_classes: List[str], sv: SchemaView) -> Dict[str, List[str]]:
"""
Expand Down Expand Up @@ -427,7 +435,16 @@ def _check_anyof(
): # pragma: no cover
# Confirm that the original slot range (ignoring the default that comes in from
# induced_slot) isn't in addition to setting any_of
allowed_keys = ('array',)

if len(s.any_of) > 0 and sv.get_slot(sn).range is not None:
allowed = True
for option in s.any_of:
items = remove_empty_items(option)
if not all([key in allowed_keys for key in items.keys()]):
allowed=False
if allowed:
return
base_range_subsumes_any_of = False
base_range = sv.get_slot(sn).range
base_range_cls = sv.get_class(base_range, strict=False)
Expand All @@ -436,75 +453,6 @@ def _check_anyof(
if not base_range_subsumes_any_of:
raise ValueError("Slot cannot have both range and any_of defined")

def _make_npytyping_range(self, attrs: Dict[str, SlotDefinition]) -> str:
# slot always starts with...
prefix = "NDArray["

# and then we specify the shape:
shape_prefix = 'Shape["'

# using the cardinality from the attributes
dim_pieces = []
for attr in attrs.values():

if attr.maximum_cardinality:
shape_part = str(attr.maximum_cardinality)
else:
shape_part = "*"

# do this with the most heinous chain of string replacements rather than regex
# because i am still figuring out what needs to be subbed lol
name_part = (
attr.name.replace(",", "_")
.replace(" ", "_")
.replace("__", "_")
.replace("|", "_")
.replace("-", "_")
.replace("+", "plus")
)

dim_pieces.append(" ".join([shape_part, name_part]))

dimension = ", ".join(dim_pieces)

shape_suffix = '"], '

# all dimensions should be the same dtype
try:
dtype = flat_to_nptyping[list(attrs.values())[0].range]
except KeyError as e: # pragma: no cover
warnings.warn(str(e))
range = list(attrs.values())[0].range
return f"List[{range}] | {range}"
suffix = "]"

slot = "".join([prefix, shape_prefix, dimension, shape_suffix, dtype, suffix])
return slot

def _get_numpy_slot_range(self, cls: ClassDefinition) -> str:
# if none of the dimensions are optional, we just have one possible array shape
if all([s.required for s in cls.attributes.values()]): # pragma: no cover
return self._make_npytyping_range(cls.attributes)
# otherwise we need to make permutations
# but not all permutations, because we typically just want to be able to exclude the last possible dimensions
# the array classes should always be well-defined where the optional dimensions are at the end, so
requireds = {k: v for k, v in cls.attributes.items() if v.required}
optionals = [(k, v) for k, v in cls.attributes.items() if not v.required]

annotations = []
if len(requireds) > 0:
# first the base case
annotations.append(self._make_npytyping_range(requireds))
# then add back each optional dimension
for i in range(len(optionals)):
attrs = {**requireds, **{k: v for k, v in optionals[0 : i + 1]}}
annotations.append(self._make_npytyping_range(attrs))

# now combine with a union:
union = "Union[\n" + " " * 8
union += (",\n" + " " * 8).join(annotations)
union += "\n" + " " * 4 + "]"
return union

def _get_linkml_classvar(self, cls: ClassDefinition) -> SlotDefinition:
"""A class variable that holds additional linkml attrs"""
Expand Down Expand Up @@ -566,17 +514,6 @@ def sort_classes(
self.sorted_class_names += [camelcase(c.name) for c in slist]
return slist

def get_class_slot_range(self, slot_range: str, inlined: bool, inlined_as_list: bool) -> str:
"""
Monkeypatch to convert Array typed slots and classes into npytyped hints
"""
sv = self.schemaview
range_cls = sv.get_class(slot_range)
if range_cls.is_a == "Arraylike":
return self._get_numpy_slot_range(range_cls)
else:
return self._get_class_slot_range_origin(slot_range, inlined, inlined_as_list)

def _get_class_slot_range_origin(
self, slot_range: str, inlined: bool, inlined_as_list: bool
) -> str:
Expand Down Expand Up @@ -694,6 +631,36 @@ def get_predefined_slot_value(

return slot_value

def generate_python_range(self, slot_range, slot_def: SlotDefinition, class_def: ClassDefinition) -> str:
"""
Generate the python range for a slot range value
"""
if isinstance(slot_range, ArrayExpression):
temp_slot = SlotDefinition(name='array', array=slot_range)
inner_range = super().generate_python_range(slot_def.range, slot_def, class_def)
results = super().get_array_representations_range(temp_slot, inner_range)
return results[0].annotation
elif isinstance(slot_range, AnonymousSlotExpression):
if slot_range.range is None:
inner_range = slot_def.range
else:
inner_range = slot_range.range

inner_range = super().generate_python_range(inner_range, slot_def, class_def)
if slot_range.array is not None:
temp_slot = SlotDefinition(name='array', array=slot_range.array)
results = super().get_array_representations_range(temp_slot, inner_range)
inner_range = results[0].annotation
return inner_range
elif isinstance(slot_range, dict):
pdb.set_trace()
elif slot_def.array is not None:
inner_range = super().generate_python_range(slot_def.range, slot_def, class_def)
results = super().get_array_representations_range(slot_def, inner_range)
return results[0].annotation
else:
return super().generate_python_range(slot_range, slot_def, class_def)

def serialize(self) -> str:
predefined_slot_values = {}
"""splitting up parent class :meth:`.get_predefined_slot_values`"""
Expand Down Expand Up @@ -763,15 +730,22 @@ def serialize(self) -> str:
s.description = s.description.replace('"', '\\"')
class_def.attributes[s.name] = s

slot_ranges: List[str] = []
slot_ranges: List[Union[str, ArrayExpression, AnonymousSlotExpression]] = []

self._check_anyof(s, sn, sv)

if s.any_of is not None and len(s.any_of) > 0:
# list comprehension here is pulling ranges from within AnonymousSlotExpression
slot_ranges.extend([r.range for r in s.any_of])
if isinstance(s.any_of, dict):
any_ofs = list(s.any_of.values())
else:
any_ofs = s.any_of
slot_ranges.extend(any_ofs)
else:
slot_ranges.append(s.range)
if s.array is not None:
slot_ranges.append(s.array)
else:
slot_ranges.append(s.range)

pyranges = [
self.generate_python_range(slot_range, s, class_def)
Expand All @@ -798,7 +772,12 @@ def serialize(self) -> str:

if s.multivalued:
if s.inlined or s.inlined_as_list:
collection_key = self.generate_collection_key(slot_ranges, s, class_def)
try:
collection_key = self.generate_collection_key(slot_ranges, s, class_def)
except TypeError:
# from not being able to hash an anonymous slot expression.
# hack, we can fix this by merging upstream pydantic generator cleanup
collection_key = None
else: # pragma: no cover
collection_key = None
if (
Expand Down

0 comments on commit 01cfb54

Please sign in to comment.