Skip to content

Commit

Permalink
Use anyio based async file io operations
Browse files Browse the repository at this point in the history
  • Loading branch information
jameshilliard committed Sep 9, 2024
1 parent bb7f43b commit f7a0ce7
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 43 deletions.
13 changes: 7 additions & 6 deletions goosebit/api/v1/software/routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path

import aiofiles
from anyio import Path
from fastapi import APIRouter, File, Form, HTTPException, Security, UploadFile
from fastapi.requests import Request

Expand Down Expand Up @@ -42,8 +41,8 @@ async def software_delete(_: Request, delete_req: SoftwareDeleteRequest) -> Stat

if software.local:
path = software.path
if path.exists():
path.unlink()
if await path.exists():
await path.unlink()

await software.delete()
success = True
Expand All @@ -68,11 +67,13 @@ async def post_update(_: Request, file: UploadFile | None = File(None), url: str
software = await create_software_update(url, None)
elif file is not None:
# local file
file_path = config.artifacts_dir.joinpath(file.filename)
artifacts_dir = Path(config.artifacts_dir)
file_path = artifacts_dir.joinpath(file.filename)

async with aiofiles.tempfile.NamedTemporaryFile("w+b") as f:
await f.write(await file.read())
software = await create_software_update(file_path.absolute().as_uri(), Path(str(f.name)))
absolute = await file_path.absolute()
software = await create_software_update(absolute.as_uri(), Path(f.name))
else:
raise HTTPException(422)

Expand Down
4 changes: 2 additions & 2 deletions goosebit/db/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from enum import IntEnum
from pathlib import Path
from typing import Self
from urllib.parse import unquote, urlparse
from urllib.request import url2pathname

import semver
from anyio import Path
from tortoise import Model, fields

from goosebit.api.telemetry.metrics import devices_count
Expand Down Expand Up @@ -132,7 +132,7 @@ async def latest(cls, device: Device) -> Self | None:
)[0]

@property
def path(self):
def path(self) -> Path:
return Path(url2pathname(unquote(urlparse(self.uri).path)))

@property
Expand Down
20 changes: 10 additions & 10 deletions goosebit/ui/bff/software/routes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

