diff --git a/src/cript/nodes/main.py b/src/cript/nodes/main.py index db95ecc..f1d4912 100644 --- a/src/cript/nodes/main.py +++ b/src/cript/nodes/main.py @@ -7,6 +7,8 @@ from jsonschema.exceptions import best_match from uuid import uuid4 +import cript +import inspect from cript import Cript, NotFoundError, camel_case_to_snake_case, extract_node_from_result from .schema import cript_schema @@ -22,7 +24,7 @@ def __init__(self, *args, **kwargs): self.__dict__["exists"] = self.__dict__.get("exists", False) self.__dict__["client"] = self.__dict__.get("client", Cript()) self.__dict__["children"] = self.__dict__.get("children", {}) - self.__dict__["initialized"] = False + self.__dict__["initialized"] = kwargs.get("initialized", False) self.__dict__["parent"] = None schema = copy.deepcopy(cript_schema) schema["$ref"] = f"#/$defs/{self.__class__.__name__}Post" @@ -37,6 +39,14 @@ def __init__(self, *args, **kwargs): self.__dict__["validator_instance"] = cls(schema, resolver=resolver) else: self.__dict__["validator_instance"] = cls(schema) + + allowed_attributes = self._allowed_attributes(kwargs) + + # Early exit for initialized nodes + if self.initialized: + for key in allowed_attributes: + setattr(self, key, kwargs[key]) + return d = dict(*args, **kwargs) if self._retrieve_on_init or kwargs.get("uuid"): @@ -67,6 +77,34 @@ def __init__(self, *args, **kwargs): self.final_update() self.__dict__["initialized"] = True + + def _allowed_attributes(self, attributes): + allowed_data = {} + for key in attributes: + if key in self.__dict__["schema"]["$defs"][f"{self.__class__.__name__}Post"]["properties"]: + allowed_data[key] = attributes[key] + return allowed_data + + @staticmethod + def _from_dict(json_dict: dict): + node_name_list = json_dict.get("node", None) + if node_name_list is None or not isinstance(node_name_list, list) or len(node_name_list) != 1: + raise ValueError(f"Conversion of dictionary to CRIPT Node failed, since 'node' is {node_name_list} given.") + node_name_str = node_name_list[0] + + for key, pyclass in inspect.getmembers(cript.nodes, inspect.isclass): + if CriptNode in inspect.getmro(pyclass): + if key == node_name_str: + next_node = pyclass._cls_from_dict(json_dict) + return next_node + raise ValueError(f"Unknow node conversion attempt {json_dict}") + + @classmethod + def _cls_from_dict(cls, json_dict: dict): + json_dict["initialized"] = True + next_node = cls(**json_dict) + return next_node + @property def name_url(self): return camel_case_to_snake_case(self.__class__.__name__) @@ -316,11 +354,10 @@ def retrieve_by_uuid(self, uuid): result = self.__dict__["client"].nodes.retrieve(node=self.name_url, uuid=uuid) data = extract_node_from_result(result.data) self.__dict__["exists"] = True - allowed_data = {} - for key in data: - if key in self.__dict__["schema"]["$defs"][f"{self.__class__.__name__}Post"]["properties"]: - setattr(self, key, data[key]) - allowed_data[key] = data[key] + + allowed_data = self._allowed_attributes(data) + for key in allowed_data: + setattr(self, key, data[key]) self.__dict__["__original__"] = copy.deepcopy(allowed_data) except NotFoundError: self.__dict__["exists"] = False diff --git a/src/cript/resources/child.py b/src/cript/resources/child.py new file mode 100644 index 0000000..f9d01fc --- /dev/null +++ b/src/cript/resources/child.py @@ -0,0 +1,154 @@ +import httpx +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._utils import ( + maybe_transform, + async_maybe_transform, +) +from .._compat import cached_property +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import ( + make_request_options, +) +import cript +import inspect +from ..types.shared.search import Search +from .._resource import SyncAPIResource + +class ChildPaginator: + # TODO consider writing operations + def __init__(self, parent, child, client=None, raw_dict_output=False): + if client is None: + client = parent.client + self._client = client + self._parent = parent + self._child = child + + self._current_child_list = [] + self._current_child_position = 0 + self._current_page = 0 + self._count = None + self.raw_dict_output = raw_dict_output + + def __iter__(self): + self._current_child_position = 0 + return self + + def __next__(self): + if self._current_child_position >= len(self._current_child_list): + self._fetch_next_page() + try: + next_node = self._current_child_list[self._current_child_position] + except IndexError: + raise StopIteration + + self._current_child_position += 1 + + if not self.raw_dict_output: + + return next_node + + def _fetch_next_page(self): + if self._finished_fetching: + raise StopIteration + + response = self._client._child.child(self._parent, self._child, self._current_page) + self._current_page += 1 + if self._count is not None and self._count != int(response.data.count): + raise RuntimeError("The number of elements for a child iteration changed during pagination. This may lead to inconsistencies. Please try again.") + self._count = int(response.data.count) + + self._current_child_list += response.data.result + + # Make it a random access iterator, since ppl expect it to behave list a list + def __getitem__(self, key): + key_index = int(key) + previous_pos = self._current_child_position + try: + if key_index < 0: + while not self._finished_fetching: + next(self) + + while len(self._current_child_list) <= key_index: + try: + next(self) + except StopIteration: + break + finally: + self._current_child_position = previous_pos + # We don't need explicit bounds checking, since the list access does that for us. + return self._current_child_list[key_index] + + def __len__(self): + previous_pos = self._current_child_position + try: + if self._count is None: + try: + next(iter(self)) + except StopIteration: + self._count = 0 + finally: + self._current_child_position = previous_pos + return self._count + + @property + def _finished_fetching(self): + if self._count is None: + return False + return len(self._current_child_list) == self._count + + +class ChildResource(SyncAPIResource): + @cached_property + def with_raw_response(self): + return ChildResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self): + return ChildResourceWithStreamingResponse(self) + + def child( + self, + parent, + child: str, + page: int, + *, + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Search: + """ + Obtain all children of parent node. + + Args: + parent: parent node + child: attribute name of the child node + """ + return self._get(f"/{parent.name_url}/{parent.uuid}/{child}", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + query={"page": page}, + timeout=timeout, + ), + cast_to=Search, + ) + + +class ChildResourceWithRawResponse: + def __init__(self, child:ChildResource) -> None: + self._child = child + + self.node = to_raw_response_wrapper(child.node) + +class ChildResourceWithStreamingResponse: + def __init__(self, child:ChildResource) -> None: + self._child = child + + self.node = to_streamed_response_wrapper(child.node) diff --git a/tests/api_resources/test_cript.py b/tests/api_resources/test_cript.py index b1c2cde..ec1f7cb 100644 --- a/tests/api_resources/test_cript.py +++ b/tests/api_resources/test_cript.py @@ -7,6 +7,7 @@ import pytest +import cript from cript import * base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -48,7 +49,7 @@ def test_create_project(self) -> None: notes="my notes", ) assert node.get("name") is not None - + def test_create_collection_exisiting_project(self) -> None: col1=Collection(name=generic_collection) proj = Project(uuid=CREATED_UUID, collection=[col1]) @@ -67,6 +68,15 @@ def test_create_experiment_exisiting_collection(self) -> None: proj1 = Project(uuid=CREATED_UUID, collection=[col1]) assert exp1.get("name") == generic_experiment + @pytest.mark.parametrize("query", ["tol", "styrene"]) + def test_create_children(self, query) -> None: + result = Search(node="Material", q=query, filters={"limit": 10}) + for i, d in enumerate(result): + node = cript.nodes.CriptNode._from_dict(d) + assert isinstance(node, cript.nodes.CriptNode) + assert isinstance(node, cript.Material) + assert query in node.name.lower() + assert i > 0 def test_create_material(self) -> None: comp_forcefield= ComputationalForcefield(key="mmff", building_block="atom")