Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New feature] Add inference load balance controller for fastdeploy llm #2276

Open
wants to merge 9 commits into
base: llm
Choose a base branch
from
Prev Previous commit
Next Next commit
add fastdeploy ic for llm
rainyfly committed Nov 20, 2023
commit 5b60c202550e8045ee1ce69d2cc160576f6605da
37 changes: 31 additions & 6 deletions llm_ic/fastdeploy_ic/server/api.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,12 @@ async def ModelStreamInfer(self, request, context):
"""
try:
model_id = request.model_id
req_id = json.loads(request.input)['req_id']
input_dict = json.loads(request.input)
if 'req_id' not in input_dict:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "ModelStreamInfer: there is no req_id in request")
if 'ic_req_data' not in input_dict:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "ModelStreamInfer: there is no ic_req_data in request")
req_id = input_dict['req_id']
# Check whether req_id is repeated
# Warning: We only simply check whether there is any same req_id has been in,
# but we can not prevent two requests with the same req_id coming simultaneously.
@@ -54,10 +59,18 @@ async def ModelStreamInfer(self, request, context):
if data is None:
await asyncio.sleep(1)
continue
logger.info("ModelStreamInfer: req_id {}: response data: {}".format(req_id, data))
yield data
try:
if json.loads(data.output)['is_end'] == 1: # this request is done
output_dict = json.loads(data.output)
if time.time() - output_dict['ic_timestamp_tag'] > global_config.resonpse_timeout: # the response is invalid because of timeout, even maybe from previous request with same req_id
continue
del output_dict['ic_timestamp_tag']
data.output = json.dumps(output_dict)
logger.info("ModelStreamInfer: req_id {}: response data: {}".format(req_id, data))
yield data
# two cases denote the request is done
# 1. something error returned by server, but not normal result
# 2. is_end is 1
if ('is_end' not in output_dict) or (output_dict['is_end'] == 1):
# clear resource about this req, only req_id in map should be removed
await data_manager.remove_req_id_from_map(model_id, req_id)
return
@@ -116,12 +129,18 @@ async def ModelSendResponse(self, response_iterator, context):
res = json.loads(response.output)
model_id = res['ic_req_data']['model_id']
req_id = res['req_id']
# add timestamp for response
res['ic_timestamp_tag'] = time.time() # we add this to prevent that client recieves
# response for previous request due to:
# 1. use the same req_id by mistake
# 2. the client corresponding to previous request did not recieve all responses for some reason
response.output = json.dumps(res)
except:
logger.info("ModelSendResponse: req_id {}: Failed to read response data from inference server".format(req_id))
await context.abort(grpc.StatusCode.INTERNAL, "ModelSendResponse: req_id {}: Failed to read response data from inference server".format(req_id))
await data_manager.enque_response(model_id, req_id, response)
logger.info("ModelSendResponse: req_id {}: response data: {}".format(req_id, res))
if res['is_end'] == 1:
if ('is_end' not in res) or (res['is_end'] == 1):
return ic_pb2.ModelSendResponseResult()
if time.time() - response_start_time > global_config.resonpse_timeout:
await data_manager.clear_response(model_id, req_id)
@@ -146,12 +165,18 @@ async def ModelSendResponseList(self, response_list_iterator, context):
res = json.loads(response.output)
model_id = res['ic_req_data']['model_id']
req_id = res['req_id']
# add timestamp for response
res['ic_timestamp_tag'] = time.time() # we add this to prevent that client recieves
# response for previous request due to:
# 1. use the same req_id by mistake
# 2. the client corresponding to previous request did not recieve all responses for some reason
response.output = json.dumps(res)
except:
logger.info("ModelSendResponseList: req_id {}: Failed to read response data from inference server".format(req_id))
await context.abort(grpc.StatusCode.INTERNAL, "ModelSendResponseList: req_id {}: Failed to read response data from inference server".format(req_id))
await data_manager.enque_response(model_id, req_id, response)
logger.info("ModelSendResponseList: req_id {}: response data: {}".format(req_id, res))
if res['is_end'] == 1:
if ('is_end' not in res) or (res['is_end'] == 1):
break
if time.time() - response_start_time > global_config.resonpse_timeout:
await data_manager.clear_response(model_id, req_id)