-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added singlepoint files + added route to main (#11)
* Added singlepoint files + added route to main * Added tests for singlepoint helper * Fixed single file upload and test
- Loading branch information
1 parent
eeb9ca5
commit 1dd457c
Showing
7 changed files
with
206 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Contains routes for performing singlepoint calculations.""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
from fastapi import APIRouter, HTTPException | ||
from janus_core.helpers.janus_types import Architectures, Properties | ||
from pydantic import BaseModel | ||
|
||
from api.utils.singlepoint_helper import singlepoint | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
router = APIRouter(prefix="/singlepoint", tags=["calculations"]) | ||
|
||
|
||
class SinglePointRequest(BaseModel): | ||
"""Class validation for singlepoint requests.""" | ||
|
||
struct: str | ||
arch: Architectures | ||
properties: list[Properties] | ||
range_selector: str | ||
|
||
|
||
@router.post("/") | ||
def get_singlepoint(request: SinglePointRequest) -> dict[str, Any]: | ||
""" | ||
Endpoint to perform single point calculations and return results. | ||
Parameters | ||
---------- | ||
request : SinglePointRequest | ||
The request body containing the parameters for the calculation. | ||
Returns | ||
------- | ||
dict[str, Any] | ||
Results of the single point calculations. | ||
Raises | ||
------ | ||
HTTPException | ||
If there is an error during the call. | ||
""" | ||
base_dir = Path("data") | ||
struct_path = base_dir / request.struct | ||
logger.info("Request contents:", request) | ||
try: | ||
return singlepoint( | ||
struct=struct_path, | ||
arch=request.arch, | ||
properties=request.properties, | ||
range_selector=request.range_selector, | ||
) | ||
except Exception as e: | ||
logger.error(e) | ||
raise HTTPException(status_code=500, detail="Internal Server Error") from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
"""Helper functions for performing singeploint calculations.""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
from pathlib import Path | ||
from typing import Any | ||
|
||
from janus_core.calculations.single_point import SinglePoint | ||
from janus_core.helpers.janus_types import Architectures, Properties | ||
import numpy as np | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def convert_ndarray_to_list( | ||
data: dict[str, Any] | list | np.ndarray | float, | ||
) -> dict[str, Any]: | ||
""" | ||
Recursive function to convert numpy arrays into a useable format for fastAPI. | ||
Parameters | ||
---------- | ||
data : dict[str, Any] | list | np.ndarray | float | ||
The object to be checked and potentially converted. | ||
Returns | ||
------- | ||
dict[str, Any] | ||
Dictionary of properties calculated. | ||
""" | ||
if isinstance(data, np.ndarray): | ||
return data.tolist() | ||
if isinstance(data, dict): | ||
return {k: convert_ndarray_to_list(v) for k, v in data.items()} | ||
if isinstance(data, list): | ||
return [convert_ndarray_to_list(i) for i in data] | ||
return data | ||
|
||
|
||
def singlepoint( | ||
struct: Path, | ||
arch: Architectures | None = "mace_mp", | ||
properties: list[Properties] | None = None, | ||
range_selector: str | None = ":", | ||
) -> dict[str, Any]: | ||
""" | ||
Perform single point calculations and return results. | ||
Parameters | ||
---------- | ||
struct : str | ||
Filename of structure to simulate. | ||
arch : Architectures | ||
MLIP architecture to use for single point calculations. Default is "mace_mp". | ||
properties : List[Properties] | ||
Physical properties to calculate. Default is ("energy", "forces", "stress"). | ||
range_selector : str | ||
Range of indices to include from the structure. Default is all. | ||
Returns | ||
------- | ||
dict[str, Any] | ||
Results of the single point calculations. | ||
""" | ||
read_kwargs = {"index": range_selector} | ||
|
||
singlepoint_kwargs = { | ||
"struct_path": struct, | ||
"properties": properties, | ||
"arch": arch, | ||
"device": "cpu", | ||
"read_kwargs": read_kwargs, | ||
} | ||
|
||
s_point = SinglePoint(**singlepoint_kwargs) | ||
|
||
s_point.run() | ||
|
||
return convert_ndarray_to_list(s_point.results) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Tests for singlepoint calculation helper functions.""" | ||
|
||
from __future__ import annotations | ||
|
||
import numpy as np | ||
|
||
from api.utils.singlepoint_helper import convert_ndarray_to_list | ||
|
||
|
||
def test_convert_ndarray_to_list_with_array(): | ||
"""Test convert_ndarray_to_list function with a np array.""" | ||
array = np.array([1, 2, 3]) | ||
result = convert_ndarray_to_list(array) | ||
assert result == [1, 2, 3] | ||
|
||
|
||
def test_convert_ndarray_to_list_with_nested_dict(): | ||
"""Test convert_ndarray_to_list function with a nested dict.""" | ||
nested_dict = {"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}} | ||
result = convert_ndarray_to_list(nested_dict) | ||
expected = {"a": [1, 2, 3], "b": {"c": [4, 5, 6]}} | ||
assert result == expected | ||
|
||
|
||
def test_convert_ndarray_to_list_with_list(): | ||
"""Test convert_ndarray_to_list function with list of np arrays.""" | ||
list_of_arrays = [np.array([1, 2]), np.array([3, 4])] | ||
result = convert_ndarray_to_list(list_of_arrays) | ||
assert result == [[1, 2], [3, 4]] | ||
|
||
|
||
def test_convert_ndarray_to_list_with_non_numpy(): | ||
"""Test convert_ndarray_to_list function with a float value.""" | ||
obj = -843.033131230741 | ||
result = convert_ndarray_to_list(obj) | ||
assert result == -843.033131230741 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters