Skip to content

Commit 811cc65

Browse files
[chore] Refactored API to match torchFastText
- Added torch / torchFasttext dependencies - Used dataloaders for inference - Small fixes in inputs and response processing (for batch it returns a list of Response) - load "mappings" from S3
1 parent 7c4fa78 commit 811cc65

File tree

9 files changed

+563
-49
lines changed

9 files changed

+563
-49
lines changed

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ dependencies = [
1616
"pydantic>=2.11.1",
1717
"requests>=2.32.3",
1818
"s3fs>=2025.3.2",
19+
"torch>=2.6.0",
20+
"torchfasttext",
1921
"tqdm>=4.67.1",
2022
"unidecode>=1.3.8",
2123
"uvicorn>=0.34.0",
@@ -35,3 +37,6 @@ line-length = 130
3537

3638
[tool.uv]
3739
default-groups = ["dev"]
40+
41+
[tool.uv.sources]
42+
torchfasttext = { git = "https://github.com/InseeFrLab/torch-fastText.git", branch = "dataset-api" }

setup.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#!/bin/bash
22
git config --global credential.helper store
33

4-
pip install -r requirements.txt
5-
pip install pre-commit
6-
pre-commit install
4+
pip install uv
5+
uv sync
6+
uv run pre-commit install
7+
uv run -m nltk.downloader stopwords
78

89
AWS_ACCESS_KEY_ID=`vault kv get -field=ACCESS_KEY onyxia-kv/projet-ape/s3` && export AWS_ACCESS_KEY_ID
910
AWS_SECRET_ACCESS_KEY=`vault kv get -field=SECRET_KEY onyxia-kv/projet-ape/s3` && export AWS_SECRET_ACCESS_KEY

src/api/constants/models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
VALID_TYPE_FORM = {"A", "B", "C", "D", "E", "G", "I", "L", "M", "N", "P", "R", "S", "X", "Y", "Z"}
2-
VALID_SURFACE = {"1", "2", "3", "4"}
3-
VALID_ACTIV_PERM = {"P", "S"}
1+
VALID_TYPE_FORM = {"A", "B", "C", "D", "E", "G", "I", "L", "M", "N", "P", "R", "S", "X", "Y", "Z", "NaN"}
2+
VALID_SURFACE = {"1", "2", "3", "4", "NaN"}
3+
VALID_ACTIV_PERM = {"P", "S", "NaN"}

src/api/main.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,7 @@ async def lifespan(app: FastAPI):
2626
logger.info("🚀 Starting API lifespan")
2727

2828
model_uri = f"models:/{os.environ['MLFLOW_MODEL_NAME']}/{os.environ['MLFLOW_MODEL_VERSION']}"
29-
app.state.model = mlflow.pyfunc.load_model(model_uri)
30-
run_params = mlflow.get_run(app.state.model.metadata.run_id).data.params
31-
32-
app.state.training_names = [
33-
run_params["text_feature"],
34-
*(v for k, v in run_params.items() if k.startswith("textual_features")),
35-
*(v for k, v in run_params.items() if k.startswith("categorical_features")),
36-
]
29+
app.state.model = mlflow.pytorch.load_model(model_uri)
3730

3831
libs_path = Path("api/data/libs.yaml")
3932
app.state.libs = yaml.safe_load(libs_path.read_text())

src/api/routes/predict_batch.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
1-
from typing import Annotated
1+
from typing import Annotated, List
22

3+
import numpy as np
4+
import torch
35
from fastapi import APIRouter, Depends, Request
46
from fastapi.security import HTTPBasicCredentials
7+
from torchFastText.datasets import FastTextModelDataset
58

69
from api.models.forms import BatchForms
710
from api.models.responses import PredictionResponse
811
from utils.logging import log_prediction
912
from utils.prediction import process_response
10-
from utils.preprocessing import preprocess_inputs
13+
from utils.preprocessing import categorical_features, mappings, preprocess_inputs, text_feature
1114
from utils.security import get_credentials
1215

16+
router = APIRouter(prefix="/single", tags=["Predict an activity"])
17+
18+
APE_NIV5_MAPPING = mappings["nace2025"]
19+
INV_APE_NIV5_MAPPING = {v: k for k, v in APE_NIV5_MAPPING.items()}
20+
1321
router = APIRouter(prefix="/batch", tags=["Predict a batch of activity"])
1422

1523

16-
@router.post("/predict", response_model=PredictionResponse)
24+
@router.post("/predict", response_model=List[PredictionResponse])
1725
async def predict(
1826
credentials: Annotated[HTTPBasicCredentials, Depends(get_credentials)],
1927
request: Request,
@@ -33,13 +41,32 @@ async def predict(
3341
Returns:
3442
list: The list of predicted responses.
3543
"""
36-
query = preprocess_inputs(request.app.state.training_names, forms.forms)
44+
query = preprocess_inputs(forms.forms)
45+
46+
text, categorical_variables = (
47+
query[text_feature].values,
48+
query[categorical_features].values,
49+
)
50+
51+
dataset = FastTextModelDataset(
52+
texts=text,
53+
categorical_variables=categorical_variables,
54+
tokenizer=request.app.state.model.model.tokenizer,
55+
)
56+
57+
batch_size = len(text) if len(text) < 256 else 256
58+
dataloader = dataset.create_dataloader(batch_size=batch_size, shuffle=False, num_workers=12)
59+
60+
batch = next(iter(dataloader))
61+
scores = request.app.state.model(batch).detach()
62+
probs = torch.nn.functional.softmax(scores, dim=1)
63+
sorted_probs, sorted_probs_indices = probs.sort(descending=True, axis=1)
3764

38-
predictions = request.app.state.model.predict(query, params={"k": nb_echos_max})
65+
predicted_class = sorted_probs_indices[:, :nb_echos_max].numpy()
66+
predicted_probs = sorted_probs[:, :nb_echos_max].numpy()
3967

40-
response = [
41-
process_response(predictions, i, nb_echos_max, prob_min, request.app.state.libs) for i in range(len(predictions[0]))
42-
]
68+
predicted_class = np.vectorize(INV_APE_NIV5_MAPPING.get)(predicted_class)
69+
predictions = (predicted_class, predicted_probs)
4370

4471
responses = []
4572
for i in range(len(predictions[0])):

src/api/routes/predict_single.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from typing import Annotated
22

3+
import numpy as np
4+
import torch
35
from fastapi import APIRouter, Depends, Request
46
from fastapi.security import HTTPBasicCredentials
7+
from torchFastText.datasets import FastTextModelDataset
58

69
from api.models.forms import SingleForm
710
from api.models.responses import PredictionResponse
811
from utils.logging import log_prediction
912
from utils.prediction import process_response
10-
from utils.preprocessing import preprocess_inputs
13+
from utils.preprocessing import categorical_features, mappings, preprocess_inputs, text_feature
1114
from utils.security import get_credentials
1215

1316
router = APIRouter(prefix="/single", tags=["Predict an activity"])
1417

18+
APE_NIV5_MAPPING = mappings["nace2025"]
19+
INV_APE_NIV5_MAPPING = {v: k for k, v in APE_NIV5_MAPPING.items()}
20+
1521

1622
@router.post("/predict", response_model=PredictionResponse)
1723
async def predict(
@@ -35,9 +41,30 @@ async def predict(
3541
dict: Response containing APE codes.
3642
"""
3743

38-
query = preprocess_inputs(request.app.state.training_names, [form])
44+
query = preprocess_inputs([form])
45+
46+
text, categorical_variables = (
47+
query[text_feature].values,
48+
query[categorical_features].values,
49+
)
50+
51+
dataset = FastTextModelDataset(
52+
texts=text,
53+
categorical_variables=categorical_variables,
54+
tokenizer=request.app.state.model.model.tokenizer,
55+
)
56+
dataloader = dataset.create_dataloader(batch_size=1, shuffle=False, num_workers=1)
57+
58+
batch = next(iter(dataloader))
59+
scores = request.app.state.model(batch).detach()
60+
probs = torch.nn.functional.softmax(scores, dim=1)
61+
sorted_probs, sorted_probs_indices = probs.sort(descending=True, axis=1)
62+
63+
predicted_class = sorted_probs_indices[:, :nb_echos_max].numpy()
64+
predicted_probs = sorted_probs[:, :nb_echos_max].numpy()
3965

40-
predictions = request.app.state.model.predict(query, params={"k": max(2, nb_echos_max)})
66+
predicted_class = np.vectorize(INV_APE_NIV5_MAPPING.get)(predicted_class)
67+
predictions = (predicted_class, predicted_probs)
4168

4269
response = process_response(predictions, 0, nb_echos_max, prob_min, request.app.state.libs)
4370

src/utils/prediction.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ def process_response(
1818
pred_labels = labels[liasse_nb]
1919
pred_probs = probs[liasse_nb]
2020

21-
valid_predictions = [
22-
(label.replace("__label__", ""), prob) for label, prob in zip(pred_labels, pred_probs) if prob >= prob_min
23-
][:nb_echos_max]
21+
valid_preds = []
22+
mask = pred_probs >= prob_min
23+
valid_predicted_class = pred_labels[mask]
24+
valid_predicted_confidence = pred_probs[mask]
25+
valid_preds.append(tuple(zip(valid_predicted_class, valid_predicted_confidence)))
2426

25-
if not valid_predictions:
27+
if not valid_preds:
2628
raise HTTPException(
2729
status_code=400,
2830
detail="No prediction exceeds the minimum probability threshold.",
@@ -34,10 +36,10 @@ def process_response(
3436
probabilite=float(prob),
3537
libelle=libs[label],
3638
)
37-
for i, (label, prob) in enumerate(valid_predictions)
39+
for i, (label, prob) in enumerate(valid_preds[0])
3840
}
3941

40-
ic = response_data["1"].probabilite - float(pred_probs[1])
41-
response_data["IC"] = ic
42+
confidence_score = pred_probs[0] - pred_probs[1]
43+
response_data["IC"] = confidence_score
4244

4345
return PredictionResponse(response_data)

0 commit comments

Comments
 (0)