1- from typing import Annotated
1+ from typing import Annotated , List
22
3+ import numpy as np
4+ import torch
35from fastapi import APIRouter , Depends , Request
46from fastapi .security import HTTPBasicCredentials
7+ from torchFastText .datasets import FastTextModelDataset
58
69from api .models .forms import BatchForms
710from api .models .responses import PredictionResponse
811from utils .logging import log_prediction
912from utils .prediction import process_response
10- from utils .preprocessing import preprocess_inputs
13+ from utils .preprocessing import categorical_features , mappings , preprocess_inputs , text_feature
1114from utils .security import get_credentials
1215
16+ router = APIRouter (prefix = "/single" , tags = ["Predict an activity" ])
17+
18+ APE_NIV5_MAPPING = mappings ["nace2025" ]
19+ INV_APE_NIV5_MAPPING = {v : k for k , v in APE_NIV5_MAPPING .items ()}
20+
1321router = APIRouter (prefix = "/batch" , tags = ["Predict a batch of activity" ])
1422
1523
16- @router .post ("/predict" , response_model = PredictionResponse )
24+ @router .post ("/predict" , response_model = List [ PredictionResponse ] )
1725async def predict (
1826 credentials : Annotated [HTTPBasicCredentials , Depends (get_credentials )],
1927 request : Request ,
@@ -33,13 +41,32 @@ async def predict(
3341 Returns:
3442 list: The list of predicted responses.
3543 """
36- query = preprocess_inputs (request .app .state .training_names , forms .forms )
44+ query = preprocess_inputs (forms .forms )
45+
46+ text , categorical_variables = (
47+ query [text_feature ].values ,
48+ query [categorical_features ].values ,
49+ )
50+
51+ dataset = FastTextModelDataset (
52+ texts = text ,
53+ categorical_variables = categorical_variables ,
54+ tokenizer = request .app .state .model .model .tokenizer ,
55+ )
56+
57+ batch_size = len (text ) if len (text ) < 256 else 256
58+ dataloader = dataset .create_dataloader (batch_size = batch_size , shuffle = False , num_workers = 12 )
59+
60+ batch = next (iter (dataloader ))
61+ scores = request .app .state .model (batch ).detach ()
62+ probs = torch .nn .functional .softmax (scores , dim = 1 )
63+ sorted_probs , sorted_probs_indices = probs .sort (descending = True , axis = 1 )
3764
38- predictions = request .app .state .model .predict (query , params = {"k" : nb_echos_max })
65+ predicted_class = sorted_probs_indices [:, :nb_echos_max ].numpy ()
66+ predicted_probs = sorted_probs [:, :nb_echos_max ].numpy ()
3967
40- response = [
41- process_response (predictions , i , nb_echos_max , prob_min , request .app .state .libs ) for i in range (len (predictions [0 ]))
42- ]
68+ predicted_class = np .vectorize (INV_APE_NIV5_MAPPING .get )(predicted_class )
69+ predictions = (predicted_class , predicted_probs )
4370
4471 responses = []
4572 for i in range (len (predictions [0 ])):
0 commit comments