diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0e16b00a..924364db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,8 +27,7 @@ repos: rev: v1.11.2 hooks: - id: mypy - additional_dependencies: - [types-pyyaml==6.0.12.20240808, types-aiofiles==24.1.0.20240626] + additional_dependencies: [types-pyyaml==6.0.12.20240808] - repo: local hooks: - id: pytest diff --git a/goosebit/api/v1/software/routes.py b/goosebit/api/v1/software/routes.py index 10dd3462..081928f2 100644 --- a/goosebit/api/v1/software/routes.py +++ b/goosebit/api/v1/software/routes.py @@ -1,5 +1,7 @@ -import aiofiles -from anyio import Path +import random +import string + +from anyio import Path, open_file from fastapi import APIRouter, File, Form, HTTPException, Security, UploadFile from fastapi.requests import Request @@ -69,11 +71,15 @@ async def post_update(_: Request, file: UploadFile | None = File(None), url: str # local file artifacts_dir = Path(config.artifacts_dir) file_path = artifacts_dir.joinpath(file.filename) + tmp_file_path = artifacts_dir.joinpath("tmp", ("".join(random.choices(string.ascii_lowercase, k=12)) + ".tmp")) + await tmp_file_path.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.tempfile.NamedTemporaryFile("w+b") as f: + async with await open_file(tmp_file_path, "w+b") as f: await f.write(await file.read()) - absolute = await file_path.absolute() - software = await create_software_update(absolute.as_uri(), Path(f.name)) + absolute = await file_path.absolute() + tmp_absolute = await tmp_file_path.absolute() + software = await create_software_update(absolute.as_uri(), tmp_absolute) + await tmp_file_path.unlink(missing_ok=True) else: raise HTTPException(422) diff --git a/goosebit/updates/__init__.py b/goosebit/updates/__init__.py index 35c35dbf..84919397 100644 --- a/goosebit/updates/__init__.py +++ b/goosebit/updates/__init__.py @@ -19,7 +19,9 @@ async def create_software_update(uri: str, temp_file: Path | None) -> Software: parsed_uri = urlparse(uri) # parse swu header into update_info - if parsed_uri.scheme == "file" and temp_file is not None: + if parsed_uri.scheme == "file": + if temp_file is None: + raise HTTPException(422, "Temporary file missing, cannot parse file information") try: update_info = await swdesc.parse_file(temp_file) except Exception: @@ -30,7 +32,6 @@ async def create_software_update(uri: str, temp_file: Path | None) -> Software: update_info = await swdesc.parse_remote(uri) except Exception: raise HTTPException(422, "Software swu header cannot be parsed") - else: raise HTTPException(422, "Software URI protocol unknown") @@ -43,11 +44,11 @@ async def create_software_update(uri: str, temp_file: Path | None) -> Software: raise HTTPException(409, "Software with same version and overlapping compatibility already exists") # for local file: rename temp file to final name - if parsed_uri.scheme == "file" and temp_file is not None: + if parsed_uri.scheme == "file": filename = Path(url2pathname(unquote(parsed_uri.path))).name path = Path(config.artifacts_dir).joinpath(update_info["hash"], filename) await path.parent.mkdir(parents=True, exist_ok=True) - await temp_file.rename(path) + await temp_file.replace(path) absolute = await path.absolute() uri = absolute.as_uri() diff --git a/goosebit/updates/swdesc.py b/goosebit/updates/swdesc.py index c6e5685c..c5f2abc3 100644 --- a/goosebit/updates/swdesc.py +++ b/goosebit/updates/swdesc.py @@ -1,13 +1,16 @@ import hashlib import logging +import random +import string from typing import Any -import aiofiles import httpx import libconf import semver from anyio import AsyncFile, Path, open_file +from goosebit.settings import config + logger = logging.getLogger(__name__) @@ -68,9 +71,14 @@ async def parse_file(file: Path): async def parse_remote(url: str): async with httpx.AsyncClient() as c: file = await c.get(url) - async with aiofiles.tempfile.NamedTemporaryFile("w+b") as f: + artifacts_dir = Path(config.artifacts_dir) + tmp_file_path = artifacts_dir.joinpath("tmp", ("".join(random.choices(string.ascii_lowercase, k=12)) + ".tmp")) + await tmp_file_path.parent.mkdir(parents=True, exist_ok=True) + async with await open_file(tmp_file_path, "w+b") as f: await f.write(file.content) - return await parse_file(Path(str(f.name))) + file_data = await parse_file(Path(str(f.name))) + await tmp_file_path.unlink() + return file_data async def _sha1_hash_file(fileobj: AsyncFile): diff --git a/pyproject.toml b/pyproject.toml index d6632b96..a7baf727 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ jinja2 = "^3.1.4" itsdangerous = "^2.2.0" tortoise-orm = "^0.21.4" aerich = "^0.7.2" -aiofiles = "^24.1.0" websockets = "^12.0" argon2-cffi = "^23.1.0" joserfc = "^1.0.0"