Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

roll down parent inheritance recursively #14

Merged
merged 18 commits into from
Oct 1, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
get ting there, working rolldown of extra attributes, but something s…
…till funny in patchclampseries children w.r.t. losing attributes in data
sneakers-the-rat committed Sep 14, 2024
commit cad57554fd04095fa49ffecb1ca7d122258d7891
260 changes: 183 additions & 77 deletions nwb_linkml/src/nwb_linkml/adapters/namespaces.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,6 @@
import contextlib
from copy import copy
from pathlib import Path
from pprint import pformat
from typing import Dict, Generator, List, Optional

from linkml_runtime.dumpers import yaml_dumper
@@ -19,7 +18,6 @@
from nwb_linkml.adapters.schema import SchemaAdapter
from nwb_linkml.lang_elements import NwbLangSchema
from nwb_linkml.ui import AdapterProgress
from nwb_linkml.util import merge_dicts
from nwb_schema_language import Dataset, Group, Namespaces


@@ -188,93 +186,105 @@
if not cls.neurodata_type_inc:
continue

# get parents
parent = self.get(cls.neurodata_type_inc)
parents = [parent]
while parent.neurodata_type_inc:
parent = self.get(parent.neurodata_type_inc)
parents.insert(0, parent)
parents.append(cls)
parents = self._get_class_ancestors(cls, include_child=True)

# merge and cast
# note that we don't want to exclude_none in the model dump here,
# if the child class has a field completely unset, we want to inherit it
# from the parent without rolling it down - we are only rolling down
# the things that need to be modified/merged in the child
new_cls: dict = {}
for parent in parents:
new_cls = merge_dicts(
new_cls,
parent.model_dump(exclude_unset=True),
list_key="name",
exclude=["neurodata_type_def"],
)
for i, parent in enumerate(parents):
# if parent.neurodata_type_def == "PatchClampSeries":
# pdb.set_trace()
complete = True
if i == len(parents) - 1:
complete = False
new_cls = roll_down_nwb_class(new_cls, parent, complete=complete)
new_cls: Group | Dataset = type(cls)(**new_cls)
new_cls.parent = cls.parent

# reinsert
if new_cls.parent:
if isinstance(cls, Dataset):
new_cls.parent.datasets[new_cls.parent.datasets.index(cls)] = new_cls
else:
new_cls.parent.groups[new_cls.parent.groups.index(cls)] = new_cls
self._overwrite_class(new_cls, cls)

def _get_class_ancestors(
self, cls: Dataset | Group, include_child: bool = True
) -> list[Dataset | Group]:
"""
Get the chain of ancestor classes inherited via ``neurodata_type_inc``

Args:
cls (:class:`.Dataset` | :class:`.Group`): The class to get ancestors of
include_child (bool): If ``True`` (default), include ``cls`` in the output list
"""
parent = self.get(cls.neurodata_type_inc)
parents = [parent]
while parent.neurodata_type_inc:
parent = self.get(parent.neurodata_type_inc)
parents.insert(0, parent)

if include_child:
parents.append(cls)

return parents

def _overwrite_class(self, new_cls: Dataset | Group, old_cls: Dataset | Group):
"""
Overwrite the version of a dataset or group that is stored in our schemas
"""
if old_cls.parent:
if isinstance(old_cls, Dataset):
new_cls.parent.datasets[new_cls.parent.datasets.index(old_cls)] = new_cls
else:
# top level class, need to go and find it
found = False
for schema in self.all_schemas():
if isinstance(cls, Dataset):
if cls in schema.datasets:
schema.datasets[schema.datasets.index(cls)] = new_cls
found = True
break
else:
if cls in schema.groups:
schema.groups[schema.groups.index(cls)] = new_cls
found = True
break
if not found:
raise KeyError(
f"Unable to find source schema for {cls} when reinserting after rolling"
" down!"
)

