Skip to content

Commit

Permalink
Add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
manics committed Jun 20, 2024
1 parent 3f32519 commit caff3bc
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 16 deletions.
14 changes: 14 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,18 @@ repos:
- id: bandit
entry: bandit
exclude: ^tests/
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
hooks:
- id: mypy
# TODO: Get rid of Any types. mypy doesn't take jsonschema.validate
# into account
# args:
# - --strict
additional_dependencies:
- pandas-stubs
- types-jsonschema
- types-PyYAML
files: ^aws_project_costs

exclude: "^tests/dummy-.*/.*.md$"
2 changes: 1 addition & 1 deletion aws_project_costs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .schema import validate


def main():
def main() -> None:
parser = ArgumentParser("aws-project-costs")
parser.add_argument(
"--config", required=True, help="Project costs configuration file"
Expand Down
42 changes: 28 additions & 14 deletions aws_project_costs/project_costs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from enum import Enum
from operator import itemgetter
from typing import Any, Hashable, Optional

import pandas as pd

Expand All @@ -12,19 +13,21 @@ class CostSourceType(Enum):
SHARED = "shared"


def _get_account_cfg(config, accountname):
def _get_account_cfg(config: dict[str, Any], accountname: str) -> dict[str, Any]:
for acc in config["accounts"]:
if acc["name"].lower() == accountname.lower():
return acc
raise ValueError(f"Configuration for account {accountname} not found")


def _get_project_costshares_in_groups(config, groupnames):
def _get_project_costshares_in_groups(
config: dict[str, Any], groupnames: list[str]
) -> dict[str, int]:
"""
Get the combined set of all projects in groupnames, and their costshare
If a project is in multiple groups it must have the same costshare
"""
projects = dict()
projects: dict[str, int] = dict()
for group in groupnames:
for p in config["project-groups"][group]:
name = p["name"]
Expand All @@ -38,7 +41,12 @@ def _get_project_costshares_in_groups(config, groupnames):
return projects


def _project_specific_account(start, description, project_name, costs_dict):
def _project_specific_account(
start: pd.Timestamp,
description: str,
project_name: str,
costs_dict: list[dict[Hashable, Any]],
) -> list[tuple[str, str, str, Any, str, Any]]:
rows = []

for item in costs_dict:
Expand All @@ -57,14 +65,14 @@ def _project_specific_account(start, description, project_name, costs_dict):


def _shared_account(
start,
description,
proj_tag_names_map,
shared_tag_values,
project_tagname,
projects_in_group,
costs_dict,
):
start: pd.Timestamp,
description: str,
proj_tag_names_map: dict[str, str],
shared_tag_values: list[str],
project_tagname: Optional[str],
projects_in_group: dict[str, int],
costs_dict: list[dict[Hashable, Any]],
) -> list[tuple[str, str, str, Any, str, Any]]:
rows = []

for item in costs_dict:
Expand Down Expand Up @@ -105,7 +113,9 @@ def _shared_account(
return rows


def allocate_costs(*, accountname, config, start, df):
def allocate_costs(
*, accountname: str, config: dict[str, Any], start: pd.Timestamp, df: pd.DataFrame
) -> list[tuple[str, str, str, Any, str, Any]]:
account_cfg = _get_account_cfg(config, accountname)

costs = df[(df["START"] == start) & (df["accountname"] == accountname)]
Expand Down Expand Up @@ -150,7 +160,11 @@ def allocate_costs(*, accountname, config, start, df):
return rows


def analyse_costs_csv(config, costs_csv_filename, output_csv_filename=None):
def analyse_costs_csv(
config: dict[str, Any],
costs_csv_filename: str,
output_csv_filename: Optional[str] = None,
) -> None:
df = pd.read_csv(
costs_csv_filename,
dtype={"accountname": str, f"{PROJECT_TAG}$": str, "COST": float},
Expand Down
3 changes: 2 additions & 1 deletion aws_project_costs/schema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import sys
from pathlib import Path
from typing import Any

import jsonschema
import yaml


def validate(projects, raise_on_error=False):
def validate(projects: dict[str, Any], raise_on_error: bool = False) -> list[str]:
with (Path(__file__).resolve().parent / "project-cost-schema.json").open() as f:
schema = json.load(f)

Expand Down

0 comments on commit caff3bc

Please sign in to comment.