From 140f5d1e53f17e5b7f25ed669c432e773ee13883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Tue, 21 Jan 2025 19:35:18 +0900 Subject: [PATCH 01/17] feat: add new decorator for pydantic req/res in manager --- .../backend/manager/api/pydantic_handlers.py | 194 ++++++++++++++++++ 1 file changed, 194 insertions(+) create mode 100644 src/ai/backend/manager/api/pydantic_handlers.py diff --git a/src/ai/backend/manager/api/pydantic_handlers.py b/src/ai/backend/manager/api/pydantic_handlers.py new file mode 100644 index 0000000000..65a5dbc5de --- /dev/null +++ b/src/ai/backend/manager/api/pydantic_handlers.py @@ -0,0 +1,194 @@ +import functools +import inspect +import json +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Generic, Optional, Self, Type, TypeVar + +import yaml +from aiohttp import web +from pydantic import BaseModel + +from .exceptions import InvalidAPIParameters + +T = TypeVar("T", bound=BaseModel) + + +class Param(ABC, Generic[T]): + @abstractmethod + def from_request(self, request: web.Request) -> T: + pass + + +class QueryParam(Param[T]): + def __init__(self, model: Type[T]): + self.model = model + + def from_request(self, request: web.Request) -> T: + return self.model.model_validate(request.query) + + +class HeaderParam(Param[T]): + def __init__(self, model: Type[T]): + self.model = model + + def from_request(self, request: web.Request) -> T: + return self.model.model_validate(dict(request.headers)) + + +class PathParam(Param[T]): + def __init__(self, model: Type[T]): + self.model = model + + def from_request(self, request: web.Request) -> T: + return self.model.model_validate(dict(request.match_info)) + + +class MiddlewareParam(Param): + @abstractmethod + def from_request(cls, request: web.Request) -> Self: + pass + + +@dataclass +class Parameter: + name: str + model: Type[BaseModel] + default: Any + + +async def extract_param_value(request: web.Request, param: Parameter) -> Optional[Any]: + match param: + # MiddlewareParam Type + case Parameter(model=model) if isinstance(model, type) and isinstance( + model, MiddlewareParam + ): + return model.from_request(request) + + # HeaderParam, QueryParam, PathParam Type + case Parameter(default=default) if isinstance(default, Param): + return default.from_request(request) + + # Body + case Parameter(model=model) if isinstance(model, type) and not issubclass(model, Param): + if not request.can_read_body: + raise InvalidAPIParameters("Malformed body") + + body = await request.text() + if not body: + raise InvalidAPIParameters("Malformed body") + + if request.content_type == "text/yaml": + data = yaml.load(body, Loader=yaml.BaseLoader) + else: + data = json.loads(body) + + return model.model_validate(data) + + case _: + raise InvalidAPIParameters( + f"Parameter '{param.name}' must be MiddlewareParam, use Param as default value, or be a BaseModel for body" + ) + + +class HandlerParameters: + def __init__(self): + self.params: dict[str, Any] = {} + + def add(self, name: str, value: Any) -> None: + if value is not None: + self.params[name] = value + + def get_all(self) -> dict[str, Any]: + return self.params + + +async def pydantic_handler(request: web.Request, handler) -> web.Response: + signature = inspect.signature(handler) + handler_params = HandlerParameters() + for name, param in signature.parameters.items(): + # Raise error when parameter has no type hint or not wrapped by 'Annotated' + if param.default is inspect.Parameter.empty and isinstance(param.annotation, type(None)): + raise InvalidAPIParameters(f"Type hint or Annotated must be added: {param.name}") + + param_info = Parameter( + name=name, + model=param.annotation, + default=param.default, + ) + + value = await extract_param_value(request, param_info) + handler_params.add(name, value) + + response = await handler(**handler_params.get_all()) + + if not isinstance(response, BaseModel): + raise InvalidAPIParameters(f"Only Pydantic Response can be handle: {type(response)}") + + return web.json_response(response.model_dump(mode="json")) + + +def pydantic_api_handler(handler): + @functools.wraps(handler) + async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: + return await pydantic_handler(request, handler) + + return wrapped + + +""" +This decorator processes HTTP request parameters using Pydantic models. +It supports four types of parameters: + +1. Request Body (automatically parsed as JSON/YAML): + @pydantic_api_handler + async def handler(user: UserModel): # UserModel is a Pydantic model + return Response(user=user) + +2. Query Parameters: + @pydantic_api_handler + async def handler(query: QueryModel = QueryParam(QueryModel)): + return Response(query=query) + +3. Headers: + @pydantic_api_handler + async def handler(headers: HeaderModel = HeaderParam(HeaderModel)): + return Response(headers=headers) + +4. Path Parameters: + @pydantic_api_handler + async def handler(path: PathModel = PathParam(PathModel)): + return Response(path=path) + +5. Middleware Parameters: + # Need to extend MiddlewareParam and implement 'from_request' + class AuthMiddlewareParam(MiddlewareParam): + user_id: str + user_email: str + @classmethod + def from_request(cls, request: web.Request) -> Self: + # Extract and validate data from request + user_id = request["user"]["uuid"] + user_email = request["user"]["email"] + return cls(user_id=user_id) + + @pydantic_api_handler + async def handler(auth: AuthMiddlewareParam): # No default value + return Response(user_id=auth.user_id) + +6. Multiple Parameters: + @pydantic_api_handler + async def handler( + user: UserModel, # body + query: QueryModel = QueryParam(QueryModel), # query parameters + headers: HeaderModel = HeaderParam(HeaderModel), # headers + auth: AuthMiddleware, # middleware parameter + ): + return Response(user=user, query=query, headers=headers, user_id=auth.user_id) + +Note: +- All parameters must have type hints or wrapped by Annotated +- Response must be a Pydantic model +- Request body is parsed from JSON by default, or from YAML if content-type is 'text/yaml' +- MiddlewareParam classes must implement the from_request classmethod +""" From 649cc2c5917a50447876607053cbe361ad819cc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Tue, 21 Jan 2025 10:49:35 +0000 Subject: [PATCH 02/17] chore: update api schema dump Co-authored-by: octodog --- docs/manager/rest-reference/openapi.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/manager/rest-reference/openapi.json b/docs/manager/rest-reference/openapi.json index f6205a8529..a4ffe21404 100644 --- a/docs/manager/rest-reference/openapi.json +++ b/docs/manager/rest-reference/openapi.json @@ -3,7 +3,7 @@ "info": { "title": "Backend.AI Manager API", "description": "Backend.AI Manager REST API specification", - "version": "24.12.1", + "version": "25.1.1", "contact": { "name": "Lablup Inc.", "url": "https://docs.backend.ai", From 7e854992ff9648e4d17a1689afe28e04764afa16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Wed, 22 Jan 2025 11:28:56 +0900 Subject: [PATCH 03/17] refactor: move pydantic handler into common pkg --- src/ai/backend/common/exception.py | 12 +++++++++++ .../api => common}/pydantic_handlers.py | 20 +++++++++++-------- 2 files changed, 24 insertions(+), 8 deletions(-) rename src/ai/backend/{manager/api => common}/pydantic_handlers.py (88%) diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index da40412a49..b32376a718 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -110,3 +110,15 @@ class VolumeUnmountFailed(RuntimeError): """ Represents a umount process failure. """ + + +class MalformedRequestBody(ValueError): + """ + Represents a malformed request body. + """ + + +class InvalidAPIParametersModel(ValueError): + """ + Exception raised for invalid API parameters. + """ diff --git a/src/ai/backend/manager/api/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py similarity index 88% rename from src/ai/backend/manager/api/pydantic_handlers.py rename to src/ai/backend/common/pydantic_handlers.py index 65a5dbc5de..e50c0d1d35 100644 --- a/src/ai/backend/manager/api/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -9,7 +9,7 @@ from aiohttp import web from pydantic import BaseModel -from .exceptions import InvalidAPIParameters +from .exception import InvalidAPIParametersModel, MalformedRequestBody T = TypeVar("T", bound=BaseModel) @@ -33,7 +33,7 @@ def __init__(self, model: Type[T]): self.model = model def from_request(self, request: web.Request) -> T: - return self.model.model_validate(dict(request.headers)) + return self.model.model_validate(request.headers) class PathParam(Param[T]): @@ -41,7 +41,7 @@ def __init__(self, model: Type[T]): self.model = model def from_request(self, request: web.Request) -> T: - return self.model.model_validate(dict(request.match_info)) + return self.model.model_validate(request.match_info) class MiddlewareParam(Param): @@ -72,11 +72,15 @@ async def extract_param_value(request: web.Request, param: Parameter) -> Optiona # Body case Parameter(model=model) if isinstance(model, type) and not issubclass(model, Param): if not request.can_read_body: - raise InvalidAPIParameters("Malformed body") + raise MalformedRequestBody( + f"Malformed body - URL: {request.url}, Method: {request.method}" + ) body = await request.text() if not body: - raise InvalidAPIParameters("Malformed body") + raise MalformedRequestBody( + f"Malformed body - URL: {request.url}, Method: {request.method}" + ) if request.content_type == "text/yaml": data = yaml.load(body, Loader=yaml.BaseLoader) @@ -86,7 +90,7 @@ async def extract_param_value(request: web.Request, param: Parameter) -> Optiona return model.model_validate(data) case _: - raise InvalidAPIParameters( + raise InvalidAPIParametersModel( f"Parameter '{param.name}' must be MiddlewareParam, use Param as default value, or be a BaseModel for body" ) @@ -109,7 +113,7 @@ async def pydantic_handler(request: web.Request, handler) -> web.Response: for name, param in signature.parameters.items(): # Raise error when parameter has no type hint or not wrapped by 'Annotated' if param.default is inspect.Parameter.empty and isinstance(param.annotation, type(None)): - raise InvalidAPIParameters(f"Type hint or Annotated must be added: {param.name}") + raise InvalidAPIParametersModel(f"Type hint or Annotated must be added: {param.name}") param_info = Parameter( name=name, @@ -123,7 +127,7 @@ async def pydantic_handler(request: web.Request, handler) -> web.Response: response = await handler(**handler_params.get_all()) if not isinstance(response, BaseModel): - raise InvalidAPIParameters(f"Only Pydantic Response can be handle: {type(response)}") + raise InvalidAPIParametersModel(f"Only Pydantic Response can be handle: {type(response)}") return web.json_response(response.model_dump(mode="json")) From bc5b480a986e8b879fbd68231a0f6d460f2009c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 23 Jan 2025 17:40:01 +0900 Subject: [PATCH 04/17] refactor: make decorator usage more simple using generic --- src/ai/backend/common/exception.py | 83 ++++++- src/ai/backend/common/pydantic_handlers.py | 239 +++++++++++++-------- 2 files changed, 229 insertions(+), 93 deletions(-) diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index b32376a718..dca1d2392e 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -1,4 +1,7 @@ -from typing import Any, Mapping +import json +from typing import Any, Mapping, Optional + +from aiohttp import web class ConfigurationError(Exception): @@ -112,13 +115,79 @@ class VolumeUnmountFailed(RuntimeError): """ -class MalformedRequestBody(ValueError): +class BackendError(web.HTTPError): """ - Represents a malformed request body. + An RFC-7807 error class as a drop-in replacement of the original + aiohttp.web.HTTPError subclasses. """ + error_type: str = "https://api.backend.ai/probs/general-error" + error_title: str = "General Backend API Error." -class InvalidAPIParametersModel(ValueError): - """ - Exception raised for invalid API parameters. - """ + content_type: str + extra_msg: Optional[str] + + body_dict: dict[str, Any] + + def __init__(self, extra_msg: str | None = None, extra_data: Optional[Any] = None, **kwargs): + super().__init__(**kwargs) + self.args = (self.status_code, self.reason, self.error_type) + self.empty_body = False + self.content_type = "application/problem+json" + self.extra_msg = extra_msg + self.extra_data = extra_data + body = { + "type": self.error_type, + "title": self.error_title, + } + if extra_msg is not None: + body["msg"] = extra_msg + if extra_data is not None: + body["data"] = extra_data + self.body_dict = body + self.body = json.dumps(body).encode() + + def __str__(self): + lines = [] + if self.extra_msg: + lines.append(f"{self.error_title} ({self.extra_msg})") + else: + lines.append(self.error_title) + if self.extra_data: + lines.append(" -> extra_data: " + repr(self.extra_data)) + return "\n".join(lines) + + def __repr__(self): + lines = [] + if self.extra_msg: + lines.append( + f"<{type(self).__name__}({self.status}): {self.error_title} ({self.extra_msg})>" + ) + else: + lines.append(f"<{type(self).__name__}({self.status}): {self.error_title}>") + if self.extra_data: + lines.append(" -> extra_data: " + repr(self.extra_data)) + return "\n".join(lines) + + def __reduce__(self): + return ( + type(self), + (), # empty the constructor args to make unpickler to use + # only the exact current state in __dict__ + self.__dict__, + ) + + +class MalformedRequestBody(BackendError, web.HTTPBadRequest): + error_type = "https://api.backend.ai/probs/generic-bad-request" + error_title = "Malformed request body." + + +class InvalidAPIParameters(BackendError, web.HTTPBadRequest): + error_type = "https://api.backend.ai/probs/generic-bad-request" + error_title = "Invalid or Missing API Parameters." + + +class MiddlewareParamParsingFailed(BackendError, web.HTTPInternalServerError): + error_type = "https://api.backend.ai/probs/internal-server-error" + error_title = "Middleware parameter parsing failed." diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index e50c0d1d35..2bb0076fd4 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -3,100 +3,153 @@ import json from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Generic, Optional, Self, Type, TypeVar +from typing import Any, Generic, Optional, Self, Type, TypeVar, get_args, get_origin -import yaml from aiohttp import web +from aiohttp.web_urldispatcher import UrlMappingMatchInfo +from multidict import CIMultiDictProxy, MultiMapping from pydantic import BaseModel +from pydantic_core._pydantic_core import ValidationError -from .exception import InvalidAPIParametersModel, MalformedRequestBody +from .exception import InvalidAPIParameters, MalformedRequestBody, MiddlewareParamParsingFailed T = TypeVar("T", bound=BaseModel) -class Param(ABC, Generic[T]): - @abstractmethod - def from_request(self, request: web.Request) -> T: - pass +class BodyParam(Generic[T]): + def __init__(self, model: Type[T]) -> None: + self.model = model + self._parsed: Optional[T] = None + @property + def parsed(self) -> T: + if not self._parsed: + raise ValueError("Data not yet parsed") + return self._parsed -class QueryParam(Param[T]): - def __init__(self, model: Type[T]): + def from_body(self, json_body: str) -> Self: + self._parsed = self.model.model_validate(json_body) + return self + + +class QueryParam(Generic[T]): + def __init__(self, model: Type[T]) -> None: self.model = model + self._parsed: Optional[T] = None + + @property + def parsed(self) -> T: + if not self._parsed: + raise ValueError("Data not yet parsed") + return self._parsed - def from_request(self, request: web.Request) -> T: - return self.model.model_validate(request.query) + def from_query(self, query: MultiMapping[str]) -> Self: + self._parsed = self.model.model_validate(query) + return self -class HeaderParam(Param[T]): - def __init__(self, model: Type[T]): +class HeaderParam(Generic[T]): + def __init__(self, model: Type[T]) -> None: self.model = model + self._parsed: Optional[T] = None - def from_request(self, request: web.Request) -> T: - return self.model.model_validate(request.headers) + @property + def parsed(self) -> T: + if not self._parsed: + raise ValueError("Data not yet parsed") + return self._parsed + def from_header(self, headers: CIMultiDictProxy[str]) -> Self: + self._parsed = self.model.model_validate(headers) + return self -class PathParam(Param[T]): - def __init__(self, model: Type[T]): + +class PathParam(Generic[T]): + def __init__(self, model: Type[T]) -> None: self.model = model + self._parsed: Optional[T] = None + + @property + def parsed(self) -> T: + if not self._parsed: + raise ValueError("Data not yet parsed") + return self._parsed - def from_request(self, request: web.Request) -> T: - return self.model.model_validate(request.match_info) + def from_path(self, match_info: UrlMappingMatchInfo) -> Self: + self._parsed = self.model.model_validate(match_info) + return self -class MiddlewareParam(Param): +class MiddlewareParam(ABC, BaseModel): + @classmethod @abstractmethod def from_request(cls, request: web.Request) -> Self: pass @dataclass -class Parameter: +class _ParsedSignature: name: str - model: Type[BaseModel] - default: Any + param_type: Any + + +@dataclass +class BaseResponse: + data: BaseModel + status_code: int = 200 -async def extract_param_value(request: web.Request, param: Parameter) -> Optional[Any]: - match param: +async def extract_param_value( + request: web.Request, parsed_signature: _ParsedSignature +) -> Optional[Any]: + try: + param_type = parsed_signature.param_type + # MiddlewareParam Type - case Parameter(model=model) if isinstance(model, type) and isinstance( - model, MiddlewareParam - ): - return model.from_request(request) - - # HeaderParam, QueryParam, PathParam Type - case Parameter(default=default) if isinstance(default, Param): - return default.from_request(request) - - # Body - case Parameter(model=model) if isinstance(model, type) and not issubclass(model, Param): - if not request.can_read_body: - raise MalformedRequestBody( - f"Malformed body - URL: {request.url}, Method: {request.method}" - ) - - body = await request.text() - if not body: - raise MalformedRequestBody( - f"Malformed body - URL: {request.url}, Method: {request.method}" - ) - - if request.content_type == "text/yaml": - data = yaml.load(body, Loader=yaml.BaseLoader) - else: - data = json.loads(body) - - return model.model_validate(data) - - case _: - raise InvalidAPIParametersModel( - f"Parameter '{param.name}' must be MiddlewareParam, use Param as default value, or be a BaseModel for body" - ) + if get_origin(param_type) is None and issubclass(param_type, MiddlewareParam): + try: + return param_type.from_request(request) + except ValidationError: + raise MiddlewareParamParsingFailed(f"Failed while parsing {parsed_signature.name}") + + # If origin type name is BodyParam/QueryParam/HeaderParam/PathParam + origin_name = get_origin(param_type).__name__ + pydantic_model = get_args(param_type)[0] + param_instance = param_type(pydantic_model) + + match origin_name: + case "BodyParam": + if not request.can_read_body: + raise MalformedRequestBody( + f"Malformed body - URL: {request.url}, Method: {request.method}" + ) + try: + body = await request.json() + except json.decoder.JSONDecodeError: + raise MalformedRequestBody( + f"Malformed body - URL: {request.url}, Method: {request.method}" + ) + return param_instance.from_body(body) + + case "QueryParam": + return param_instance.from_query(request.query) + + case "HeaderParam": + return param_instance.from_header(request.headers) + + case "PathParam": + return param_instance.from_path(request.match_info) + + raise InvalidAPIParameters( + f"Parameter '{parsed_signature.name}' must use one of QueryParam, PathParam, HeaderParam, MiddlewareParam, BodyParam" + ) + except ValidationError as e: + raise InvalidAPIParameters(str(e)) -class HandlerParameters: - def __init__(self): + +class _HandlerParameters: + def __init__(self) -> None: self.params: dict[str, Any] = {} def add(self, name: str, value: Any) -> None: @@ -109,27 +162,32 @@ def get_all(self) -> dict[str, Any]: async def pydantic_handler(request: web.Request, handler) -> web.Response: signature = inspect.signature(handler) - handler_params = HandlerParameters() + handler_params = _HandlerParameters() for name, param in signature.parameters.items(): # Raise error when parameter has no type hint or not wrapped by 'Annotated' - if param.default is inspect.Parameter.empty and isinstance(param.annotation, type(None)): - raise InvalidAPIParametersModel(f"Type hint or Annotated must be added: {param.name}") + if param.annotation is inspect.Parameter.empty: + raise InvalidAPIParameters( + f"Type hint or Annotated must be added in API handler signature: {param.name}" + ) - param_info = Parameter( - name=name, - model=param.annotation, - default=param.default, - ) + parsed_signature = _ParsedSignature(name=name, param_type=param.annotation) + value = await extract_param_value(request=request, parsed_signature=parsed_signature) + + if not value: + raise InvalidAPIParameters( + f"Type hint or Annotated must be added in API handler signature: {param.name}" + ) - value = await extract_param_value(request, param_info) handler_params.add(name, value) response = await handler(**handler_params.get_all()) - if not isinstance(response, BaseModel): - raise InvalidAPIParametersModel(f"Only Pydantic Response can be handle: {type(response)}") + if not isinstance(response, BaseResponse): + raise InvalidAPIParameters( + f"Only Response wrapped by BaseResponse Class can be handle: {type(response)}" + ) - return web.json_response(response.model_dump(mode="json")) + return web.json_response(response.data.model_dump(mode="json"), status=response.status_code) def pydantic_api_handler(handler): @@ -146,23 +204,27 @@ async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: 1. Request Body (automatically parsed as JSON/YAML): @pydantic_api_handler - async def handler(user: UserModel): # UserModel is a Pydantic model - return Response(user=user) + async def handler(body: BodyParam[UserModel]): # UserModel is a Pydantic model + user = body.parsed # 'parsed' property gets pydantic model you defined + return BaseResponse(user=user) 2. Query Parameters: @pydantic_api_handler - async def handler(query: QueryModel = QueryParam(QueryModel)): - return Response(query=query) + async def handler(query: QueryParam[QueryPathModel]): + query_path = query.parsed + return Response(query=query_path) 3. Headers: @pydantic_api_handler - async def handler(headers: HeaderModel = HeaderParam(HeaderModel)): - return Response(headers=headers) + async def handler(headers: HeaderParam[HeaderModel]): + parsed_header = headers.parsed + return Response(headers=parsed_headers) 4. Path Parameters: @pydantic_api_handler async def handler(path: PathModel = PathParam(PathModel)): - return Response(path=path) + parsed_path = path.parsed + return Response(path=parsed_path) 5. Middleware Parameters: # Need to extend MiddlewareParam and implement 'from_request' @@ -177,22 +239,27 @@ def from_request(cls, request: web.Request) -> Self: return cls(user_id=user_id) @pydantic_api_handler - async def handler(auth: AuthMiddlewareParam): # No default value + async def handler(auth: AuthMiddlewareParam): # No generic return Response(user_id=auth.user_id) 6. Multiple Parameters: @pydantic_api_handler async def handler( - user: UserModel, # body - query: QueryModel = QueryParam(QueryModel), # query parameters - headers: HeaderModel = HeaderParam(HeaderModel), # headers + user: BodyParam[UserModel], # body + query: QueryParam[QueryModel], # query parameters + headers: HeaderParam[HeaderModel], # headers auth: AuthMiddleware, # middleware parameter ): - return Response(user=user, query=query, headers=headers, user_id=auth.user_id) + return Response( + user=user.parsed.user_id, + query=query.parsed.page, + headers=headers.parsed.auth, + user_id=auth.user_id + ) Note: - All parameters must have type hints or wrapped by Annotated -- Response must be a Pydantic model -- Request body is parsed from JSON by default, or from YAML if content-type is 'text/yaml' +- Response must be a Pydantic model (Use BaseResponse) +- Request body is parsed must be json format - MiddlewareParam classes must implement the from_request classmethod """ From 329dab52497ccffceb96af70b951d3a631428b57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 23 Jan 2025 17:40:50 +0900 Subject: [PATCH 05/17] test: add test for pydantic api handler decorator --- python.lock | 123 ++++++---- requirements.txt | 1 + tests/common/conftest.py | 20 ++ tests/common/test_pydantic_handlers.py | 304 +++++++++++++++++++++++++ 4 files changed, 409 insertions(+), 39 deletions(-) create mode 100644 tests/common/test_pydantic_handlers.py diff --git a/python.lock b/python.lock index 1901cfc340..b7cb52a22f 100644 --- a/python.lock +++ b/python.lock @@ -72,6 +72,7 @@ // "pydantic~=2.9.2", // "pyhumps~=3.8.0", // "pyroscope-io~=0.8.8", +// "pytest-aiohttp~=1.0.5", // "pytest-dependency>=0.6.0", // "pytest>=8.3.3", // "python-dateutil>=2.9", @@ -458,13 +459,13 @@ "artifacts": [ { "algorithm": "sha256", - "hash": "6975f31fe5e7f2113a41bd387221f31854f285ecbc05527272cd8ba4c50764a3", - "url": "https://files.pythonhosted.org/packages/92/23/04a00b3714803e5a58f893eec230b58956e1e8289d3e223d9e294dac3cda/aioresponses-0.7.7-py2.py3-none-any.whl" + "hash": "b73bd4400d978855e55004b23a3a84cb0f018183bcf066a85ad392800b5b9a94", + "url": "https://files.pythonhosted.org/packages/12/b7/584157e43c98aa89810bc2f7099e7e01c728ecf905a66cf705106009228f/aioresponses-0.7.8-py2.py3-none-any.whl" }, { "algorithm": "sha256", - "hash": "66292f1d5c94a3cb984f3336d806446042adb17347d3089f2d3962dd6e5ba55a", - "url": "https://files.pythonhosted.org/packages/27/eb/a69466280306dc9976687cda06d2c9195ff72533192184627f5e7b1d3f1e/aioresponses-0.7.7.tar.gz" + "hash": "b861cdfe5dc58f3b8afac7b0a6973d5d7b2cb608dd0f6253d16b8ee8eaf6df11", + "url": "https://files.pythonhosted.org/packages/de/03/532bbc645bdebcf3b6af3b25d46655259d66ce69abba7720b71ebfabbade/aioresponses-0.7.8.tar.gz" } ], "project_name": "aioresponses", @@ -473,7 +474,7 @@ "packaging>=22.0" ], "requires_python": null, - "version": "0.7.7" + "version": "0.7.8" }, { "artifacts": [ @@ -1046,66 +1047,61 @@ "artifacts": [ { "algorithm": "sha256", - "hash": "83e560faaec38a956dfb3d62e05e1703ee50432b45b788c09e25107c5058bd71", - "url": "https://files.pythonhosted.org/packages/65/77/8bbca82f70b062181cf0ae53fd43f1ac6556f3078884bfef9da2269c06a3/boto3-1.35.99-py3-none-any.whl" - }, - { - "algorithm": "sha256", - "hash": "e0abd794a7a591d90558e92e29a9f8837d25ece8e3c120e530526fe27eba5fca", - "url": "https://files.pythonhosted.org/packages/f7/99/3e8b48f15580672eda20f33439fc1622bd611f6238b6d05407320e1fb98c/boto3-1.35.99.tar.gz" + "hash": "f9843a5d06f501d66ada06f5a5417f671823af2cf319e36ceefa1bafaaaaa953", + "url": "https://files.pythonhosted.org/packages/79/97/4697aa8050e306d6139815996adeb263ddc83024399a188e8b42587665db/boto3-1.36.3-py3-none-any.whl" } ], "project_name": "boto3", "requires_dists": [ - "botocore<1.36.0,>=1.35.99", + "botocore<1.37.0,>=1.36.3", "botocore[crt]<2.0a0,>=1.21.0; extra == \"crt\"", "jmespath<2.0.0,>=0.7.1", - "s3transfer<0.11.0,>=0.10.0" + "s3transfer<0.12.0,>=0.11.0" ], "requires_python": ">=3.8", - "version": "1.35.99" + "version": "1.36.3" }, { "artifacts": [ { "algorithm": "sha256", - "hash": "b22d27b6b617fc2d7342090d6129000af2efd20174215948c0d7ae2da0fab445", - "url": "https://files.pythonhosted.org/packages/fc/dd/d87e2a145fad9e08d0ec6edcf9d71f838ccc7acdd919acc4c0d4a93515f8/botocore-1.35.99-py3-none-any.whl" + "hash": "536ab828e6f90dbb000e3702ac45fd76642113ae2db1b7b1373ad24104e89255", + "url": "https://files.pythonhosted.org/packages/9f/14/f952fed35b9c04aa66453b5fb5d1262a5a9f5dfdcb396d387c1ff0c6da41/botocore-1.36.3-py3-none-any.whl" }, { "algorithm": "sha256", - "hash": "1eab44e969c39c5f3d9a3104a0836c24715579a455f12b3979a31d7cde51b3c3", - "url": "https://files.pythonhosted.org/packages/7c/9c/1df6deceee17c88f7170bad8325aa91452529d683486273928eecfd946d8/botocore-1.35.99.tar.gz" + "hash": "775b835e979da5c96548ed1a0b798101a145aec3cd46541d62e27dda5a94d7f8", + "url": "https://files.pythonhosted.org/packages/3a/61/69eb06a803c83e0da733b60b2bc65880c18ef2dee19ee10cf8732794a3c1/botocore-1.36.3.tar.gz" } ], "project_name": "botocore", "requires_dists": [ - "awscrt==0.22.0; extra == \"crt\"", + "awscrt==0.23.4; extra == \"crt\"", "jmespath<2.0.0,>=0.7.1", "python-dateutil<3.0.0,>=2.1", "urllib3!=2.2.0,<3,>=1.25.4; python_version >= \"3.10\"", "urllib3<1.27,>=1.25.4; python_version < \"3.10\"" ], "requires_python": ">=3.8", - "version": "1.35.99" + "version": "1.36.3" }, { "artifacts": [ { "algorithm": "sha256", - "hash": "02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292", - "url": "https://files.pythonhosted.org/packages/a4/07/14f8ad37f2d12a5ce41206c21820d8cb6561b728e51fad4530dff0552a67/cachetools-5.5.0-py3-none-any.whl" + "hash": "b76651fdc3b24ead3c648bbdeeb940c1b04d365b38b4af66788f9ec4a81d42bb", + "url": "https://files.pythonhosted.org/packages/ec/4e/de4ff18bcf55857ba18d3a4bd48c8a9fde6bb0980c9d20b263f05387fd88/cachetools-5.5.1-py3-none-any.whl" }, { "algorithm": "sha256", - "hash": "2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a", - "url": "https://files.pythonhosted.org/packages/c3/38/a0f315319737ecf45b4319a8cd1f3a908e29d9277b46942263292115eee7/cachetools-5.5.0.tar.gz" + "hash": "70f238fbba50383ef62e55c6aff6d9673175fe59f7c6782c7a0b9e38f4a9df95", + "url": "https://files.pythonhosted.org/packages/d9/74/57df1ab0ce6bc5f6fa868e08de20df8ac58f9c44330c7671ad922d2bbeae/cachetools-5.5.1.tar.gz" } ], "project_name": "cachetools", "requires_dists": [], "requires_python": ">=3.7", - "version": "5.5.0" + "version": "5.5.1" }, { "artifacts": [ @@ -3031,21 +3027,21 @@ "artifacts": [ { "algorithm": "sha256", - "hash": "f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e", - "url": "https://files.pythonhosted.org/packages/a9/6a/fd08d94654f7e67c52ca30523a178b3f8ccc4237fce4be90d39c938a831a/prompt_toolkit-3.0.48-py3-none-any.whl" + "hash": "9b6427eb19e479d98acff65196a307c555eb567989e6d88ebbb1b509d9779198", + "url": "https://files.pythonhosted.org/packages/e4/ea/d836f008d33151c7a1f62caf3d8dd782e4d15f6a43897f64480c2b8de2ad/prompt_toolkit-3.0.50-py3-none-any.whl" }, { "algorithm": "sha256", - "hash": "d6623ab0477a80df74e646bdbc93621143f5caf104206aa29294d53de1a03d90", - "url": "https://files.pythonhosted.org/packages/2d/4f/feb5e137aff82f7c7f3248267b97451da3644f6cdc218edfe549fb354127/prompt_toolkit-3.0.48.tar.gz" + "hash": "544748f3860a2623ca5cd6d2795e7a14f3d0e1c3c9728359013f79877fc89bab", + "url": "https://files.pythonhosted.org/packages/a1/e1/bd15cb8ffdcfeeb2bdc215de3c3cffca11408d829e4b8416dcfe71ba8854/prompt_toolkit-3.0.50.tar.gz" } ], "project_name": "prompt-toolkit", "requires_dists": [ "wcwidth" ], - "requires_python": ">=3.7.0", - "version": "3.0.48" + "requires_python": ">=3.8.0", + "version": "3.0.50" }, { "artifacts": [ @@ -3567,6 +3563,54 @@ "requires_python": ">=3.8", "version": "8.3.4" }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "63a5360fd2f34dda4ab8e6baee4c5f5be4cd186a403cabd498fced82ac9c561e", + "url": "https://files.pythonhosted.org/packages/9a/a7/6e50ba2c0a27a34859a952162e63362a13142ce3c646e925b76de440e102/pytest_aiohttp-1.0.5-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "880262bc5951e934463b15e3af8bb298f11f7d4d3ebac970aab425aff10a780a", + "url": "https://files.pythonhosted.org/packages/28/ad/7915ae42ca364a66708755517c5d669a7a4921d70d1070d3b660ea716a3e/pytest-aiohttp-1.0.5.tar.gz" + } + ], + "project_name": "pytest-aiohttp", + "requires_dists": [ + "aiohttp>=3.8.1", + "coverage==6.2; extra == \"testing\"", + "mypy==0.931; extra == \"testing\"", + "pytest-asyncio>=0.17.2", + "pytest>=6.1.0" + ], + "requires_python": ">=3.7", + "version": "1.0.5" + }, + { + "artifacts": [ + { + "algorithm": "sha256", + "hash": "0d0bb693f7b99da304a0634afc0a4b19e49d5e0de2d670f38dc4bfa5727c5075", + "url": "https://files.pythonhosted.org/packages/61/d8/defa05ae50dcd6019a95527200d3b3980043df5aa445d40cb0ef9f7f98ab/pytest_asyncio-0.25.2-py3-none-any.whl" + }, + { + "algorithm": "sha256", + "hash": "3f8ef9a98f45948ea91a0ed3dc4268b5326c0e7bce73892acc654df4262ad45f", + "url": "https://files.pythonhosted.org/packages/72/df/adcc0d60f1053d74717d21d58c0048479e9cab51464ce0d2965b086bd0e2/pytest_asyncio-0.25.2.tar.gz" + } + ], + "project_name": "pytest-asyncio", + "requires_dists": [ + "coverage>=6.2; extra == \"testing\"", + "hypothesis>=5.7.1; extra == \"testing\"", + "pytest<9,>=8.2", + "sphinx-rtd-theme>=1; extra == \"docs\"", + "sphinx>=5.3; extra == \"docs\"" + ], + "requires_python": ">=3.9", + "version": "0.25.2" + }, { "artifacts": [ { @@ -3929,22 +3973,22 @@ "artifacts": [ { "algorithm": "sha256", - "hash": "244a76a24355363a68164241438de1b72f8781664920260c48465896b712a41e", - "url": "https://files.pythonhosted.org/packages/66/05/7957af15543b8c9799209506df4660cba7afc4cf94bfb60513827e96bed6/s3transfer-0.10.4-py3-none-any.whl" + "hash": "8fa0aa48177be1f3425176dfe1ab85dcd3d962df603c3dbfc585e6bf857ef0ff", + "url": "https://files.pythonhosted.org/packages/5f/ce/22673f4a85ccc640735b4f8d12178a0f41b5d3c6eda7f33756d10ce56901/s3transfer-0.11.1-py3-none-any.whl" }, { "algorithm": "sha256", - "hash": "29edc09801743c21eb5ecbc617a152df41d3c287f67b615f73e5f750583666a7", - "url": "https://files.pythonhosted.org/packages/c0/0a/1cdbabf9edd0ea7747efdf6c9ab4e7061b085aa7f9bfc36bb1601563b069/s3transfer-0.10.4.tar.gz" + "hash": "3f25c900a367c8b7f7d8f9c34edc87e300bde424f779dc9f0a8ae4f9df9264f6", + "url": "https://files.pythonhosted.org/packages/1a/aa/fdd958c626b00e3f046d4004363e7f1a2aba4354f78d65ceb3b217fa5eb8/s3transfer-0.11.1.tar.gz" } ], "project_name": "s3transfer", "requires_dists": [ - "botocore<2.0a.0,>=1.33.2", - "botocore[crt]<2.0a.0,>=1.33.2; extra == \"crt\"" + "botocore<2.0a.0,>=1.36.0", + "botocore[crt]<2.0a.0,>=1.36.0; extra == \"crt\"" ], "requires_python": ">=3.8", - "version": "0.10.4" + "version": "0.11.1" }, { "artifacts": [ @@ -5110,6 +5154,7 @@ "pydantic~=2.9.2", "pyhumps~=3.8.0", "pyroscope-io~=0.8.8", + "pytest-aiohttp~=1.0.5", "pytest-dependency>=0.6.0", "pytest>=8.3.3", "python-dateutil>=2.9", diff --git a/requirements.txt b/requirements.txt index 1365f1064d..ef5064bc3c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -83,6 +83,7 @@ zipstream-new~=1.1.8 # required by ai.backend.test (integration test suite) pytest>=8.3.3 +pytest-aiohttp~=1.0.5 pytest-dependency>=0.6.0 # type stubs diff --git a/tests/common/conftest.py b/tests/common/conftest.py index 532c2b44c7..68ffeae674 100644 --- a/tests/common/conftest.py +++ b/tests/common/conftest.py @@ -1,7 +1,9 @@ import asyncio import secrets import time +import uuid from decimal import Decimal +from unittest.mock import MagicMock import pytest @@ -152,3 +154,21 @@ def allow_and_block_list(): @pytest.fixture def allow_and_block_list_has_union(): return {"cuda"}, {"cuda"} + + +@pytest.fixture +def mock_authenticated_request(): + mock_request = MagicMock() + mock_request["user"] = { + "uuid": uuid.uuid4(), + "role": "user", + "email": "test@email.com", + "domain_name": "default", + } + mock_request["keypair"] = { + "access_key": "TESTKEY", + "resource_policy": {"allowed_vfolder_hosts": ["local"]}, + } + vfolder_id = str(uuid.uuid4()) + mock_request.match_info = {"vfolder_id": vfolder_id} + return mock_request diff --git a/tests/common/test_pydantic_handlers.py b/tests/common/test_pydantic_handlers.py new file mode 100644 index 0000000000..c4ec9904cb --- /dev/null +++ b/tests/common/test_pydantic_handlers.py @@ -0,0 +1,304 @@ +from typing import Optional, Self + +import pytest +from aiohttp import web +from pydantic import BaseModel, Field + +from ai.backend.common.pydantic_handlers import ( + BaseResponse, + BodyParam, + HeaderParam, + MiddlewareParam, + PathParam, + QueryParam, + pydantic_api_handler, +) + + +class PostUserModel(BaseModel): + name: str + age: int + + +class PostUserResponse(BaseModel): + name: str + age: int + + +@pytest.mark.asyncio +async def test_body_parameter(aiohttp_client): + @pydantic_api_handler + async def handler(user: BodyParam[PostUserModel]): + parsed_user = user.parsed + return BaseResponse( + status_code=200, data=PostUserResponse(name=parsed_user.name, age=parsed_user.age) + ) + + app = web.Application() + app.router.add_route("POST", "/test", handler) + + client = await aiohttp_client(app) + + test_data = {"name": "John", "age": 30} + resp = await client.post("/test", json=test_data) + + assert resp.status == 200 + data = await resp.json() + assert data["name"] == "John" + assert data["age"] == 30 + + +class SearchQueryModel(BaseModel): + search: str + page: Optional[int] = Field(default=1) + + +class SearchQueryResponse(BaseModel): + search: str + page: Optional[int] = Field(default=1) + + +@pytest.mark.asyncio +async def test_query_parameter(aiohttp_client): + @pydantic_api_handler + async def handler(query: QueryParam[SearchQueryModel]): + parsed_query = query.parsed + return BaseResponse( + data=SearchQueryResponse(search=parsed_query.search, page=parsed_query.page) + ) + + app = web.Application() + app.router.add_get("/test", handler) + + client = await aiohttp_client(app) + resp = await client.get("/test?search=test&page=2") + + assert resp.status == 200 + data = await resp.json() + assert data["search"] == "test" + assert data["page"] == 2 + + +class AuthHeaderModel(BaseModel): + authorization: str + + +class AuthHeaderResponse(BaseModel): + authorization: str + + +@pytest.mark.asyncio +async def test_header_parameter(aiohttp_client): + @pydantic_api_handler + async def handler(headers: HeaderParam[AuthHeaderModel]): + parsed_headers = headers.parsed + return BaseResponse(data=AuthHeaderResponse(authorization=parsed_headers.authorization)) + + app = web.Application() + app.router.add_get("/test", handler) + + client = await aiohttp_client(app) + headers = {"Authorization": "Bearer token123"} + resp = await client.get("/test", headers=headers) + + assert resp.status == 200 + data = await resp.json() + assert data["authorization"] == "Bearer token123" + + +class UserPathModel(BaseModel): + user_id: str + + +class UserPathResponse(BaseModel): + user_id: str + + +@pytest.mark.asyncio +async def test_path_parameter(aiohttp_client): + @pydantic_api_handler + async def handler(path: PathParam[UserPathModel]): + parsed_path = path.parsed + return BaseResponse(data=UserPathResponse(user_id=parsed_path.user_id)) + + app = web.Application() + app.router.add_get("/test/{user_id}", handler) + + client = await aiohttp_client(app) + resp = await client.get("/test/123") + + assert resp.status == 200 + data = await resp.json() + assert data["user_id"] == "123" + + +class AuthInfo(MiddlewareParam): + is_authorized: bool = Field(default=False) + + @classmethod + def from_request(cls, request: web.Request) -> Self: + return cls(is_authorized=request.get("is_authorized", False)) + + +class AuthResponse(BaseModel): + is_authorized: bool = Field(default=False) + + +@pytest.mark.asyncio +async def test_middleware_parameter(aiohttp_client): + @pydantic_api_handler + async def handler(auth: AuthInfo): + return BaseResponse(data=AuthResponse(is_authorized=auth.is_authorized)) + + @web.middleware + async def auth_middleware(request, handler): + request["is_authorized"] = True + return await handler(request) + + app = web.Application() + app.middlewares.append(auth_middleware) + app.router.add_get("/test", handler) + client = await aiohttp_client(app) + + resp = await client.get("/test") + + assert resp.status == 200 + data = await resp.json() + assert data["is_authorized"] + + +@pytest.mark.asyncio +async def test_middleware_parameter_invalid_type(aiohttp_client): + @pydantic_api_handler + async def handler(auth: AuthInfo): + return BaseResponse(data=AuthResponse(is_authorized=auth.is_authorized)) + + @web.middleware + async def broken_auth_middleware(request, handler): + request["is_authorized"] = "not_a_boolean" + return await handler(request) + + app = web.Application() + app.middlewares.append(broken_auth_middleware) + app.router.add_get("/test", handler) + client = await aiohttp_client(app) + + resp = await client.get("/test") + assert resp.status == 500 + + error_data = await resp.json() + assert error_data["type"] == "https://api.backend.ai/probs/internal-server-error" + assert "Middleware parameter parsing failed" in error_data["title"] + + +class MiddlewareTestModel(MiddlewareParam): + is_authorized: bool + + @classmethod + def from_request(cls, request: web.Request) -> Self: + return cls(is_authorized=request.get("is_authorized", False)) + + +class CreateUserModel(BaseModel): + user_name: str + + +class SearchParamModel(BaseModel): + query: str + + +class CombinedResponse(BaseModel): + user_name: str + query: str + is_authorized: bool + + +@pytest.mark.asyncio +async def test_multiple_parameters(aiohttp_client): + @pydantic_api_handler + async def handler( + body: BodyParam[CreateUserModel], + auth: MiddlewareTestModel, + query: QueryParam[SearchParamModel], + ): + parsed_body = body.parsed + parsed_query = query.parsed + + return BaseResponse( + data=CombinedResponse( + user_name=parsed_body.user_name, + query=parsed_query.query, + is_authorized=auth.is_authorized, + ) + ) + + @web.middleware + async def auth_middleware(request, handler): + request["is_authorized"] = True + return await handler(request) + + app = web.Application() + app.middlewares.append(auth_middleware) + app.router.add_post("/test", handler) + + client = await aiohttp_client(app) + test_data = {"user_name": "John"} + resp = await client.post("/test?query=yes", json=test_data) + + assert resp.status == 200 + data = await resp.json() + assert data["user_name"] == "John" + assert data["query"] == "yes" + assert data["is_authorized"] + + +class RegisterUserModel(BaseModel): + name: str + age: int + + +class RegisterUserResponse(BaseModel): + name: str + age: int + + +@pytest.mark.asyncio +async def test_invalid_body(aiohttp_client): + @pydantic_api_handler + async def handler(user: BodyParam[RegisterUserModel]): + test_user = user.parsed + return BaseResponse(data=RegisterUserResponse(name=test_user.name, age=test_user.age)) + + app = web.Application() + app.router.add_post("/test", handler) + client = await aiohttp_client(app) + + test_data = {"name": "John"} # age field missing + error_response = await client.post("/test", json=test_data) + assert error_response.status == 400 + + +class ProductSearchModel(BaseModel): + search: str + page: Optional[int] = Field(default=1) + + +class ProductSearchResponse(BaseModel): + search: str + page: Optional[int] = Field(default=1) + + +@pytest.mark.asyncio +async def test_invalid_query_parameter(aiohttp_client): + @pydantic_api_handler + async def handler(query: QueryParam[ProductSearchModel]): + parsed_query = query.parsed + return BaseResponse( + data=ProductSearchResponse(search=parsed_query.search, page=parsed_query.page) + ) + + app = web.Application() + app.router.add_get("/test", handler) + client = await aiohttp_client(app) + error_response = await client.get("/test") # request with no query parameter + assert error_response.status == 400 # InvalidAPIParameters Error raised From 1683901ebf7260b498c876ca9a1469f1effa3f55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Thu, 23 Jan 2025 19:00:03 +0900 Subject: [PATCH 06/17] doc: add annotation about how to use pydantic_api_handler --- src/ai/backend/common/pydantic_handlers.py | 30 +++++++++++----------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index 2bb0076fd4..dc85cfaee2 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -200,31 +200,30 @@ async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: """ This decorator processes HTTP request parameters using Pydantic models. -It supports four types of parameters: -1. Request Body (automatically parsed as JSON/YAML): +1. Request Body: @pydantic_api_handler async def handler(body: BodyParam[UserModel]): # UserModel is a Pydantic model user = body.parsed # 'parsed' property gets pydantic model you defined - return BaseResponse(user=user) + return BaseResponse(data=YourResponseModel(user=user.id)) 2. Query Parameters: @pydantic_api_handler async def handler(query: QueryParam[QueryPathModel]): - query_path = query.parsed - return Response(query=query_path) + parsed_query = query.parsed + return BaseResponse(data=YourResponseModel(search=parsed_query.query)) 3. Headers: @pydantic_api_handler async def handler(headers: HeaderParam[HeaderModel]): parsed_header = headers.parsed - return Response(headers=parsed_headers) + return BaseResponse(data=YourResponseModel(data=parsed_header.token)) 4. Path Parameters: @pydantic_api_handler async def handler(path: PathModel = PathParam(PathModel)): parsed_path = path.parsed - return Response(path=parsed_path) + return BaseResponse(data=YourResponseModel(path=parsed_path)) 5. Middleware Parameters: # Need to extend MiddlewareParam and implement 'from_request' @@ -239,8 +238,8 @@ def from_request(cls, request: web.Request) -> Self: return cls(user_id=user_id) @pydantic_api_handler - async def handler(auth: AuthMiddlewareParam): # No generic - return Response(user_id=auth.user_id) + async def handler(auth: AuthMiddlewareParam): # No generic, so no need to call 'parsed' + return BaseResponse(data=YourResponseModel(author_name=auth.name)) 6. Multiple Parameters: @pydantic_api_handler @@ -250,16 +249,17 @@ async def handler( headers: HeaderParam[HeaderModel], # headers auth: AuthMiddleware, # middleware parameter ): - return Response( - user=user.parsed.user_id, - query=query.parsed.page, - headers=headers.parsed.auth, - user_id=auth.user_id + return BaseResponse(data=YourResponseModel( + user=user.parsed.user_id, + query=query.parsed.page, + headers=headers.parsed.auth, + user_id=auth.user_id + ) ) Note: - All parameters must have type hints or wrapped by Annotated -- Response must be a Pydantic model (Use BaseResponse) +- Response class must be BaseResponse. put your response model in BaseResponse.data - Request body is parsed must be json format - MiddlewareParam classes must implement the from_request classmethod """ From 8fbd2b6de2f78888855ab11aa1323fe8c59f9a77 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 24 Jan 2025 17:21:02 +0900 Subject: [PATCH 07/17] refactor: change param type matching using type system --- src/ai/backend/common/pydantic_handlers.py | 83 ++++++++++------------ tests/common/test_pydantic_handlers.py | 78 ++++++++++---------- 2 files changed, 75 insertions(+), 86 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index dc85cfaee2..41ca67d2d9 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -87,62 +87,52 @@ def from_request(cls, request: web.Request) -> Self: pass -@dataclass -class _ParsedSignature: - name: str - param_type: Any - - @dataclass class BaseResponse: data: BaseModel status_code: int = 200 -async def extract_param_value( - request: web.Request, parsed_signature: _ParsedSignature -) -> Optional[Any]: +async def _extract_param_value(request: web.Request, input_param_type: Any) -> Optional[Any]: try: - param_type = parsed_signature.param_type - # MiddlewareParam Type - if get_origin(param_type) is None and issubclass(param_type, MiddlewareParam): + if get_origin(input_param_type) is None and issubclass(input_param_type, MiddlewareParam): try: - return param_type.from_request(request) + return input_param_type.from_request(request) except ValidationError: - raise MiddlewareParamParsingFailed(f"Failed while parsing {parsed_signature.name}") + raise MiddlewareParamParsingFailed(f"Failed while parsing {input_param_type}") # If origin type name is BodyParam/QueryParam/HeaderParam/PathParam - origin_name = get_origin(param_type).__name__ - pydantic_model = get_args(param_type)[0] - param_instance = param_type(pydantic_model) - - match origin_name: - case "BodyParam": - if not request.can_read_body: - raise MalformedRequestBody( - f"Malformed body - URL: {request.url}, Method: {request.method}" - ) - try: - body = await request.json() - except json.decoder.JSONDecodeError: - raise MalformedRequestBody( - f"Malformed body - URL: {request.url}, Method: {request.method}" - ) - return param_instance.from_body(body) - - case "QueryParam": - return param_instance.from_query(request.query) - - case "HeaderParam": - return param_instance.from_header(request.headers) - - case "PathParam": - return param_instance.from_path(request.match_info) + origin_type = get_origin(input_param_type) + pydantic_model = get_args(input_param_type)[0] + param_instance = input_param_type(pydantic_model) + + if origin_type is BodyParam: + if not request.can_read_body: + raise MalformedRequestBody( + f"Malformed body - URL: {request.url}, Method: {request.method}" + ) + try: + body = await request.json() + except json.decoder.JSONDecodeError: + raise MalformedRequestBody( + f"Malformed body - URL: {request.url}, Method: {request.method}" + ) + return param_instance.from_body(body) - raise InvalidAPIParameters( - f"Parameter '{parsed_signature.name}' must use one of QueryParam, PathParam, HeaderParam, MiddlewareParam, BodyParam" - ) + elif origin_type is QueryParam: + return param_instance.from_query(request.query) + + elif origin_type is HeaderParam: + return param_instance.from_header(request.headers) + + elif origin_type is PathParam: + return param_instance.from_path(request.match_info) + + else: + raise InvalidAPIParameters( + f"Parameter '{input_param_type}' must use one of QueryParam, PathParam, HeaderParam, MiddlewareParam, BodyParam" + ) except ValidationError as e: raise InvalidAPIParameters(str(e)) @@ -160,7 +150,7 @@ def get_all(self) -> dict[str, Any]: return self.params -async def pydantic_handler(request: web.Request, handler) -> web.Response: +async def _pydantic_handler(request: web.Request, handler) -> web.Response: signature = inspect.signature(handler) handler_params = _HandlerParameters() for name, param in signature.parameters.items(): @@ -170,8 +160,7 @@ async def pydantic_handler(request: web.Request, handler) -> web.Response: f"Type hint or Annotated must be added in API handler signature: {param.name}" ) - parsed_signature = _ParsedSignature(name=name, param_type=param.annotation) - value = await extract_param_value(request=request, parsed_signature=parsed_signature) + value = await _extract_param_value(request=request, input_param_type=param.annotation) if not value: raise InvalidAPIParameters( @@ -193,7 +182,7 @@ async def pydantic_handler(request: web.Request, handler) -> web.Response: def pydantic_api_handler(handler): @functools.wraps(handler) async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: - return await pydantic_handler(request, handler) + return await _pydantic_handler(request, handler) return wrapped diff --git a/tests/common/test_pydantic_handlers.py b/tests/common/test_pydantic_handlers.py index c4ec9904cb..24f5db3c82 100644 --- a/tests/common/test_pydantic_handlers.py +++ b/tests/common/test_pydantic_handlers.py @@ -15,12 +15,12 @@ ) -class PostUserModel(BaseModel): +class TestPostUserModel(BaseModel): name: str age: int -class PostUserResponse(BaseModel): +class TestPostUserResponse(BaseModel): name: str age: int @@ -28,10 +28,10 @@ class PostUserResponse(BaseModel): @pytest.mark.asyncio async def test_body_parameter(aiohttp_client): @pydantic_api_handler - async def handler(user: BodyParam[PostUserModel]): + async def handler(user: BodyParam[TestPostUserModel]) -> BaseResponse: parsed_user = user.parsed return BaseResponse( - status_code=200, data=PostUserResponse(name=parsed_user.name, age=parsed_user.age) + status_code=200, data=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age) ) app = web.Application() @@ -48,12 +48,12 @@ async def handler(user: BodyParam[PostUserModel]): assert data["age"] == 30 -class SearchQueryModel(BaseModel): +class TestSearchQueryModel(BaseModel): search: str page: Optional[int] = Field(default=1) -class SearchQueryResponse(BaseModel): +class TestSearchQueryResponse(BaseModel): search: str page: Optional[int] = Field(default=1) @@ -61,10 +61,10 @@ class SearchQueryResponse(BaseModel): @pytest.mark.asyncio async def test_query_parameter(aiohttp_client): @pydantic_api_handler - async def handler(query: QueryParam[SearchQueryModel]): + async def handler(query: QueryParam[TestSearchQueryModel]) -> BaseResponse: parsed_query = query.parsed return BaseResponse( - data=SearchQueryResponse(search=parsed_query.search, page=parsed_query.page) + data=TestSearchQueryResponse(search=parsed_query.search, page=parsed_query.page) ) app = web.Application() @@ -79,20 +79,20 @@ async def handler(query: QueryParam[SearchQueryModel]): assert data["page"] == 2 -class AuthHeaderModel(BaseModel): +class TestAuthHeaderModel(BaseModel): authorization: str -class AuthHeaderResponse(BaseModel): +class TestAuthHeaderResponse(BaseModel): authorization: str @pytest.mark.asyncio async def test_header_parameter(aiohttp_client): @pydantic_api_handler - async def handler(headers: HeaderParam[AuthHeaderModel]): + async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> BaseResponse: parsed_headers = headers.parsed - return BaseResponse(data=AuthHeaderResponse(authorization=parsed_headers.authorization)) + return BaseResponse(data=TestAuthHeaderResponse(authorization=parsed_headers.authorization)) app = web.Application() app.router.add_get("/test", handler) @@ -106,20 +106,20 @@ async def handler(headers: HeaderParam[AuthHeaderModel]): assert data["authorization"] == "Bearer token123" -class UserPathModel(BaseModel): +class TestUserPathModel(BaseModel): user_id: str -class UserPathResponse(BaseModel): +class TestUserPathResponse(BaseModel): user_id: str @pytest.mark.asyncio async def test_path_parameter(aiohttp_client): @pydantic_api_handler - async def handler(path: PathParam[UserPathModel]): + async def handler(path: PathParam[TestUserPathModel]) -> BaseResponse: parsed_path = path.parsed - return BaseResponse(data=UserPathResponse(user_id=parsed_path.user_id)) + return BaseResponse(data=TestUserPathResponse(user_id=parsed_path.user_id)) app = web.Application() app.router.add_get("/test/{user_id}", handler) @@ -132,7 +132,7 @@ async def handler(path: PathParam[UserPathModel]): assert data["user_id"] == "123" -class AuthInfo(MiddlewareParam): +class TestAuthInfo(MiddlewareParam): is_authorized: bool = Field(default=False) @classmethod @@ -140,15 +140,15 @@ def from_request(cls, request: web.Request) -> Self: return cls(is_authorized=request.get("is_authorized", False)) -class AuthResponse(BaseModel): +class TestAuthResponse(BaseModel): is_authorized: bool = Field(default=False) @pytest.mark.asyncio async def test_middleware_parameter(aiohttp_client): @pydantic_api_handler - async def handler(auth: AuthInfo): - return BaseResponse(data=AuthResponse(is_authorized=auth.is_authorized)) + async def handler(auth: TestAuthInfo) -> BaseResponse: + return BaseResponse(data=TestAuthResponse(is_authorized=auth.is_authorized)) @web.middleware async def auth_middleware(request, handler): @@ -170,8 +170,8 @@ async def auth_middleware(request, handler): @pytest.mark.asyncio async def test_middleware_parameter_invalid_type(aiohttp_client): @pydantic_api_handler - async def handler(auth: AuthInfo): - return BaseResponse(data=AuthResponse(is_authorized=auth.is_authorized)) + async def handler(auth: TestAuthInfo) -> BaseResponse: + return BaseResponse(data=TestAuthResponse(is_authorized=auth.is_authorized)) @web.middleware async def broken_auth_middleware(request, handler): @@ -191,7 +191,7 @@ async def broken_auth_middleware(request, handler): assert "Middleware parameter parsing failed" in error_data["title"] -class MiddlewareTestModel(MiddlewareParam): +class TestMiddlewareModel(MiddlewareParam): is_authorized: bool @classmethod @@ -199,15 +199,15 @@ def from_request(cls, request: web.Request) -> Self: return cls(is_authorized=request.get("is_authorized", False)) -class CreateUserModel(BaseModel): +class TestCreateUserModel(BaseModel): user_name: str -class SearchParamModel(BaseModel): +class TestSearchParamModel(BaseModel): query: str -class CombinedResponse(BaseModel): +class TestCombinedResponse(BaseModel): user_name: str query: str is_authorized: bool @@ -217,15 +217,15 @@ class CombinedResponse(BaseModel): async def test_multiple_parameters(aiohttp_client): @pydantic_api_handler async def handler( - body: BodyParam[CreateUserModel], - auth: MiddlewareTestModel, - query: QueryParam[SearchParamModel], - ): + body: BodyParam[TestCreateUserModel], + auth: TestMiddlewareModel, + query: QueryParam[TestSearchParamModel], + ) -> BaseResponse: parsed_body = body.parsed parsed_query = query.parsed return BaseResponse( - data=CombinedResponse( + data=TestCombinedResponse( user_name=parsed_body.user_name, query=parsed_query.query, is_authorized=auth.is_authorized, @@ -252,12 +252,12 @@ async def auth_middleware(request, handler): assert data["is_authorized"] -class RegisterUserModel(BaseModel): +class TestRegisterUserModel(BaseModel): name: str age: int -class RegisterUserResponse(BaseModel): +class TestRegisterUserResponse(BaseModel): name: str age: int @@ -265,9 +265,9 @@ class RegisterUserResponse(BaseModel): @pytest.mark.asyncio async def test_invalid_body(aiohttp_client): @pydantic_api_handler - async def handler(user: BodyParam[RegisterUserModel]): + async def handler(user: BodyParam[TestRegisterUserModel]) -> BaseResponse: test_user = user.parsed - return BaseResponse(data=RegisterUserResponse(name=test_user.name, age=test_user.age)) + return BaseResponse(data=TestRegisterUserResponse(name=test_user.name, age=test_user.age)) app = web.Application() app.router.add_post("/test", handler) @@ -278,12 +278,12 @@ async def handler(user: BodyParam[RegisterUserModel]): assert error_response.status == 400 -class ProductSearchModel(BaseModel): +class TestProductSearchModel(BaseModel): search: str page: Optional[int] = Field(default=1) -class ProductSearchResponse(BaseModel): +class TestProductSearchResponse(BaseModel): search: str page: Optional[int] = Field(default=1) @@ -291,10 +291,10 @@ class ProductSearchResponse(BaseModel): @pytest.mark.asyncio async def test_invalid_query_parameter(aiohttp_client): @pydantic_api_handler - async def handler(query: QueryParam[ProductSearchModel]): + async def handler(query: QueryParam[TestProductSearchModel]) -> BaseResponse: parsed_query = query.parsed return BaseResponse( - data=ProductSearchResponse(search=parsed_query.search, page=parsed_query.page) + data=TestProductSearchResponse(search=parsed_query.search, page=parsed_query.page) ) app = web.Application() From 4a0d876514ef2f1456bb9448c30c2ff4ff3338c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 24 Jan 2025 17:50:56 +0900 Subject: [PATCH 08/17] doc: add changelog about new api handler decorator --- changes/3511.feature.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/3511.feature.md diff --git a/changes/3511.feature.md b/changes/3511.feature.md new file mode 100644 index 0000000000..8b8c102fb9 --- /dev/null +++ b/changes/3511.feature.md @@ -0,0 +1 @@ +Add new Pydantic handling api decorator for Request/Response validation \ No newline at end of file From e1740c14b5a0f732093e4a325d7e5df96ffe0c1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 10:53:31 +0900 Subject: [PATCH 09/17] refactor: change param parse value error to custom error --- src/ai/backend/common/exception.py | 5 +++++ src/ai/backend/common/pydantic_handlers.py | 15 ++++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/ai/backend/common/exception.py b/src/ai/backend/common/exception.py index dca1d2392e..ab7f0a6cf4 100644 --- a/src/ai/backend/common/exception.py +++ b/src/ai/backend/common/exception.py @@ -191,3 +191,8 @@ class InvalidAPIParameters(BackendError, web.HTTPBadRequest): class MiddlewareParamParsingFailed(BackendError, web.HTTPInternalServerError): error_type = "https://api.backend.ai/probs/internal-server-error" error_title = "Middleware parameter parsing failed." + + +class ParameterNotParsedError(BackendError, web.HTTPInternalServerError): + error_type = "https://api.backend.ai/probs/internal-server-error" + error_title = "Parameter Not Parsed Error" diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index 41ca67d2d9..5afb8e4ef9 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -11,7 +11,12 @@ from pydantic import BaseModel from pydantic_core._pydantic_core import ValidationError -from .exception import InvalidAPIParameters, MalformedRequestBody, MiddlewareParamParsingFailed +from .exception import ( + InvalidAPIParameters, + MalformedRequestBody, + MiddlewareParamParsingFailed, + ParameterNotParsedError, +) T = TypeVar("T", bound=BaseModel) @@ -24,7 +29,7 @@ def __init__(self, model: Type[T]) -> None: @property def parsed(self) -> T: if not self._parsed: - raise ValueError("Data not yet parsed") + raise ParameterNotParsedError() return self._parsed def from_body(self, json_body: str) -> Self: @@ -40,7 +45,7 @@ def __init__(self, model: Type[T]) -> None: @property def parsed(self) -> T: if not self._parsed: - raise ValueError("Data not yet parsed") + raise ParameterNotParsedError() return self._parsed def from_query(self, query: MultiMapping[str]) -> Self: @@ -56,7 +61,7 @@ def __init__(self, model: Type[T]) -> None: @property def parsed(self) -> T: if not self._parsed: - raise ValueError("Data not yet parsed") + raise ParameterNotParsedError() return self._parsed def from_header(self, headers: CIMultiDictProxy[str]) -> Self: @@ -72,7 +77,7 @@ def __init__(self, model: Type[T]) -> None: @property def parsed(self) -> T: if not self._parsed: - raise ValueError("Data not yet parsed") + raise ParameterNotParsedError() return self._parsed def from_path(self, match_info: UrlMappingMatchInfo) -> Self: From c32a4db4d228f8c0a3bc3c8eb17c6e5a00181cbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 11:02:52 +0900 Subject: [PATCH 10/17] refactor: remove default value from BaseResponse --- src/ai/backend/common/pydantic_handlers.py | 2 +- tests/common/test_pydantic_handlers.py | 27 +++++++++++++++------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index 5afb8e4ef9..c4a1fc2d02 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -95,7 +95,7 @@ def from_request(cls, request: web.Request) -> Self: @dataclass class BaseResponse: data: BaseModel - status_code: int = 200 + status_code: int async def _extract_param_value(request: web.Request, input_param_type: Any) -> Optional[Any]: diff --git a/tests/common/test_pydantic_handlers.py b/tests/common/test_pydantic_handlers.py index 24f5db3c82..2cd66268d5 100644 --- a/tests/common/test_pydantic_handlers.py +++ b/tests/common/test_pydantic_handlers.py @@ -64,7 +64,8 @@ async def test_query_parameter(aiohttp_client): async def handler(query: QueryParam[TestSearchQueryModel]) -> BaseResponse: parsed_query = query.parsed return BaseResponse( - data=TestSearchQueryResponse(search=parsed_query.search, page=parsed_query.page) + status_code=200, + data=TestSearchQueryResponse(search=parsed_query.search, page=parsed_query.page), ) app = web.Application() @@ -92,7 +93,9 @@ async def test_header_parameter(aiohttp_client): @pydantic_api_handler async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> BaseResponse: parsed_headers = headers.parsed - return BaseResponse(data=TestAuthHeaderResponse(authorization=parsed_headers.authorization)) + return BaseResponse( + status_code=200, data=TestAuthHeaderResponse(authorization=parsed_headers.authorization) + ) app = web.Application() app.router.add_get("/test", handler) @@ -119,7 +122,7 @@ async def test_path_parameter(aiohttp_client): @pydantic_api_handler async def handler(path: PathParam[TestUserPathModel]) -> BaseResponse: parsed_path = path.parsed - return BaseResponse(data=TestUserPathResponse(user_id=parsed_path.user_id)) + return BaseResponse(status_code=200, data=TestUserPathResponse(user_id=parsed_path.user_id)) app = web.Application() app.router.add_get("/test/{user_id}", handler) @@ -148,7 +151,9 @@ class TestAuthResponse(BaseModel): async def test_middleware_parameter(aiohttp_client): @pydantic_api_handler async def handler(auth: TestAuthInfo) -> BaseResponse: - return BaseResponse(data=TestAuthResponse(is_authorized=auth.is_authorized)) + return BaseResponse( + status_code=200, data=TestAuthResponse(is_authorized=auth.is_authorized) + ) @web.middleware async def auth_middleware(request, handler): @@ -171,7 +176,9 @@ async def auth_middleware(request, handler): async def test_middleware_parameter_invalid_type(aiohttp_client): @pydantic_api_handler async def handler(auth: TestAuthInfo) -> BaseResponse: - return BaseResponse(data=TestAuthResponse(is_authorized=auth.is_authorized)) + return BaseResponse( + status_code=200, data=TestAuthResponse(is_authorized=auth.is_authorized) + ) @web.middleware async def broken_auth_middleware(request, handler): @@ -225,11 +232,12 @@ async def handler( parsed_query = query.parsed return BaseResponse( + status_code=200, data=TestCombinedResponse( user_name=parsed_body.user_name, query=parsed_query.query, is_authorized=auth.is_authorized, - ) + ), ) @web.middleware @@ -267,7 +275,9 @@ async def test_invalid_body(aiohttp_client): @pydantic_api_handler async def handler(user: BodyParam[TestRegisterUserModel]) -> BaseResponse: test_user = user.parsed - return BaseResponse(data=TestRegisterUserResponse(name=test_user.name, age=test_user.age)) + return BaseResponse( + status_code=200, data=TestRegisterUserResponse(name=test_user.name, age=test_user.age) + ) app = web.Application() app.router.add_post("/test", handler) @@ -294,7 +304,8 @@ async def test_invalid_query_parameter(aiohttp_client): async def handler(query: QueryParam[TestProductSearchModel]) -> BaseResponse: parsed_query = query.parsed return BaseResponse( - data=TestProductSearchResponse(search=parsed_query.search, page=parsed_query.page) + status_code=200, + data=TestProductSearchResponse(search=parsed_query.search, page=parsed_query.page), ) app = web.Application() From ea2d96228087fcadebe42495357b5d78c3efb083 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 12:34:27 +0900 Subject: [PATCH 11/17] refactor: add type hints for param instance variables --- src/ai/backend/common/pydantic_handlers.py | 45 +++++++++++++++------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index c4a1fc2d02..b51990058a 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -22,8 +22,11 @@ class BodyParam(Generic[T]): + _model: Type[T] + _parsed: Optional[T] + def __init__(self, model: Type[T]) -> None: - self.model = model + self._model = model self._parsed: Optional[T] = None @property @@ -33,13 +36,16 @@ def parsed(self) -> T: return self._parsed def from_body(self, json_body: str) -> Self: - self._parsed = self.model.model_validate(json_body) + self._parsed = self._model.model_validate(json_body) return self class QueryParam(Generic[T]): + _model: Type[T] + _parsed: Optional[T] + def __init__(self, model: Type[T]) -> None: - self.model = model + self._model = model self._parsed: Optional[T] = None @property @@ -49,13 +55,16 @@ def parsed(self) -> T: return self._parsed def from_query(self, query: MultiMapping[str]) -> Self: - self._parsed = self.model.model_validate(query) + self._parsed = self._model.model_validate(query) return self class HeaderParam(Generic[T]): + _model: Type[T] + _parsed: Optional[T] + def __init__(self, model: Type[T]) -> None: - self.model = model + self._model = model self._parsed: Optional[T] = None @property @@ -65,13 +74,16 @@ def parsed(self) -> T: return self._parsed def from_header(self, headers: CIMultiDictProxy[str]) -> Self: - self._parsed = self.model.model_validate(headers) + self._parsed = self._model.model_validate(headers) return self class PathParam(Generic[T]): + _model: Type[T] + _parsed: Optional[T] + def __init__(self, model: Type[T]) -> None: - self.model = model + self._model = model self._parsed: Optional[T] = None @property @@ -81,7 +93,7 @@ def parsed(self) -> T: return self._parsed def from_path(self, match_info: UrlMappingMatchInfo) -> Self: - self._parsed = self.model.model_validate(match_info) + self._parsed = self._model.model_validate(match_info) return self @@ -98,7 +110,10 @@ class BaseResponse: status_code: int -async def _extract_param_value(request: web.Request, input_param_type: Any) -> Optional[Any]: +_ParamType = BodyParam | QueryParam | PathParam | HeaderParam | MiddlewareParam + + +async def _extract_param_value(request: web.Request, input_param_type: Any) -> _ParamType: try: # MiddlewareParam Type if get_origin(input_param_type) is None and issubclass(input_param_type, MiddlewareParam): @@ -144,15 +159,17 @@ async def _extract_param_value(request: web.Request, input_param_type: Any) -> O class _HandlerParameters: + _params: dict[str, _ParamType] + def __init__(self) -> None: - self.params: dict[str, Any] = {} + self._params: dict[str, _ParamType] = {} - def add(self, name: str, value: Any) -> None: + def add(self, name: str, value: _ParamType) -> None: if value is not None: - self.params[name] = value + self._params[name] = value - def get_all(self) -> dict[str, Any]: - return self.params + def get_all(self) -> dict[str, _ParamType]: + return self._params async def _pydantic_handler(request: web.Request, handler) -> web.Response: From e5af16e2aecbe8d78f1a743aa0488ee72a33c435 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 13:31:29 +0900 Subject: [PATCH 12/17] style: move decorator usage annotation into decorator func --- src/ai/backend/common/pydantic_handlers.py | 133 ++++++++++----------- 1 file changed, 66 insertions(+), 67 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index b51990058a..f0205f2d02 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -202,75 +202,74 @@ async def _pydantic_handler(request: web.Request, handler) -> web.Response: def pydantic_api_handler(handler): + """ + This decorator processes HTTP request parameters using Pydantic models. + + 1. Request Body: + @pydantic_api_handler + async def handler(body: BodyParam[UserModel]): # UserModel is a Pydantic model + user = body.parsed # 'parsed' property gets pydantic model you defined + return BaseResponse(data=YourResponseModel(user=user.id)) + + 2. Query Parameters: + @pydantic_api_handler + async def handler(query: QueryParam[QueryPathModel]): + parsed_query = query.parsed + return BaseResponse(data=YourResponseModel(search=parsed_query.query)) + + 3. Headers: + @pydantic_api_handler + async def handler(headers: HeaderParam[HeaderModel]): + parsed_header = headers.parsed + return BaseResponse(data=YourResponseModel(data=parsed_header.token)) + + 4. Path Parameters: + @pydantic_api_handler + async def handler(path: PathModel = PathParam(PathModel)): + parsed_path = path.parsed + return BaseResponse(data=YourResponseModel(path=parsed_path)) + + 5. Middleware Parameters: + # Need to extend MiddlewareParam and implement 'from_request' + class AuthMiddlewareParam(MiddlewareParam): + user_id: str + user_email: str + @classmethod + def from_request(cls, request: web.Request) -> Self: + # Extract and validate data from request + user_id = request["user"]["uuid"] + user_email = request["user"]["email"] + return cls(user_id=user_id) + + @pydantic_api_handler + async def handler(auth: AuthMiddlewareParam): # No generic, so no need to call 'parsed' + return BaseResponse(data=YourResponseModel(author_name=auth.name)) + + 6. Multiple Parameters: + @pydantic_api_handler + async def handler( + user: BodyParam[UserModel], # body + query: QueryParam[QueryModel], # query parameters + headers: HeaderParam[HeaderModel], # headers + auth: AuthMiddleware, # middleware parameter + ): + return BaseResponse(data=YourResponseModel( + user=user.parsed.user_id, + query=query.parsed.page, + headers=headers.parsed.auth, + user_id=auth.user_id + ) + ) + + Note: + - All parameters must have type hints or wrapped by Annotated + - Response class must be BaseResponse. put your response model in BaseResponse.data + - Request body is parsed must be json format + - MiddlewareParam classes must implement the from_request classmethod + """ + @functools.wraps(handler) async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: return await _pydantic_handler(request, handler) return wrapped - - -""" -This decorator processes HTTP request parameters using Pydantic models. - -1. Request Body: - @pydantic_api_handler - async def handler(body: BodyParam[UserModel]): # UserModel is a Pydantic model - user = body.parsed # 'parsed' property gets pydantic model you defined - return BaseResponse(data=YourResponseModel(user=user.id)) - -2. Query Parameters: - @pydantic_api_handler - async def handler(query: QueryParam[QueryPathModel]): - parsed_query = query.parsed - return BaseResponse(data=YourResponseModel(search=parsed_query.query)) - -3. Headers: - @pydantic_api_handler - async def handler(headers: HeaderParam[HeaderModel]): - parsed_header = headers.parsed - return BaseResponse(data=YourResponseModel(data=parsed_header.token)) - -4. Path Parameters: - @pydantic_api_handler - async def handler(path: PathModel = PathParam(PathModel)): - parsed_path = path.parsed - return BaseResponse(data=YourResponseModel(path=parsed_path)) - -5. Middleware Parameters: - # Need to extend MiddlewareParam and implement 'from_request' - class AuthMiddlewareParam(MiddlewareParam): - user_id: str - user_email: str - @classmethod - def from_request(cls, request: web.Request) -> Self: - # Extract and validate data from request - user_id = request["user"]["uuid"] - user_email = request["user"]["email"] - return cls(user_id=user_id) - - @pydantic_api_handler - async def handler(auth: AuthMiddlewareParam): # No generic, so no need to call 'parsed' - return BaseResponse(data=YourResponseModel(author_name=auth.name)) - -6. Multiple Parameters: - @pydantic_api_handler - async def handler( - user: BodyParam[UserModel], # body - query: QueryParam[QueryModel], # query parameters - headers: HeaderParam[HeaderModel], # headers - auth: AuthMiddleware, # middleware parameter - ): - return BaseResponse(data=YourResponseModel( - user=user.parsed.user_id, - query=query.parsed.page, - headers=headers.parsed.auth, - user_id=auth.user_id - ) - ) - -Note: -- All parameters must have type hints or wrapped by Annotated -- Response class must be BaseResponse. put your response model in BaseResponse.data -- Request body is parsed must be json format -- MiddlewareParam classes must implement the from_request classmethod -""" From 2809dd6ba52efd1949c974897de2613bbaaef4d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 13:35:26 +0900 Subject: [PATCH 13/17] style: fix annotation about status code defining method in BaseResponse --- src/ai/backend/common/pydantic_handlers.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index f0205f2d02..bb409f49cf 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -209,25 +209,25 @@ def pydantic_api_handler(handler): @pydantic_api_handler async def handler(body: BodyParam[UserModel]): # UserModel is a Pydantic model user = body.parsed # 'parsed' property gets pydantic model you defined - return BaseResponse(data=YourResponseModel(user=user.id)) + return BaseResponse(status_code=200, data=YourResponseModel(user=user.id)) 2. Query Parameters: @pydantic_api_handler async def handler(query: QueryParam[QueryPathModel]): parsed_query = query.parsed - return BaseResponse(data=YourResponseModel(search=parsed_query.query)) + return BaseResponse(status_code=200, data=YourResponseModel(search=parsed_query.query)) 3. Headers: @pydantic_api_handler async def handler(headers: HeaderParam[HeaderModel]): parsed_header = headers.parsed - return BaseResponse(data=YourResponseModel(data=parsed_header.token)) + return BaseResponse(status_code=200, data=YourResponseModel(data=parsed_header.token)) 4. Path Parameters: @pydantic_api_handler async def handler(path: PathModel = PathParam(PathModel)): parsed_path = path.parsed - return BaseResponse(data=YourResponseModel(path=parsed_path)) + return BaseResponse(status_code=200, data=YourResponseModel(path=parsed_path)) 5. Middleware Parameters: # Need to extend MiddlewareParam and implement 'from_request' @@ -243,7 +243,7 @@ def from_request(cls, request: web.Request) -> Self: @pydantic_api_handler async def handler(auth: AuthMiddlewareParam): # No generic, so no need to call 'parsed' - return BaseResponse(data=YourResponseModel(author_name=auth.name)) + return BaseResponse(status_code=200, data=YourResponseModel(author_name=auth.name)) 6. Multiple Parameters: @pydantic_api_handler @@ -253,7 +253,9 @@ async def handler( headers: HeaderParam[HeaderModel], # headers auth: AuthMiddleware, # middleware parameter ): - return BaseResponse(data=YourResponseModel( + return BaseResponse( + status_code=200, + data=YourResponseModel( user=user.parsed.user_id, query=query.parsed.page, headers=headers.parsed.auth, From 93cd54aa54725f21e0ac7f8b86fd42dc256320e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 13:50:53 +0900 Subject: [PATCH 14/17] feat: enhance error messages with type information in Param Classes --- src/ai/backend/common/pydantic_handlers.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index bb409f49cf..04369038a5 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -32,7 +32,9 @@ def __init__(self, model: Type[T]) -> None: @property def parsed(self) -> T: if not self._parsed: - raise ParameterNotParsedError() + raise ParameterNotParsedError( + f"Parameter of type {self._model.__name__} has not been parsed yet" + ) return self._parsed def from_body(self, json_body: str) -> Self: @@ -51,7 +53,9 @@ def __init__(self, model: Type[T]) -> None: @property def parsed(self) -> T: if not self._parsed: - raise ParameterNotParsedError() + raise ParameterNotParsedError( + f"Parameter of type {self._model.__name__} has not been parsed yet" + ) return self._parsed def from_query(self, query: MultiMapping[str]) -> Self: @@ -70,7 +74,9 @@ def __init__(self, model: Type[T]) -> None: @property def parsed(self) -> T: if not self._parsed: - raise ParameterNotParsedError() + raise ParameterNotParsedError( + f"Parameter of type {self._model.__name__} has not been parsed yet" + ) return self._parsed def from_header(self, headers: CIMultiDictProxy[str]) -> Self: @@ -89,7 +95,9 @@ def __init__(self, model: Type[T]) -> None: @property def parsed(self) -> T: if not self._parsed: - raise ParameterNotParsedError() + raise ParameterNotParsedError( + f"Parameter of type {self._model.__name__} has not been parsed yet" + ) return self._parsed def from_path(self, match_info: UrlMappingMatchInfo) -> Self: From 42900c10ae16893e5e479ad61ca9b5ecd53189e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 16:41:08 +0900 Subject: [PATCH 15/17] refactor: extend pydantic API decorator to support class-based handlers --- src/ai/backend/common/pydantic_handlers.py | 20 +++- tests/common/test_pydantic_handlers.py | 120 +++++++++++++++++++++ 2 files changed, 136 insertions(+), 4 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index 04369038a5..a1b9bac04c 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -180,11 +180,11 @@ def get_all(self) -> dict[str, _ParamType]: return self._params -async def _pydantic_handler(request: web.Request, handler) -> web.Response: - signature = inspect.signature(handler) +async def _pydantic_handler(request: web.Request, handler, signature) -> web.Response: handler_params = _HandlerParameters() for name, param in signature.parameters.items(): - # Raise error when parameter has no type hint or not wrapped by 'Annotated' + # If handler has no parameter, for loop is skipped + # Raise error when parameter exists and has no type hint or not wrapped by 'Annotated' if param.annotation is inspect.Parameter.empty: raise InvalidAPIParameters( f"Type hint or Annotated must be added in API handler signature: {param.name}" @@ -278,8 +278,20 @@ async def handler( - MiddlewareParam classes must implement the from_request classmethod """ + original_signature = inspect.signature(handler) # 원본 시그니처 저장 + @functools.wraps(handler) async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: - return await _pydantic_handler(request, handler) + if isinstance(request, web.Request): + return await _pydantic_handler(request, handler, original_signature) + # 클래스의 인스턴스 메서드인 경우 + self = request + return await _pydantic_handler( + args[0], + lambda *a, **kw: handler(self, *a, **kw), + original_signature.replace( + parameters=list(original_signature.parameters.values())[1:] + ), # self 제외 + ) return wrapped diff --git a/tests/common/test_pydantic_handlers.py b/tests/common/test_pydantic_handlers.py index 2cd66268d5..b2280278fc 100644 --- a/tests/common/test_pydantic_handlers.py +++ b/tests/common/test_pydantic_handlers.py @@ -15,6 +15,126 @@ ) +class TestEmptyResponseModel(BaseModel): + status: str + version: str + + +class TestEmptyHandlerClass: + @pydantic_api_handler + async def handle_empty(self) -> BaseResponse: + return BaseResponse( + status_code=200, data=TestEmptyResponseModel(status="success", version="1.0.0") + ) + + +@pytest.mark.asyncio +async def test_empty_parameter_handler_in_class(aiohttp_client): + handler = TestEmptyHandlerClass() + app = web.Application() + app.router.add_get("/system", handler.handle_empty) + + client = await aiohttp_client(app) + resp = await client.get("/system") + + assert resp.status == 200 + data = await resp.json() + assert data["status"] == "success" + assert data["version"] == "1.0.0" + + +class TestUserRequestModel(BaseModel): + username: str + email: str + age: int + + +class TestSearchParamsModel(BaseModel): + keyword: str + category: Optional[str] = Field(default="all") + limit: Optional[int] = Field(default=10) + + +class TestCombinedResponseModel(BaseModel): + user_info: dict + search_info: dict + timestamp: str + + +class CombinedParamsHandlerClass: + @pydantic_api_handler + async def handle_combined( + self, user: BodyParam[TestUserRequestModel], search: QueryParam[TestSearchParamsModel] + ) -> BaseResponse: + parsed_user = user.parsed + parsed_search = search.parsed + + return BaseResponse( + status_code=200, + data=TestCombinedResponseModel( + user_info={ + "username": parsed_user.username, + "email": parsed_user.email, + "age": parsed_user.age, + }, + search_info={ + "keyword": parsed_search.keyword, + "category": parsed_search.category, + "limit": parsed_search.limit, + }, + timestamp="2024-01-31T00:00:00Z", + ), + ) + + +@pytest.mark.asyncio +async def test_combined_parameters_handler_in_class(aiohttp_client): + handler = CombinedParamsHandlerClass() + app = web.Application() + app.router.add_post("/users/search", handler.handle_combined) + + client = await aiohttp_client(app) + + test_user_data = {"username": "john_doe", "email": "john@example.com", "age": 30} + + resp = await client.post( + "/users/search?keyword=python&category=programming&limit=20", json=test_user_data + ) + + assert resp.status == 200 + data = await resp.json() + + assert data["user_info"]["username"] == "john_doe" + assert data["user_info"]["email"] == "john@example.com" + assert data["user_info"]["age"] == 30 + + assert data["search_info"]["keyword"] == "python" + assert data["search_info"]["category"] == "programming" + assert data["search_info"]["limit"] == 20 + + +class TestMessageResponse(BaseModel): + message: str + + +@pytest.mark.asyncio +async def test_empty_parameter(aiohttp_client): + @pydantic_api_handler + async def handler() -> BaseResponse: + return BaseResponse(status_code=200, data=TestMessageResponse(message="test")) + + app = web.Application() + app.router.add_route("GET", "/test", handler) + + client = await aiohttp_client(app) + + resp = await client.get("/test") + + assert resp.status == 200 + data = await resp.json() + assert data["message"] == "test" + + class TestPostUserModel(BaseModel): name: str age: int From 576ed8d5eb0382e50fec9a1926aa027ad9da16c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 16:55:03 +0900 Subject: [PATCH 16/17] style: replace annotation language into English --- src/ai/backend/common/pydantic_handlers.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index a1b9bac04c..515c6a491c 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -278,20 +278,23 @@ async def handler( - MiddlewareParam classes must implement the from_request classmethod """ - original_signature = inspect.signature(handler) # 원본 시그니처 저장 + original_signature = inspect.signature(handler) @functools.wraps(handler) async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: if isinstance(request, web.Request): return await _pydantic_handler(request, handler, original_signature) - # 클래스의 인스턴스 메서드인 경우 + + # If handler is method defined in class + # Remove 'self' in parameters self = request + sanitized_signature = original_signature.replace( + parameters=list(original_signature.parameters.values())[1:] + ) return await _pydantic_handler( - args[0], - lambda *a, **kw: handler(self, *a, **kw), - original_signature.replace( - parameters=list(original_signature.parameters.values())[1:] - ), # self 제외 + request=args[0], + handler=lambda *a, **kw: handler(self, *a, **kw), + signature=sanitized_signature, ) return wrapped From 9519d2974f9513ce4f33099e18fa21d3335e72f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?BoGyum=20Kim=20=7C=20=EA=B9=80=EB=B3=B4=EA=B2=B8?= Date: Fri, 31 Jan 2025 18:41:44 +0900 Subject: [PATCH 17/17] refactor: improve http response abstraction --- src/ai/backend/common/pydantic_handlers.py | 46 ++++++--- tests/common/test_pydantic_handlers.py | 105 ++++++++++++--------- 2 files changed, 90 insertions(+), 61 deletions(-) diff --git a/src/ai/backend/common/pydantic_handlers.py b/src/ai/backend/common/pydantic_handlers.py index 515c6a491c..2ed279a53b 100644 --- a/src/ai/backend/common/pydantic_handlers.py +++ b/src/ai/backend/common/pydantic_handlers.py @@ -112,10 +112,22 @@ def from_request(cls, request: web.Request) -> Self: pass +class BaseResponseModel(BaseModel): + pass + + @dataclass -class BaseResponse: - data: BaseModel - status_code: int +class ApiResponse: + _status_code: int + _data: Optional[BaseResponseModel] + + @classmethod + def build(cls, status_code: int, response_model: BaseResponseModel) -> Self: + return cls(_status_code=status_code, _data=response_model) + + @classmethod + def no_content(cls, status_code: int): + return cls(_status_code=status_code, _data=None) _ParamType = BodyParam | QueryParam | PathParam | HeaderParam | MiddlewareParam @@ -201,12 +213,15 @@ async def _pydantic_handler(request: web.Request, handler, signature) -> web.Res response = await handler(**handler_params.get_all()) - if not isinstance(response, BaseResponse): + if not isinstance(response, ApiResponse): raise InvalidAPIParameters( - f"Only Response wrapped by BaseResponse Class can be handle: {type(response)}" + f"Only Response wrapped by ApiResponse Class can be handle: {type(response)}" ) - return web.json_response(response.data.model_dump(mode="json"), status=response.status_code) + return web.json_response( + response._data.model_dump(mode="json") if response._data else {}, + status=response._status_code, + ) def pydantic_api_handler(handler): @@ -217,25 +232,26 @@ def pydantic_api_handler(handler): @pydantic_api_handler async def handler(body: BodyParam[UserModel]): # UserModel is a Pydantic model user = body.parsed # 'parsed' property gets pydantic model you defined - return BaseResponse(status_code=200, data=YourResponseModel(user=user.id)) + # Response model should inherit BaseResponseModel + return ApiResponse.build(status_code=200, response_model=YourResponseModel(user=user.id)) 2. Query Parameters: @pydantic_api_handler async def handler(query: QueryParam[QueryPathModel]): parsed_query = query.parsed - return BaseResponse(status_code=200, data=YourResponseModel(search=parsed_query.query)) + return ApiResponse.build(status_code=200, response_model=YourResponseModel(search=parsed_query.query)) 3. Headers: @pydantic_api_handler async def handler(headers: HeaderParam[HeaderModel]): parsed_header = headers.parsed - return BaseResponse(status_code=200, data=YourResponseModel(data=parsed_header.token)) + return ApiResponse.build(status_code=200, response_model=YourResponseModel(data=parsed_header.token)) 4. Path Parameters: @pydantic_api_handler async def handler(path: PathModel = PathParam(PathModel)): parsed_path = path.parsed - return BaseResponse(status_code=200, data=YourResponseModel(path=parsed_path)) + return ApiResponse.build(status_code=200, response_model=YourResponseModel(path=parsed_path)) 5. Middleware Parameters: # Need to extend MiddlewareParam and implement 'from_request' @@ -251,7 +267,7 @@ def from_request(cls, request: web.Request) -> Self: @pydantic_api_handler async def handler(auth: AuthMiddlewareParam): # No generic, so no need to call 'parsed' - return BaseResponse(status_code=200, data=YourResponseModel(author_name=auth.name)) + return ApiResponse(status_code=200, response_model=YourResponseModel(author_name=auth.name)) 6. Multiple Parameters: @pydantic_api_handler @@ -261,9 +277,9 @@ async def handler( headers: HeaderParam[HeaderModel], # headers auth: AuthMiddleware, # middleware parameter ): - return BaseResponse( + return ApiResponse( status_code=200, - data=YourResponseModel( + response_model=YourResponseModel( user=user.parsed.user_id, query=query.parsed.page, headers=headers.parsed.auth, @@ -273,7 +289,7 @@ async def handler( Note: - All parameters must have type hints or wrapped by Annotated - - Response class must be BaseResponse. put your response model in BaseResponse.data + - Response class must be ApiResponse and your response model should inherit BaseResponseModel - Request body is parsed must be json format - MiddlewareParam classes must implement the from_request classmethod """ @@ -281,7 +297,7 @@ async def handler( original_signature = inspect.signature(handler) @functools.wraps(handler) - async def wrapped(request: web.Request, *args, **kwargs) -> web.Response: + async def wrapped(request: Any, *args, **kwargs) -> web.Response: if isinstance(request, web.Request): return await _pydantic_handler(request, handler, original_signature) diff --git a/tests/common/test_pydantic_handlers.py b/tests/common/test_pydantic_handlers.py index b2280278fc..12c61e2dcb 100644 --- a/tests/common/test_pydantic_handlers.py +++ b/tests/common/test_pydantic_handlers.py @@ -5,7 +5,8 @@ from pydantic import BaseModel, Field from ai.backend.common.pydantic_handlers import ( - BaseResponse, + ApiResponse, + BaseResponseModel, BodyParam, HeaderParam, MiddlewareParam, @@ -15,16 +16,17 @@ ) -class TestEmptyResponseModel(BaseModel): +class TestEmptyResponseModel(BaseResponseModel): status: str version: str class TestEmptyHandlerClass: @pydantic_api_handler - async def handle_empty(self) -> BaseResponse: - return BaseResponse( - status_code=200, data=TestEmptyResponseModel(status="success", version="1.0.0") + async def handle_empty(self) -> ApiResponse: + return ApiResponse.build( + status_code=200, + response_model=TestEmptyResponseModel(status="success", version="1.0.0"), ) @@ -55,7 +57,7 @@ class TestSearchParamsModel(BaseModel): limit: Optional[int] = Field(default=10) -class TestCombinedResponseModel(BaseModel): +class TestCombinedResponseModel(BaseResponseModel): user_info: dict search_info: dict timestamp: str @@ -65,13 +67,13 @@ class CombinedParamsHandlerClass: @pydantic_api_handler async def handle_combined( self, user: BodyParam[TestUserRequestModel], search: QueryParam[TestSearchParamsModel] - ) -> BaseResponse: + ) -> ApiResponse: parsed_user = user.parsed parsed_search = search.parsed - return BaseResponse( + return ApiResponse.build( status_code=200, - data=TestCombinedResponseModel( + response_model=TestCombinedResponseModel( user_info={ "username": parsed_user.username, "email": parsed_user.email, @@ -113,15 +115,17 @@ async def test_combined_parameters_handler_in_class(aiohttp_client): assert data["search_info"]["limit"] == 20 -class TestMessageResponse(BaseModel): +class TestMessageResponse(BaseResponseModel): message: str @pytest.mark.asyncio async def test_empty_parameter(aiohttp_client): @pydantic_api_handler - async def handler() -> BaseResponse: - return BaseResponse(status_code=200, data=TestMessageResponse(message="test")) + async def handler() -> ApiResponse: + return ApiResponse.build( + status_code=200, response_model=TestMessageResponse(message="test") + ) app = web.Application() app.router.add_route("GET", "/test", handler) @@ -140,7 +144,7 @@ class TestPostUserModel(BaseModel): age: int -class TestPostUserResponse(BaseModel): +class TestPostUserResponse(BaseResponseModel): name: str age: int @@ -148,10 +152,11 @@ class TestPostUserResponse(BaseModel): @pytest.mark.asyncio async def test_body_parameter(aiohttp_client): @pydantic_api_handler - async def handler(user: BodyParam[TestPostUserModel]) -> BaseResponse: + async def handler(user: BodyParam[TestPostUserModel]) -> ApiResponse: parsed_user = user.parsed - return BaseResponse( - status_code=200, data=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age) + return ApiResponse.build( + status_code=200, + response_model=TestPostUserResponse(name=parsed_user.name, age=parsed_user.age), ) app = web.Application() @@ -173,7 +178,7 @@ class TestSearchQueryModel(BaseModel): page: Optional[int] = Field(default=1) -class TestSearchQueryResponse(BaseModel): +class TestSearchQueryResponse(BaseResponseModel): search: str page: Optional[int] = Field(default=1) @@ -181,11 +186,13 @@ class TestSearchQueryResponse(BaseModel): @pytest.mark.asyncio async def test_query_parameter(aiohttp_client): @pydantic_api_handler - async def handler(query: QueryParam[TestSearchQueryModel]) -> BaseResponse: + async def handler(query: QueryParam[TestSearchQueryModel]) -> ApiResponse: parsed_query = query.parsed - return BaseResponse( + return ApiResponse.build( status_code=200, - data=TestSearchQueryResponse(search=parsed_query.search, page=parsed_query.page), + response_model=TestSearchQueryResponse( + search=parsed_query.search, page=parsed_query.page + ), ) app = web.Application() @@ -204,17 +211,18 @@ class TestAuthHeaderModel(BaseModel): authorization: str -class TestAuthHeaderResponse(BaseModel): +class TestAuthHeaderResponse(BaseResponseModel): authorization: str @pytest.mark.asyncio async def test_header_parameter(aiohttp_client): @pydantic_api_handler - async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> BaseResponse: + async def handler(headers: HeaderParam[TestAuthHeaderModel]) -> ApiResponse: parsed_headers = headers.parsed - return BaseResponse( - status_code=200, data=TestAuthHeaderResponse(authorization=parsed_headers.authorization) + return ApiResponse.build( + status_code=200, + response_model=TestAuthHeaderResponse(authorization=parsed_headers.authorization), ) app = web.Application() @@ -233,16 +241,18 @@ class TestUserPathModel(BaseModel): user_id: str -class TestUserPathResponse(BaseModel): +class TestUserPathResponse(BaseResponseModel): user_id: str @pytest.mark.asyncio async def test_path_parameter(aiohttp_client): @pydantic_api_handler - async def handler(path: PathParam[TestUserPathModel]) -> BaseResponse: + async def handler(path: PathParam[TestUserPathModel]) -> ApiResponse: parsed_path = path.parsed - return BaseResponse(status_code=200, data=TestUserPathResponse(user_id=parsed_path.user_id)) + return ApiResponse.build( + status_code=200, response_model=TestUserPathResponse(user_id=parsed_path.user_id) + ) app = web.Application() app.router.add_get("/test/{user_id}", handler) @@ -263,16 +273,16 @@ def from_request(cls, request: web.Request) -> Self: return cls(is_authorized=request.get("is_authorized", False)) -class TestAuthResponse(BaseModel): +class TestAuthResponse(BaseResponseModel): is_authorized: bool = Field(default=False) @pytest.mark.asyncio async def test_middleware_parameter(aiohttp_client): @pydantic_api_handler - async def handler(auth: TestAuthInfo) -> BaseResponse: - return BaseResponse( - status_code=200, data=TestAuthResponse(is_authorized=auth.is_authorized) + async def handler(auth: TestAuthInfo) -> ApiResponse: + return ApiResponse.build( + status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized) ) @web.middleware @@ -295,9 +305,9 @@ async def auth_middleware(request, handler): @pytest.mark.asyncio async def test_middleware_parameter_invalid_type(aiohttp_client): @pydantic_api_handler - async def handler(auth: TestAuthInfo) -> BaseResponse: - return BaseResponse( - status_code=200, data=TestAuthResponse(is_authorized=auth.is_authorized) + async def handler(auth: TestAuthInfo) -> ApiResponse: + return ApiResponse.build( + status_code=200, response_model=TestAuthResponse(is_authorized=auth.is_authorized) ) @web.middleware @@ -334,7 +344,7 @@ class TestSearchParamModel(BaseModel): query: str -class TestCombinedResponse(BaseModel): +class TestCombinedResponse(BaseResponseModel): user_name: str query: str is_authorized: bool @@ -347,13 +357,13 @@ async def handler( body: BodyParam[TestCreateUserModel], auth: TestMiddlewareModel, query: QueryParam[TestSearchParamModel], - ) -> BaseResponse: + ) -> ApiResponse: parsed_body = body.parsed parsed_query = query.parsed - return BaseResponse( + return ApiResponse.build( status_code=200, - data=TestCombinedResponse( + response_model=TestCombinedResponse( user_name=parsed_body.user_name, query=parsed_query.query, is_authorized=auth.is_authorized, @@ -385,7 +395,7 @@ class TestRegisterUserModel(BaseModel): age: int -class TestRegisterUserResponse(BaseModel): +class TestRegisterUserResponse(BaseResponseModel): name: str age: int @@ -393,10 +403,11 @@ class TestRegisterUserResponse(BaseModel): @pytest.mark.asyncio async def test_invalid_body(aiohttp_client): @pydantic_api_handler - async def handler(user: BodyParam[TestRegisterUserModel]) -> BaseResponse: + async def handler(user: BodyParam[TestRegisterUserModel]) -> ApiResponse: test_user = user.parsed - return BaseResponse( - status_code=200, data=TestRegisterUserResponse(name=test_user.name, age=test_user.age) + return ApiResponse.build( + status_code=200, + response_model=TestRegisterUserResponse(name=test_user.name, age=test_user.age), ) app = web.Application() @@ -413,7 +424,7 @@ class TestProductSearchModel(BaseModel): page: Optional[int] = Field(default=1) -class TestProductSearchResponse(BaseModel): +class TestProductSearchResponse(BaseResponseModel): search: str page: Optional[int] = Field(default=1) @@ -421,11 +432,13 @@ class TestProductSearchResponse(BaseModel): @pytest.mark.asyncio async def test_invalid_query_parameter(aiohttp_client): @pydantic_api_handler - async def handler(query: QueryParam[TestProductSearchModel]) -> BaseResponse: + async def handler(query: QueryParam[TestProductSearchModel]) -> ApiResponse: parsed_query = query.parsed - return BaseResponse( + return ApiResponse.build( status_code=200, - data=TestProductSearchResponse(search=parsed_query.search, page=parsed_query.page), + response_model=TestProductSearchResponse( + search=parsed_query.search, page=parsed_query.page + ), ) app = web.Application()