def find_type_source(self, name: str) -> SchemaAdapter:
new_cls.parent.groups[new_cls.parent.groups.index(old_cls)] = new_cls
else:
# top level class, need to go and find it
schema = self.find_type_source(old_cls)
if isinstance(new_cls, Dataset):
schema.datasets[schema.datasets.index(old_cls)] = new_cls
else:
schema.groups[schema.groups.index(old_cls)] = new_cls

def find_type_source(self, cls: str | Dataset | Group, fast: bool = False) -> SchemaAdapter:
"""
Given some neurodata_type_inc, find the schema that it's defined in.
Given some type (as `neurodata_type_def`), find the schema that it's defined in.

Rather than returning as soon as a match is found, ensure that duplicates are
not found within the primary schema, then so the same for all imported schemas.

Args:
cls (str | :class:`.Dataset` | :class:`.Group`): The ``neurodata_type_def``
to look for the source of. If a Dataset or Group, look for the object itself
(cls in schema.datasets), otherwise look for a class with a matching name.
fast (bool): If ``True``, return as soon as a match is found.
If ``False`, return after checking all schemas for duplicates.

Returns:
:class:`.SchemaAdapter`

Rather than returning as soon as a match is found, check all
Raises:
KeyError: if multiple schemas or no schemas are found
"""
# First check within the main schema
internal_matches = []
for schema in self.schemas:
class_names = [cls.neurodata_type_def for cls in schema.created_classes]
if name in class_names:
internal_matches.append(schema)

if len(internal_matches) > 1:
raise KeyError(
f"Found multiple schemas in namespace that define {name}:\ninternal:"
f" {pformat(internal_matches)}\nimported:{pformat(internal_matches)}"
)
elif len(internal_matches) == 1:
return internal_matches[0]

import_matches = []
for imported_ns in self.imported:
for schema in imported_ns.schemas:
class_names = [cls.neurodata_type_def for cls in schema.created_classes]
if name in class_names:
import_matches.append(schema)

if len(import_matches) > 1:
raise KeyError(
f"Found multiple schemas in namespace that define {name}:\ninternal:"
f" {pformat(internal_matches)}\nimported:{pformat(import_matches)}"
)
elif len(import_matches) == 1:
return import_matches[0]
matches = []
for schema in self.all_schemas():
in_schema = False
if isinstance(cls, str) and cls in [
c.neurodata_type_def for c in schema.created_classes
]:
in_schema = True
elif isinstance(cls, Dataset) and cls in schema.datasets:
in_schema = True
elif isinstance(cls, Group) and cls in schema.groups:
in_schema = True

if in_schema:
if fast:
return schema
else:
matches.append(schema)

if len(matches) > 1:
raise KeyError(f"Found multiple schemas in namespace that define {cls}:\n{matches}")
elif len(matches) == 1:
return matches[0]
else:
raise KeyError(f"No schema found that define {name}")
raise KeyError(f"No schema found that define {cls}")

def _populate_imports(self) -> "NamespacesAdapter":
"""
@@ -378,3 +388,99 @@
for imported in self.imported:
for sch in imported.schemas:
yield sch


def roll_down_nwb_class(
source: Group | Dataset | dict, target: Group | Dataset | dict, complete: bool = False
) -> dict:
"""
Merge an ancestor (via ``neurodata_type_inc`` ) source class with a
child ``target`` class.

On the first recurive pass, only those values that are set on the target are copied from the

Check failure on line 400 in nwb_linkml/src/nwb_linkml/adapters/namespaces.py

GitHub Actions / Check for spelling errors

recurive ==> recursive
source class - this isn't a true merging, what we are after is to recursively merge all the
values that are modified in the child class with those of the parent class below the top level,
the top-level attributes will be carried through via normal inheritance.

Rather than re-instantiating the child class, we return the dictionary so that this
function can be used in series to merge a whole ancestry chain within
:class:`.NamespacesAdapter` , but this isn't exposed in the function since
class definitions can be spread out over many schemas, and we need the orchestration
of the adapter to have them in all cases we'd be using this.

