From caff3bc225befbac58cbe30bb00ba7978a2decb1 Mon Sep 17 00:00:00 2001 From: Simon Li Date: Thu, 20 Jun 2024 12:49:13 +0100 Subject: [PATCH] Add type hints --- .pre-commit-config.yaml | 14 ++++++++++ aws_project_costs/main.py | 2 +- aws_project_costs/project_costs.py | 42 ++++++++++++++++++++---------- aws_project_costs/schema.py | 3 ++- 4 files changed, 45 insertions(+), 16 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a531570..bb06934 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -42,4 +42,18 @@ repos: - id: bandit entry: bandit exclude: ^tests/ + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.10.0 + hooks: + - id: mypy + # TODO: Get rid of Any types. mypy doesn't take jsonschema.validate + # into account + # args: + # - --strict + additional_dependencies: + - pandas-stubs + - types-jsonschema + - types-PyYAML + files: ^aws_project_costs + exclude: "^tests/dummy-.*/.*.md$" diff --git a/aws_project_costs/main.py b/aws_project_costs/main.py index f95d798..9688fa3 100644 --- a/aws_project_costs/main.py +++ b/aws_project_costs/main.py @@ -6,7 +6,7 @@ from .schema import validate -def main(): +def main() -> None: parser = ArgumentParser("aws-project-costs") parser.add_argument( "--config", required=True, help="Project costs configuration file" diff --git a/aws_project_costs/project_costs.py b/aws_project_costs/project_costs.py index bb46a47..6388fae 100644 --- a/aws_project_costs/project_costs.py +++ b/aws_project_costs/project_costs.py @@ -1,6 +1,7 @@ import json from enum import Enum from operator import itemgetter +from typing import Any, Hashable, Optional import pandas as pd @@ -12,19 +13,21 @@ class CostSourceType(Enum): SHARED = "shared" -def _get_account_cfg(config, accountname): +def _get_account_cfg(config: dict[str, Any], accountname: str) -> dict[str, Any]: for acc in config["accounts"]: if acc["name"].lower() == accountname.lower(): return acc raise ValueError(f"Configuration for account {accountname} not found") -def _get_project_costshares_in_groups(config, groupnames): +def _get_project_costshares_in_groups( + config: dict[str, Any], groupnames: list[str] +) -> dict[str, int]: """ Get the combined set of all projects in groupnames, and their costshare If a project is in multiple groups it must have the same costshare """ - projects = dict() + projects: dict[str, int] = dict() for group in groupnames: for p in config["project-groups"][group]: name = p["name"] @@ -38,7 +41,12 @@ def _get_project_costshares_in_groups(config, groupnames): return projects -def _project_specific_account(start, description, project_name, costs_dict): +def _project_specific_account( + start: pd.Timestamp, + description: str, + project_name: str, + costs_dict: list[dict[Hashable, Any]], +) -> list[tuple[str, str, str, Any, str, Any]]: rows = [] for item in costs_dict: @@ -57,14 +65,14 @@ def _project_specific_account(start, description, project_name, costs_dict): def _shared_account( - start, - description, - proj_tag_names_map, - shared_tag_values, - project_tagname, - projects_in_group, - costs_dict, -): + start: pd.Timestamp, + description: str, + proj_tag_names_map: dict[str, str], + shared_tag_values: list[str], + project_tagname: Optional[str], + projects_in_group: dict[str, int], + costs_dict: list[dict[Hashable, Any]], +) -> list[tuple[str, str, str, Any, str, Any]]: rows = [] for item in costs_dict: @@ -105,7 +113,9 @@ def _shared_account( return rows -def allocate_costs(*, accountname, config, start, df): +def allocate_costs( + *, accountname: str, config: dict[str, Any], start: pd.Timestamp, df: pd.DataFrame +) -> list[tuple[str, str, str, Any, str, Any]]: account_cfg = _get_account_cfg(config, accountname) costs = df[(df["START"] == start) & (df["accountname"] == accountname)] @@ -150,7 +160,11 @@ def allocate_costs(*, accountname, config, start, df): return rows -def analyse_costs_csv(config, costs_csv_filename, output_csv_filename=None): +def analyse_costs_csv( + config: dict[str, Any], + costs_csv_filename: str, + output_csv_filename: Optional[str] = None, +) -> None: df = pd.read_csv( costs_csv_filename, dtype={"accountname": str, f"{PROJECT_TAG}$": str, "COST": float}, diff --git a/aws_project_costs/schema.py b/aws_project_costs/schema.py index 59616e8..3b88cc9 100644 --- a/aws_project_costs/schema.py +++ b/aws_project_costs/schema.py @@ -1,12 +1,13 @@ import json import sys from pathlib import Path +from typing import Any import jsonschema import yaml -def validate(projects, raise_on_error=False): +def validate(projects: dict[str, Any], raise_on_error: bool = False) -> list[str]: with (Path(__file__).resolve().parent / "project-cost-schema.json").open() as f: schema = json.load(f)