import aiofiles
from anyio import Path, open_file
from fastapi import APIRouter, Form, HTTPException, Security, UploadFile
from fastapi.requests import Request
from tortoise.expressions import Q
Expand Down Expand Up @@ -64,20 +64,20 @@ async def post_update(
await create_software_update(url, None)
else:
# local file
file = config.artifacts_dir.joinpath(filename)
config.artifacts_dir.mkdir(parents=True, exist_ok=True)
artifacts_dir = Path(config.artifacts_dir)
file = artifacts_dir.joinpath(filename)
await artifacts_dir.mkdir(parents=True, exist_ok=True)

temp_file = file.with_suffix(".tmp")
if init:
temp_file.unlink(missing_ok=True)
await temp_file.unlink(missing_ok=True)

contents = await chunk.read()

async with aiofiles.open(temp_file, mode="ab") as f:
await f.write(contents)
async with await open_file(temp_file, "ab") as f:
await f.write(await chunk.read())

if done:
try:
await create_software_update(file.absolute().as_uri(), temp_file)
absolute = await file.absolute()
await create_software_update(absolute.as_uri(), temp_file)
finally:
temp_file.unlink(missing_ok=True)
await temp_file.unlink(missing_ok=True)
12 changes: 6 additions & 6 deletions goosebit/updates/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import shutil
from pathlib import Path
from urllib.parse import unquote, urlparse
from urllib.request import url2pathname

from anyio import Path
from fastapi import HTTPException
from fastapi.requests import Request
from tortoise.expressions import Q
Expand Down Expand Up @@ -46,10 +45,11 @@ async def create_software_update(uri: str, temp_file: Path | None) -> Software:
# for local file: rename temp file to final name
if parsed_uri.scheme == "file" and temp_file is not None:
filename = Path(url2pathname(unquote(parsed_uri.path))).name
path = config.artifacts_dir.joinpath(update_info["hash"], filename)
path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(temp_file, path)
uri = path.absolute().as_uri()
path = Path(config.artifacts_dir).joinpath(update_info["hash"], filename)
await path.parent.mkdir(parents=True, exist_ok=True)
await temp_file.rename(path)
absolute = await path.absolute()
uri = absolute.as_uri()

# create software
software = await Software.create(
Expand Down
25 changes: 18 additions & 7 deletions goosebit/updates/swdesc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import hashlib
import logging
from pathlib import Path
from typing import Any

import aiofiles
import httpx
import libconf
import semver
from anyio import AsyncFile, Path, open_file

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -41,7 +41,7 @@ def parse_descriptor(swdesc: libconf.AttrDict[Any, Any | None]):


async def parse_file(file: Path):
async with aiofiles.open(file, "r+b") as f:
async with await open_file(file, "r+b") as f:
# get file size
size = int((await f.read(110))[54:62], 16)
filename = b""
Expand All @@ -59,8 +59,9 @@ async def parse_file(file: Path):
swdesc = libconf.loads((await f.read(size)).decode("utf-8"))

swdesc_attrs = parse_descriptor(swdesc)
swdesc_attrs["size"] = file.stat().st_size
swdesc_attrs["hash"] = _sha1_hash_file(file)
stat = await file.stat()
swdesc_attrs["size"] = stat.st_size
swdesc_attrs["hash"] = await _sha1_hash_file(f)
return swdesc_attrs


Expand All @@ -72,7 +73,17 @@ async def parse_remote(url: str):
return await parse_file(Path(str(f.name)))


def _sha1_hash_file(file_path: Path):
with file_path.open("rb") as f:
sha1_hash = hashlib.file_digest(f, "sha1")
async def _sha1_hash_file(fileobj: AsyncFile):
last = await fileobj.tell()
await fileobj.seek(0)
sha1_hash = hashlib.sha1()
buf = bytearray(2**18)
view = memoryview(buf)
while True:
size = await fileobj.readinto(buf)
if size == 0:
break
sha1_hash.update(view[:size])

await fileobj.seek(last)
return sha1_hash.hexdigest()
20 changes: 11 additions & 9 deletions tests/api/v1/software/test_routes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from pathlib import Path

import pytest
from anyio import Path, open_file


@pytest.mark.asyncio
async def test_create_software_local(async_client, test_data):
path = Path(__file__).resolve().parent / "software-header.swu"
with open(path, "rb") as file:
files = {"file": file}
response = await async_client.post(f"/api/v1/software", files=files)
resolved = await Path(__file__).resolve()
path = resolved.parent / "software-header.swu"
async with await open_file(path, "rb") as file:
files = {"file": await file.read()}

response = await async_client.post(f"/api/v1/software", files=files)

assert response.status_code == 200
software = response.json()
Expand All @@ -17,9 +18,10 @@ async def test_create_software_local(async_client, test_data):

@pytest.mark.asyncio
async def test_create_software_remote(async_client, httpserver, test_data):
path = Path(__file__).resolve().parent / "software-header.swu"
with open(path, "rb") as file:
byte_array = file.read()
resolved = await Path(__file__).resolve()
path = resolved.parent / "software-header.swu"
async with await open_file(path, "rb") as file:
byte_array = await file.read()

httpserver.expect_request("/software-header.swu").respond_with_data(byte_array)

Expand Down
6 changes: 3 additions & 3 deletions tests/updates/test_swdesc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path

import pytest
from anyio import Path
from libconf import AttrDict

from goosebit.updates.swdesc import parse_descriptor, parse_file
Expand Down Expand Up @@ -102,7 +101,8 @@ def test_parse_descriptor_several_boardname():

@pytest.mark.asyncio
async def test_parse_software_header():
swdesc_attrs = await parse_file(Path(__file__).resolve().parent / "software-header.swu")
resolved = await Path(__file__).resolve()
swdesc_attrs = await parse_file(resolved.parent / "software-header.swu")
assert str(swdesc_attrs["version"]) == "8.8.1-11-g8c926e5+188370"
assert swdesc_attrs["compatibility"] == [
{"hw_model": "smart-gateway-mt7688", "hw_revision": "0.5"},
Expand Down

0 comments on commit f7a0ce7

Please sign in to comment.