Args:
source (dict): source dictionary
target (dict): target dictionary (values merged over source)
complete (bool): (default ``False``)do a complete merge, merging everything
from source to target without trying to minimize redundancy.
Used to collapse ancestor classes before the terminal class.

References:
https://github.com/NeurodataWithoutBorders/pynwb/issues/1954

"""
if isinstance(source, (Group, Dataset)):
source = source.model_dump(exclude_unset=True, exclude_none=True)
if isinstance(target, (Group, Dataset)):
target = target.model_dump(exclude_unset=True, exclude_none=True)

exclude = ("neurodata_type_def",)

# if we are on the first recursion, we exclude top-level items that are not set in the target
if complete:
ret = {k: v for k, v in source.items() if k not in exclude}
else:
ret = {k: v for k, v in source.items() if k not in exclude and k in target}

for key, value in target.items():
if key not in ret:
ret[key] = value
elif isinstance(value, dict):
if key in ret:
ret[key] = roll_down_nwb_class(ret[key], value, complete=True)
else:
ret[key] = value
elif isinstance(value, list) and all([isinstance(v, dict) for v in value]):
src_keys = {v["name"]: ret[key].index(v) for v in ret.get(key, {}) if "name" in v}
target_keys = {v["name"]: value.index(v) for v in value if "name" in v}

new_val = []
# screwy double iteration to preserve dict order
# all dicts not in target, if in depth > 0
if complete:
new_val.extend(
[
ret[key][src_keys[k]]
for k in src_keys
if k in set(src_keys.keys()) - set(target_keys.keys())
]
)
# all dicts not in source
new_val.extend(
[
value[target_keys[k]]
for k in target_keys
if k in set(target_keys.keys()) - set(src_keys.keys())
]
)
# merge dicts in both
new_val.extend(
[
roll_down_nwb_class(ret[key][src_keys[k]], value[target_keys[k]], complete=True)
for k in target_keys
if k in set(src_keys.keys()).intersection(set(target_keys.keys()))
]
)
new_val = sorted(new_val, key=lambda i: i["name"])
# add any dicts that don't have the list_key
# they can't be merged since they can't be matched
if complete:
new_val.extend([v for v in ret.get(key, {}) if "name" not in v])
new_val.extend([v for v in value if "name" not in v])

ret[key] = new_val

else:
ret[key] = value

return ret
15 changes: 9 additions & 6 deletions nwb_linkml/src/nwb_linkml/generators/pydantic.py
Original file line number Diff line number Diff line change
@@ -136,7 +136,7 @@ def after_generate_class(self, cls: ClassResult, sv: SchemaView) -> ClassResult:
"""Customize dynamictable behavior"""
cls = AfterGenerateClass.inject_dynamictable(cls)
cls = AfterGenerateClass.wrap_dynamictable_columns(cls, sv)
cls = AfterGenerateClass.inject_elementidentifiers(cls, sv, self._get_element_import)
cls = AfterGenerateClass.inject_dynamictable_imports(cls, sv, self._get_element_import)
cls = AfterGenerateClass.strip_vector_data_slots(cls, sv)
return cls

@@ -346,19 +346,22 @@ def wrap_dynamictable_columns(cls: ClassResult, sv: SchemaView) -> ClassResult:
return cls

@staticmethod
def inject_elementidentifiers(
def inject_dynamictable_imports(
cls: ClassResult, sv: SchemaView, import_method: Callable[[str], Import]
) -> ClassResult:
"""
Inject ElementIdentifiers into module that define dynamictables -
needed to handle ID columns
Ensure that schema that contain dynamictables have all the imports needed to use them
"""
if (
cls.source.is_a == "DynamicTable"
or "DynamicTable" in sv.class_ancestors(cls.source.name)
) and sv.schema.name != "hdmf-common.table":
imp = import_method("ElementIdentifiers")
cls.imports += [imp]
imp = [
import_method("ElementIdentifiers"),
import_method("VectorData"),
import_method("VectorIndex"),
]
cls.imports += imp
return cls

@staticmethod
Loading