Skip to content

Commit 9b868bb

Browse files
committed
add run_id, refactor output response
1 parent 8eafa0c commit 9b868bb

File tree

4 files changed

+89
-13
lines changed

4 files changed

+89
-13
lines changed

src/api/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ async def lifespan(app: FastAPI):
2424
logger.info("🚀 Starting API lifespan")
2525

2626
app.state.model = load_model()
27+
app.state.run_id = app.state.model.metadata.run_id
2728

2829
yield
2930
logger.info("🛑 Shutting down API lifespan")

src/api/models/responses.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Dict, Union
1+
from typing import Any, Dict, Mapping, Union
22

3-
from pydantic import BaseModel, RootModel
3+
from pydantic import BaseModel, RootModel, model_validator
44

55

66
class Prediction(BaseModel):
@@ -9,5 +9,72 @@ class Prediction(BaseModel):
99
libelle: str
1010

1111

12-
class PredictionResponse(RootModel[Dict[str, Union[Prediction, float, str]]]):
13-
pass
12+
class OutputResponse(RootModel[Dict[str, Union[Prediction, float, str]]]):
13+
"""
14+
Contract for the output response of the API including:
15+
- KV of PredictionResponse: normalized prediction responses generated by the model artifact
16+
- MLversion: run_id as version of the ML model
17+
18+
Expected flat structure after normalization:
19+
20+
{
21+
"1": Prediction,
22+
"2": Prediction,
23+
...,
24+
"IC": float, # required confidence score
25+
"MLversion": str # required run_id as model version
26+
}
27+
28+
Notes:
29+
- The output reflects what the model artifact produces, but the API applies
30+
`model_dump()` in `predict()` before returning to ensure schema consistency.
31+
- Any changes to the output schema (e.g., new fields, renaming) must be documented
32+
both here and in training repo (codif-ape-training) to maintain API contract clarity.
33+
"""
34+
35+
@model_validator(mode="after")
36+
@classmethod
37+
def _normalize(cls, data: Any) -> "OutputResponse":
38+
# unwrap root if called with an instance
39+
raw = data.root if isinstance(data, cls) else data
40+
41+
if not isinstance(raw, Mapping):
42+
raise TypeError("OutputResponse: expected a dict/mapping")
43+
44+
# IC (required) - accept numbers or numeric strings
45+
try:
46+
ic = float(raw["IC"])
47+
except KeyError:
48+
raise ValueError("OutputResponse: missing required key 'IC'")
49+
except (TypeError, ValueError) as e:
50+
raise ValueError(f"OutputResponse: 'IC' not convertible to float: {e}") from e
51+
52+
# MLversion (required)
53+
try:
54+
ml_version = str(raw["MLversion"])
55+
except KeyError:
56+
raise ValueError("OutputResponse: missing required key 'MLversion'")
57+
except Exception as e:
58+
raise ValueError(f"OutputResponse: 'MLversion' not convertible to str: {e}") from e
59+
60+
# allow only digit keys + IC + MLversion
61+
allowed = {k for k in raw.keys() if k.isdigit()} | {"IC", "MLversion"}
62+
extra = set(raw.keys()) - allowed
63+
if extra:
64+
raise ValueError(f"OutputResponse: unexpected keys: {sorted(extra)}")
65+
66+
# ensure digit keys map to Prediction
67+
for k in (k for k in raw.keys() if k.isdigit()):
68+
val = raw[k]
69+
if not isinstance(val, (Mapping, Prediction)):
70+
raise ValueError(
71+
f"OutputResponse: value for key '{k}'must be a mapping or Prediction"
72+
)
73+
74+
# normalize in place
75+
normalized = dict(raw)
76+
normalized["IC"] = ic
77+
normalized["MLversion"] = ml_version
78+
79+
data.root = normalized
80+
return data

src/api/routes/predict.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from fastapi.security import HTTPBasicCredentials
55

66
from api.models.forms import BatchForms
7-
from api.models.responses import PredictionResponse
7+
from api.models.responses import OutputResponse
88
from utils.security import get_credentials
99

1010
router = APIRouter(prefix="/predict", tags=["Predict NACE code for a list of activities"])
1111

1212

13-
@router.post("/", response_model=List[PredictionResponse])
13+
@router.post("/", response_model=List[OutputResponse])
1414
async def predict(
1515
credentials: Annotated[HTTPBasicCredentials, Depends(get_credentials)],
1616
request: Request,
@@ -27,12 +27,17 @@ async def predict(
2727
credentials (HTTPBasicCredentials): The credentials for authentication.
2828
forms (Forms): The input data in the form of Forms object.
2929
nb_echos_max (int, optional): The maximum number of predictions to return. Defaults to 5.
30-
prob_min (float, optional): The minimum probability threshold for predictions. Defaults to 0.01.
31-
num_workers (int, optional): Number of CPU for multiprocessing in Dataloader. Defaults to 1.
30+
prob_min (float, optional): The minimum probability threshold for predictions.
31+
Defaults to 0.01.
32+
num_workers (int, optional): Number of CPU for multiprocessing in Dataloader.
33+
Defaults to 1.
3234
batch_size (int, optional): Size of a batch for batch prediction.
3335
34-
For single predictions, we recommend keeping num_workers and batch_size to 1 for better performance.
35-
For batched predictions, consider increasing these two parameters (num_workers can range from 4 to 12, batch size can be increased up to 256) to optimize performance.
36+
For single predictions, we recommend keeping num_workers and batch_size to 1
37+
for better performance.
38+
For batched predictions, consider increasing these two parameters
39+
(num_workers can range from 4 to 12, batch size can be increased up to 256)
40+
to optimize performance.
3641
3742
Returns:
3843
list: The list of predicted responses.
@@ -51,4 +56,7 @@ async def predict(
5156
}
5257

5358
output = request.app.state.model.predict(input_data, params=params_dict)
54-
return [out.model_dump() for out in output]
59+
return [
60+
OutputResponse({**out.model_dump(), "MLversion": request.app.state.run_id})
61+
for out in output
62+
]

src/utils/logging.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from api.models.responses import PredictionResponse
3+
from api.models.responses import OutputResponse
44

55

66
def configure_logging():
@@ -13,6 +13,6 @@ def configure_logging():
1313
)
1414

1515

16-
def log_prediction(query: dict, response: PredictionResponse, index: int = 0):
16+
def log_prediction(query: dict, response: OutputResponse, index: int = 0):
1717
query_line = {key: value[index] for key, value in query.items()}
1818
logging.info(f"{{'Query': {query_line}, 'Response': {response.model_dump()}}}")

0 commit comments

Comments
 (0)