Skip to content

Commit

Permalink
Added singlepoint files + added route to main (#11)
Browse files Browse the repository at this point in the history
* Added singlepoint files + added route to main

* Added tests for singlepoint helper

* Fixed single file upload and test
  • Loading branch information
Cbameron12 authored Jan 29, 2025
1 parent eeb9ca5 commit 1dd457c
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 31 deletions.
61 changes: 61 additions & 0 deletions api/endpoints/singlepoint_route.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Contains routes for performing singlepoint calculations."""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

from fastapi import APIRouter, HTTPException
from janus_core.helpers.janus_types import Architectures, Properties
from pydantic import BaseModel

from api.utils.singlepoint_helper import singlepoint

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/singlepoint", tags=["calculations"])


class SinglePointRequest(BaseModel):
"""Class validation for singlepoint requests."""

struct: str
arch: Architectures
properties: list[Properties]
range_selector: str


@router.post("/")
def get_singlepoint(request: SinglePointRequest) -> dict[str, Any]:
"""
Endpoint to perform single point calculations and return results.
Parameters
----------
request : SinglePointRequest
The request body containing the parameters for the calculation.
Returns
-------
dict[str, Any]
Results of the single point calculations.
Raises
------
HTTPException
If there is an error during the call.
"""
base_dir = Path("data")
struct_path = base_dir / request.struct
logger.info("Request contents:", request)
try:
return singlepoint(
struct=struct_path,
arch=request.arch,
properties=request.properties,
range_selector=request.range_selector,
)
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail="Internal Server Error") from e
2 changes: 1 addition & 1 deletion api/endpoints/upload_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def upload_single(
file_content = await file.read()
logger.info(f"Hash matches: {calculate_md5_checksum(file_content, file_hash)}")

save_file(file)
save_file(file_content, file.filename)
except Exception as e:
logger.error(f"Error during file upload: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
Expand Down
80 changes: 80 additions & 0 deletions api/utils/singlepoint_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""Helper functions for performing singeploint calculations."""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

from janus_core.calculations.single_point import SinglePoint
from janus_core.helpers.janus_types import Architectures, Properties
import numpy as np

logger = logging.getLogger(__name__)


def convert_ndarray_to_list(
data: dict[str, Any] | list | np.ndarray | float,
) -> dict[str, Any]:
"""
Recursive function to convert numpy arrays into a useable format for fastAPI.
Parameters
----------
data : dict[str, Any] | list | np.ndarray | float
The object to be checked and potentially converted.
Returns
-------
dict[str, Any]
Dictionary of properties calculated.
"""
if isinstance(data, np.ndarray):
return data.tolist()
if isinstance(data, dict):
return {k: convert_ndarray_to_list(v) for k, v in data.items()}
if isinstance(data, list):
return [convert_ndarray_to_list(i) for i in data]
return data


def singlepoint(
struct: Path,
arch: Architectures | None = "mace_mp",
properties: list[Properties] | None = None,
range_selector: str | None = ":",
) -> dict[str, Any]:
"""
Perform single point calculations and return results.
Parameters
----------
struct : str
Filename of structure to simulate.
arch : Architectures
MLIP architecture to use for single point calculations. Default is "mace_mp".
properties : List[Properties]
Physical properties to calculate. Default is ("energy", "forces", "stress").
range_selector : str
Range of indices to include from the structure. Default is all.
Returns
-------
dict[str, Any]
Results of the single point calculations.
"""
read_kwargs = {"index": range_selector}

singlepoint_kwargs = {
"struct_path": struct,
"properties": properties,
"arch": arch,
"device": "cpu",
"read_kwargs": read_kwargs,
}

s_point = SinglePoint(**singlepoint_kwargs)

s_point.run()

return convert_ndarray_to_list(s_point.results)
47 changes: 24 additions & 23 deletions api/utils/upload_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@
DATA_DIR = Path("/home/ubuntu/janus-api/janus-web/data")


def save_file(file_contents: bytes, filename: str, directory: Path = DATA_DIR):
"""
Save a file to the specified directory.
Parameters
----------
file_contents : bytes
The file to be saved.
filename : str
Name of the file to be saved.
directory : Path, optional
The directory where the file will be saved (default is DATA_DIR).
Returns
-------
str
The path where the file was saved.
"""
directory.mkdir(parents=True, exist_ok=True)
file_path = directory / filename
file_path.write_bytes(file_contents)
return file_path


def save_chunk(
file: bytes, chunk_number: int, original_filename: str, directory: Path = DATA_DIR
) -> str:
Expand Down Expand Up @@ -66,29 +90,6 @@ def reassemble_file(
return output_path


def save_file(file: bytes, directory: Path = DATA_DIR):
"""
Save a file to the specified directory.
Parameters
----------
file : bytes
The file to be saved.
directory : Path, optional
The directory where the file will be saved (default is DATA_DIR).
Returns
-------
str
The path where the file was saved.
"""
directory.mkdir(parents=True, exist_ok=True)
file_path = directory / file.filename
with file_path.open("wb") as buffer:
buffer.write(file.file.read())
return file_path


def calculate_md5_checksum(file_chunk: bytes, received_hash: str) -> bool:
"""
Calculate the MD5 checksum of a file chunk.
Expand Down
3 changes: 2 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from fastapi import FastAPI

from api.endpoints import upload_route
from api.endpoints import singlepoint_route, upload_route
import logging_config

app = FastAPI()

app.include_router(upload_route.router)
app.include_router(singlepoint_route.router)

if __name__ == "__main__":
import uvicorn
Expand Down
36 changes: 36 additions & 0 deletions tests/test_singlepoint_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Tests for singlepoint calculation helper functions."""

from __future__ import annotations

import numpy as np

from api.utils.singlepoint_helper import convert_ndarray_to_list


def test_convert_ndarray_to_list_with_array():
"""Test convert_ndarray_to_list function with a np array."""
array = np.array([1, 2, 3])
result = convert_ndarray_to_list(array)
assert result == [1, 2, 3]


def test_convert_ndarray_to_list_with_nested_dict():
"""Test convert_ndarray_to_list function with a nested dict."""
nested_dict = {"a": np.array([1, 2, 3]), "b": {"c": np.array([4, 5, 6])}}
result = convert_ndarray_to_list(nested_dict)
expected = {"a": [1, 2, 3], "b": {"c": [4, 5, 6]}}
assert result == expected


def test_convert_ndarray_to_list_with_list():
"""Test convert_ndarray_to_list function with list of np arrays."""
list_of_arrays = [np.array([1, 2]), np.array([3, 4])]
result = convert_ndarray_to_list(list_of_arrays)
assert result == [[1, 2], [3, 4]]


def test_convert_ndarray_to_list_with_non_numpy():
"""Test convert_ndarray_to_list function with a float value."""
obj = -843.033131230741
result = convert_ndarray_to_list(obj)
assert result == -843.033131230741
8 changes: 2 additions & 6 deletions tests/test_upload_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,10 @@ def test_reassemble_file(tmp_path):

def test_save_file(tmp_path):
"""Test if save file function saves a given file in the correct directory."""
from io import BytesIO

from fastapi import UploadFile

file_content = b"Test file content"
file = UploadFile(filename="testfile.txt", file=BytesIO(file_content))
filename = "testfile.txt"

file_path = save_file(file, tmp_path)
file_path = save_file(file_content, filename, tmp_path)

assert Path(file_path).exists()
assert Path(file_path).read_bytes() == file_content
Expand Down

0 comments on commit 1dd457c

Please sign in to comment.