Skip to content

Enable creation of nodes from dictionary #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 43 additions & 6 deletions src/cript/nodes/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from jsonschema.exceptions import best_match
from uuid import uuid4

import cript
import inspect
from cript import Cript, NotFoundError, camel_case_to_snake_case, extract_node_from_result
from .schema import cript_schema

Expand All @@ -22,7 +24,7 @@ def __init__(self, *args, **kwargs):
self.__dict__["exists"] = self.__dict__.get("exists", False)
self.__dict__["client"] = self.__dict__.get("client", Cript())
self.__dict__["children"] = self.__dict__.get("children", {})
self.__dict__["initialized"] = False
self.__dict__["initialized"] = kwargs.get("initialized", False)
self.__dict__["parent"] = None
schema = copy.deepcopy(cript_schema)
schema["$ref"] = f"#/$defs/{self.__class__.__name__}Post"
Expand All @@ -37,6 +39,14 @@ def __init__(self, *args, **kwargs):
self.__dict__["validator_instance"] = cls(schema, resolver=resolver)
else:
self.__dict__["validator_instance"] = cls(schema)

allowed_attributes = self._allowed_attributes(kwargs)

# Early exit for initialized nodes
if self.initialized:
for key in allowed_attributes:
setattr(self, key, kwargs[key])
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to assign valid kwargs to self at this point and then return

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did I do it right now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but we need to make sure not to include the extra fields that the api returns, you can refactor the retrieve_by_uuid and create a separate method from

allowed_data = {}
for key in data:
    if key in self.__dict__["schema"]["$defs"][f"{self.__class__.__name__}Post"]["properties"]:
        setattr(self, key, data[key])
        allowed_data[key] = data[key]
self.__dict__["__original__"] = copy.deepcopy(allowed_data)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the pointers.
I used that snippet to refactor it a little bit.

I think there is some room to streamline this a little bit.
But we can address that in a refactor down the line.

d = dict(*args, **kwargs)

if self._retrieve_on_init or kwargs.get("uuid"):
Expand Down Expand Up @@ -67,6 +77,34 @@ def __init__(self, *args, **kwargs):
self.final_update()
self.__dict__["initialized"] = True


def _allowed_attributes(self, attributes):
allowed_data = {}
for key in attributes:
if key in self.__dict__["schema"]["$defs"][f"{self.__class__.__name__}Post"]["properties"]:
allowed_data[key] = attributes[key]
return allowed_data

@staticmethod
def _from_dict(json_dict: dict):
node_name_list = json_dict.get("node", None)
if node_name_list is None or not isinstance(node_name_list, list) or len(node_name_list) != 1:
raise ValueError(f"Conversion of dictionary to CRIPT Node failed, since 'node' is {node_name_list} given.")
node_name_str = node_name_list[0]

for key, pyclass in inspect.getmembers(cript.nodes, inspect.isclass):
if CriptNode in inspect.getmro(pyclass):
if key == node_name_str:
next_node = pyclass._cls_from_dict(json_dict)
return next_node
raise ValueError(f"Unknow node conversion attempt {json_dict}")

@classmethod
def _cls_from_dict(cls, json_dict: dict):
json_dict["initialized"] = True
next_node = cls(**json_dict)
return next_node

@property
def name_url(self):
return camel_case_to_snake_case(self.__class__.__name__)
Expand Down Expand Up @@ -316,11 +354,10 @@ def retrieve_by_uuid(self, uuid):
result = self.__dict__["client"].nodes.retrieve(node=self.name_url, uuid=uuid)
data = extract_node_from_result(result.data)
self.__dict__["exists"] = True
allowed_data = {}
for key in data:
if key in self.__dict__["schema"]["$defs"][f"{self.__class__.__name__}Post"]["properties"]:
setattr(self, key, data[key])
allowed_data[key] = data[key]

