Skip to content

Commit 349ce69

Browse files
committed
stash
1 parent 4275459 commit 349ce69

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

pkg/api/http/service/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def create_llm_model(self, model_data: dict) -> None:
3535
**model_data
3636
)
3737
)
38-
await self.ap.model_mgr.load_model(model_data)
38+
await self.ap.model_mgr.load_llm_model(model_data)
3939

4040
async def get_llm_model(self, model_uuid: str) -> dict | None:
4141
result = await self.ap.persistence_mgr.execute_async(
@@ -54,12 +54,12 @@ async def update_llm_model(self, model_uuid: str, model_data: dict) -> None:
5454
sqlalchemy.update(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid).values(**model_data)
5555
)
5656

57-
await self.ap.model_mgr.remove_model(model_uuid)
58-
await self.ap.model_mgr.load_model(model_data)
57+
await self.ap.model_mgr.remove_llm_model(model_uuid)
58+
await self.ap.model_mgr.load_llm_model(model_data)
5959

6060
async def delete_llm_model(self, model_uuid: str) -> None:
6161
await self.ap.persistence_mgr.execute_async(
6262
sqlalchemy.delete(persistence_model.LLMModel).where(persistence_model.LLMModel.uuid == model_uuid)
6363
)
6464

65-
await self.ap.model_mgr.remove_model(model_uuid)
65+
await self.ap.model_mgr.remove_llm_model(model_uuid)

pkg/provider/modelmgr/modelmgr.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

3-
import aiohttp
3+
import typing
44
import sqlalchemy
55

66
from . import entities, requester
77
from ...core import app
8+
from ...core import entities as core_entities
9+
from .. import entities as llm_entities
10+
from ..tools import entities as tools_entities
811
from ...discover import engine
912
from . import token
1013
from ...entity.persistence import model as persistence_model
@@ -58,14 +61,6 @@ def __init__(self, ap: app.Application):
5861
self.llm_models = []
5962
self.requester_components = []
6063
self.requester_dict = {}
61-
62-
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo:
63-
"""通过名称获取模型
64-
"""
65-
for model in self.model_list:
66-
if model.name == name:
67-
return model
68-
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
6964

7065
async def initialize(self):
7166
self.requester_components = self.ap.discover.get_components_by_kind('LLMAPIRequester')
@@ -92,9 +87,9 @@ async def load_model_from_db(self):
9287

9388
# load models
9489
for llm_model in llm_models:
95-
await self.load_model(llm_model)
90+
await self.load_llm_model(llm_model)
9691

97-
async def load_model(self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict):
92+
async def load_llm_model(self, model_info: persistence_model.LLMModel | sqlalchemy.Row[persistence_model.LLMModel] | dict):
9893
"""加载模型"""
9994

10095
if isinstance(model_info, sqlalchemy.Row):
@@ -113,10 +108,24 @@ async def load_model(self, model_info: persistence_model.LLMModel | sqlalchemy.R
113108
config=model_info.requester_config
114109
)
115110
)
116-
print(runtime_llm_model, runtime_llm_model.model_entity.name, "loaded")
117111
self.llm_models.append(runtime_llm_model)
118112

119-
async def remove_model(self, model_uuid: str):
113+
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated
114+
"""通过名称获取模型
115+
"""
116+
for model in self.model_list:
117+
if model.name == name:
118+
return model
119+
raise ValueError(f"无法确定模型 {name} 的信息,请在元数据中配置")
120+
121+
async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo:
122+
"""通过uuid获取模型"""
123+
for model in self.llm_models:
124+
if model.model_entity.uuid == uuid:
125+
return model
126+
raise ValueError(f"model {uuid} not found")
127+
128+
async def remove_llm_model(self, model_uuid: str):
120129
"""移除模型"""
121130
for model in self.llm_models:
122131
if model.model_entity.uuid == model_uuid:
@@ -136,3 +145,13 @@ def get_available_requester_info_by_name(self, name: str) -> dict | None:
136145
if component.metadata.name == name:
137146
return component.to_plain_dict()
138147
return None
148+
149+
async def invoke_llm(
150+
self,
151+
query: core_entities.Query,
152+
model_uuid: str,
153+
messages: list[llm_entities.Message],
154+
funcs: list[tools_entities.LLMFunction] = None,
155+
) -> llm_entities.Message:
156+
pass
157+

0 commit comments

Comments
 (0)