|
13 | 13 | import uvicorn |
14 | 14 | from aiohttp import ClientConnectorError |
15 | 15 | from fastapi import FastAPI, Request |
16 | | -from fastapi.responses import JSONResponse, StreamingResponse |
| 16 | +from fastapi.responses import JSONResponse, Response, StreamingResponse |
17 | 17 |
|
18 | 18 | from swift.llm import AdapterRequest, DeployArguments, InferArguments |
19 | 19 | from swift.llm.infer.protocol import EmbeddingRequest, MultiModalRequestMixin |
@@ -42,6 +42,9 @@ def get_infer_engine(args: InferArguments, template=None, **kwargs): |
42 | 42 | return SwiftInfer.get_infer_engine(args, template, **kwargs) |
43 | 43 |
|
44 | 44 | def _register_app(self): |
| 45 | + self.app.get('/health')(self.health) |
| 46 | + self.app.get('/ping')(self.ping) |
| 47 | + self.app.post('/ping')(self.ping) |
45 | 48 | self.app.get('/v1/models')(self.get_available_models) |
46 | 49 | self.app.post('/v1/chat/completions')(self.create_chat_completion) |
47 | 50 | self.app.post('/v1/completions')(self.create_completion) |
@@ -85,6 +88,17 @@ def _get_model_list(self): |
85 | 88 | model_list += [name for name in args.adapter_mapping.keys()] |
86 | 89 | return model_list |
87 | 90 |
|
| 91 | + async def health(self) -> Response: |
| 92 | + """Health check endpoint.""" |
| 93 | + if self.infer_engine is not None: |
| 94 | + return Response(status_code=200) |
| 95 | + else: |
| 96 | + return Response(status_code=503) |
| 97 | + |
| 98 | + async def ping(self) -> Response: |
| 99 | + """Ping check endpoint. Required for SageMaker compatibility.""" |
| 100 | + return await self.health() |
| 101 | + |
88 | 102 | async def get_available_models(self): |
89 | 103 | model_list = self._get_model_list() |
90 | 104 | data = [Model(id=model_id, owned_by=self.args.owned_by) for model_id in model_list] |
|
0 commit comments