allowed_data = self._allowed_attributes(data)
for key in allowed_data:
setattr(self, key, data[key])
self.__dict__["__original__"] = copy.deepcopy(allowed_data)
except NotFoundError:
self.__dict__["exists"] = False
Expand Down
154 changes: 154 additions & 0 deletions src/cript/resources/child.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import httpx
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
from .._utils import (
maybe_transform,
async_maybe_transform,
)
from .._compat import cached_property
from .._response import (
to_raw_response_wrapper,
to_streamed_response_wrapper,
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from .._base_client import (
make_request_options,
)
import cript
import inspect
from ..types.shared.search import Search
from .._resource import SyncAPIResource

class ChildPaginator:
# TODO consider writing operations
def __init__(self, parent, child, client=None, raw_dict_output=False):
if client is None:
client = parent.client
self._client = client
self._parent = parent
self._child = child

self._current_child_list = []
self._current_child_position = 0
self._current_page = 0
self._count = None
self.raw_dict_output = raw_dict_output

def __iter__(self):
self._current_child_position = 0
return self

def __next__(self):
if self._current_child_position >= len(self._current_child_list):
self._fetch_next_page()
try:
next_node = self._current_child_list[self._current_child_position]
except IndexError:
raise StopIteration

self._current_child_position += 1

if not self.raw_dict_output:

return next_node

def _fetch_next_page(self):
if self._finished_fetching:
raise StopIteration

response = self._client._child.child(self._parent, self._child, self._current_page)
self._current_page += 1
if self._count is not None and self._count != int(response.data.count):
raise RuntimeError("The number of elements for a child iteration changed during pagination. This may lead to inconsistencies. Please try again.")
self._count = int(response.data.count)

self._current_child_list += response.data.result

# Make it a random access iterator, since ppl expect it to behave list a list
def __getitem__(self, key):
key_index = int(key)
previous_pos = self._current_child_position
try:
if key_index < 0:
while not self._finished_fetching:
next(self)

while len(self._current_child_list) <= key_index:
try:
next(self)
except StopIteration:
break
finally:
self._current_child_position = previous_pos
# We don't need explicit bounds checking, since the list access does that for us.
return self._current_child_list[key_index]

def __len__(self):
previous_pos = self._current_child_position
try:
if self._count is None:
try:
next(iter(self))
except StopIteration:
self._count = 0
finally:
self._current_child_position = previous_pos
return self._count

@property
def _finished_fetching(self):
if self._count is None:
return False
return len(self._current_child_list) == self._count


class ChildResource(SyncAPIResource):
@cached_property
def with_raw_response(self):
return ChildResourceWithRawResponse(self)

@cached_property
def with_streaming_response(self):
return ChildResourceWithStreamingResponse(self)

def child(
self,
parent,
child: str,
page: int,
*,
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> Search:
"""
Obtain all children of parent node.

Args:
parent: parent node
child: attribute name of the child node
"""
return self._get(f"/{parent.name_url}/{parent.uuid}/{child}",
options=make_request_options(
extra_headers=extra_headers,
extra_query=extra_query,
extra_body=extra_body,
query={"page": page},
timeout=timeout,
),
cast_to=Search,
)


class ChildResourceWithRawResponse:
def __init__(self, child:ChildResource) -> None:
self._child = child

self.node = to_raw_response_wrapper(child.node)

class ChildResourceWithStreamingResponse:
def __init__(self, child:ChildResource) -> None:
self._child = child

self.node = to_streamed_response_wrapper(child.node)
12 changes: 11 additions & 1 deletion tests/api_resources/test_cript.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import pytest

import cript
from cript import *

base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_create_project(self) -> None:
notes="my notes",
)
assert node.get("name") is not None

def test_create_collection_exisiting_project(self) -> None:
col1=Collection(name=generic_collection)
proj = Project(uuid=CREATED_UUID, collection=[col1])
Expand All @@ -67,6 +68,15 @@ def test_create_experiment_exisiting_collection(self) -> None:
proj1 = Project(uuid=CREATED_UUID, collection=[col1])
assert exp1.get("name") == generic_experiment

@pytest.mark.parametrize("query", ["tol", "styrene"])
def test_create_children(self, query) -> None:
result = Search(node="Material", q=query, filters={"limit": 10})
for i, d in enumerate(result):
node = cript.nodes.CriptNode._from_dict(d)
assert isinstance(node, cript.nodes.CriptNode)
assert isinstance(node, cript.Material)
assert query in node.name.lower()
assert i > 0

def test_create_material(self) -> None:
comp_forcefield= ComputationalForcefield(key="mmff", building_block="atom")
Expand Down