Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New feature] Add inference load balance controller for fastdeploy llm #2276

Open
wants to merge 9 commits into
base: llm
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions llm_ic/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# 大模型服务的负载均衡组件

## 环境要求

- python >= 3.7
- 启动好的redis服务,用于作为负载均衡的数据库

## 环境变量
目前所支持的环境变量参考fastdeploy_ic里的config.py

| 环境变量 | 含义 |
| -------- | ------- |
| REDIS_HOST | redis服务的ip |
| REDIS_PORT | redis服务的port |
| REDIS_USERNAME | redis认证用户 |
| REDIS_PASSWORD | redis认证密码 |
| RESPONSE_TIMEOUT | 获取推理服务流式token的超时时间 |


## 启动示例

```shell
export REDIS_HOST="localhost"
export REDIS_PORT="6379"
python main.py
```

Empty file.
30 changes: 30 additions & 0 deletions llm_ic/fastdeploy_ic/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import multiprocessing
import json

class GlobalConfig():
""" global config """

def __init__(self):
"""init
Args:
None
Returns:
None
"""
# Redis
self.redis_host = os.getenv('REDIS_HOST', default="localhost")
self.redis_port = int(os.getenv('REDIS_PORT', default="6379"))
self.redis_db = int(os.getenv('REDIS_DB', default="0"))
self.redis_username = os.getenv('REDIS_USERNAME', default=None)
self.redis_password = os.getenv('REDIS_PASSWORD', default=None)

# Response
self.resonpse_timeout = int(os.getenv('RESPONSE_TIMEOUT', default="120"))

# Server
self.num_process = int(os.getenv('NUM_PROCESS', default=multiprocessing.cpu_count()))

# Logger
self.log_dir = os.getenv('IC_LOG_DIR', default='ic_logs')

Empty file.
139 changes: 139 additions & 0 deletions llm_ic/fastdeploy_ic/data/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@

import json
import math
import asyncio

import aioredis

import fastdeploy_ic.proto.ic_pb2 as ic_pb2
from fastdeploy_ic.utils import get_logger

logger = get_logger("data_manager", "ic_data_manager.log")

__retry_times = 5 # redis client may have unexpected errors, we retry it with respect to some errors
def retry_wrapper(f):
async def wrapper(*args, **kwargs):
for i in range(__retry_times):
try:
return await f(*args, **kwargs)
except asyncio.CancelledError:
logger.info("{} occured asyncio.CancelledError, retry times: {}".format(f.__name__, i+1))
continue
except aioredis.ConnectionError:
args[0].renew_client()
logger.info("{} occured aioredis.ConnectionError, retry times: {}".format(f.__name__, i+1))
continue
except aioredis.TimeoutError:
args[0].renew_client()
logger.info("{} occured aioredis.TimeoutError, retry times: {}".format(f.__name__, i+1))
continue
return wrapper



class DataManager:
def __init__(self, redis_conf) -> None:
self.redis_conf = redis_conf
self.client = aioredis.Redis(**redis_conf)
self.internal_check_key_prefix = '__keymap_'

def renew_client(self):
self.client = aioredis.Redis(**self.redis_conf)

@retry_wrapper
async def check_req_id_exist(self, model_id, req_id):
key = '{}{}'.format(self.internal_check_key_prefix, model_id)
logger.info("check_req_id_exist: key: {} value: {}".format(key, req_id))
is_exist = await self.client.sismember(key, req_id)
return is_exist

@retry_wrapper
async def add_req_id_to_map(self, model_id, req_id):
key = '{}{}'.format(self.internal_check_key_prefix, model_id)
logger.info("add_req_id_to_map: key: {} value: {}".format(key, req_id))
await self.client.sadd(key, req_id)

@retry_wrapper
async def remove_req_id_from_map(self, model_id, req_id):
key = '{}{}'.format(self.internal_check_key_prefix, model_id)
logger.info("remove_req_id_from_map: key: {} value: {}".format(key, req_id))
await self.client.srem(key, req_id)

@retry_wrapper
async def enque_request(self, model_id, req, to_end=True):
serialized_req = req.SerializeToString()
# key = model_id
logger.info("enque_request: key: {} value: {}".format(model_id, req))
if to_end:
await self.client.rpush(model_id, serialized_req)
else:
await self.client.lpush(model_id, serialized_req)

@retry_wrapper
async def deque_request(self, model_id):
data = await self.client.lpop(model_id)
if data is not None:
data = ic_pb2.ModelInferRequest.FromString(data)
logger.info("deque_request: key: {} value: {}".format(model_id, data))
return data

@retry_wrapper
async def remove_request(self, model_id, req):
serialized_req = req.SerializeToString()
logger.info("remove_request: key: {} value: {}".format(model_id, req))
await self.client.lrem(model_id, 1, serialized_req)

@retry_wrapper
async def enque_response(self, model_id, req_id, res, to_end=True):
serialized_res = res.SerializeToString()
key = '{}/{}'.format(model_id, req_id)
logger.info("enque_response: key: {} value: {}".format(key, res))
if to_end:
await self.client.rpush(key, serialized_res)
else:
await self.client.lpush(key, serialized_res)

@retry_wrapper
async def deque_response(self, model_id, req_id):
key = '{}/{}'.format(model_id, req_id)
data = await self.client.lpop(key)
if data is not None:
data = ic_pb2.ModelInferResponse.FromString(data)
logger.info("deque_response: key: {} value: {}".format(key, data))
return data

@retry_wrapper
async def clear_response(self, model_id, req_id):
key = '{}/{}'.format(model_id, req_id)
logger.info("clear_response: key: {}".format(key))
await self.client.delete(key)

async def get_requests_by_number(self, model_id, max_request_num):
# return requests by ByRequest strategy
requests = []
for i in range(max_request_num):
request = await self.deque_request(model_id)
if request is not None:
requests.append(request)
else:
break
logger.info("get_requests_by_number: model_id: {} length: {}".format(model_id, len(requests)))
return requests

async def get_requests_by_block(self, model_id, max_request_num, block_num, block_size, dec_token_num):
# return requests by ByToken strategy
requests = []
left_block_num = block_num
for i in range(max_request_num):
request = await self.deque_request(model_id)
if request is not None:
text_words_num = json.loads(request.input)['text_words_num']
need_block_num = math.ceil((text_words_num + dec_token_num)/block_size)
if need_block_num < left_block_num:
requests.append(request)
left_block_num -= need_block_num
else:
await self.enque_request(model_id, request, to_end=False)
break
logger.info("get_requests_by_block: model_id: {} length: {}".format(model_id, len(requests)))
return requests
3 changes: 3 additions & 0 deletions llm_ic/fastdeploy_ic/proto/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import sys
import os
sys.path.append(os.path.dirname(__file__))
99 changes: 99 additions & 0 deletions llm_ic/fastdeploy_ic/proto/ic.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
syntax = "proto3";
package language_inference;

// Inference Server GRPC endpoints.
service GRPCInferenceService
{
// 模型推理请求入口,给上层dispatch调用
// 输入一个请求,流式返回多个response
rpc ModelStreamInfer(ModelInferRequest) returns (stream ModelInferResponse) {}

// 拉取一个请求,给inference server调用
rpc ModelFetchRequest(ModelFetchRequestParams) returns (ModelFetchRequestResult) {}

// 发送请求的返回结果,给inference server调用
// response是流式的发送
rpc ModelSendResponse(stream ModelInferResponse) returns (ModelSendResponseResult) {}

// 批量发送请求的返回结果,给inference server调用
// response是流式的发送
rpc ModelSendResponseList(stream ModelInferResponseList) returns (ModelSendResponseResult) {}
}

message ModelFetchRequestParams
{
// 模型全局唯一id
repeated string model_id = 1;

// 一次返回的最大请求数
int32 max_request_num = 2;

FetchStrategy strategy = 3;

ByTokenParams by_token_params = 4;
}
// 根据 token 数量拉取请求的计算公式:
// 每个query需要的block数量: block_num = ceil((text_words_num + dec_token_num)/block_size)

enum FetchStrategy {
// 根据 request 数量拉取请求
ByRequest = 0; // 默认值

// 根据 token 数量拉取请求
ByToken = 1;
}

message ByTokenParams
{
// 可用的 block 数量
int32 block_num = 1;

// 每个 block 能支持的 token 数量
int32 block_size = 2;

// 每个 query 需要给输出预留的 token 数量
int32 dec_token_num = 3;
}

message ModelFetchRequestResult
{
// 获取到的请求数组
repeated ModelInferRequest requests = 1;
}

// 无需关心SendResponse的返回值
message ModelSendResponseResult {
}

message ModelInferRequest
{
// 模型唯一id
string model_id = 1;

// 请求唯一id,必须全局唯一
string request_id = 2;

// 串联上下游日志的id,用于定位问题
string trace_id = 3;

// 语言模型输入
string input = 4;
}

message ModelInferResponseList{
repeated ModelInferResponse response_list = 1;
}

message ModelInferResponse
{
// 请求唯一id
string request_id = 1;

// 返回的句子id,表示第几句,用于去重和排序
int32 sentence_id = 2;

// 语言模型输出
string output = 3;
}


41 changes: 41 additions & 0 deletions llm_ic/fastdeploy_ic/proto/ic_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

175 changes: 175 additions & 0 deletions llm_ic/fastdeploy_ic/proto/ic_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""
import grpc

import ic_pb2 as ic__pb2


class GRPCInferenceServiceStub(object):
"""Inference Server GRPC endpoints.
"""

def __init__(self, channel):
"""Constructor.
Args:
channel: A grpc.Channel.
"""
self.ModelStreamInfer = channel.unary_stream(
'/language_inference.GRPCInferenceService/ModelStreamInfer',
request_serializer=ic__pb2.ModelInferRequest.SerializeToString,
response_deserializer=ic__pb2.ModelInferResponse.FromString,
)
self.ModelFetchRequest = channel.unary_unary(
'/language_inference.GRPCInferenceService/ModelFetchRequest',
request_serializer=ic__pb2.ModelFetchRequestParams.SerializeToString,
response_deserializer=ic__pb2.ModelFetchRequestResult.FromString,
)
self.ModelSendResponse = channel.stream_unary(
'/language_inference.GRPCInferenceService/ModelSendResponse',
request_serializer=ic__pb2.ModelInferResponse.SerializeToString,
response_deserializer=ic__pb2.ModelSendResponseResult.FromString,
)
self.ModelSendResponseList = channel.stream_unary(
'/language_inference.GRPCInferenceService/ModelSendResponseList',
request_serializer=ic__pb2.ModelInferResponseList.SerializeToString,
response_deserializer=ic__pb2.ModelSendResponseResult.FromString,
)


class GRPCInferenceServiceServicer(object):
"""Inference Server GRPC endpoints.
"""

def ModelStreamInfer(self, request, context):
"""模型推理请求入口,给上层dispatch调用
输入一个请求,流式返回多个response
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def ModelFetchRequest(self, request, context):
"""拉取一个请求,给inference server调用
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def ModelSendResponse(self, request_iterator, context):
"""发送请求的返回结果,给inference server调用
response是流式的发送
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def ModelSendResponseList(self, request_iterator, context):
"""批量发送请求的返回结果,给inference server调用
response是流式的发送
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_GRPCInferenceServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
'ModelStreamInfer': grpc.unary_stream_rpc_method_handler(
servicer.ModelStreamInfer,
request_deserializer=ic__pb2.ModelInferRequest.FromString,
response_serializer=ic__pb2.ModelInferResponse.SerializeToString,
),
'ModelFetchRequest': grpc.unary_unary_rpc_method_handler(
servicer.ModelFetchRequest,
request_deserializer=ic__pb2.ModelFetchRequestParams.FromString,
response_serializer=ic__pb2.ModelFetchRequestResult.SerializeToString,
),
'ModelSendResponse': grpc.stream_unary_rpc_method_handler(
servicer.ModelSendResponse,
request_deserializer=ic__pb2.ModelInferResponse.FromString,
response_serializer=ic__pb2.ModelSendResponseResult.SerializeToString,
),
'ModelSendResponseList': grpc.stream_unary_rpc_method_handler(
servicer.ModelSendResponseList,
request_deserializer=ic__pb2.ModelInferResponseList.FromString,
response_serializer=ic__pb2.ModelSendResponseResult.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'language_inference.GRPCInferenceService', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))


# This class is part of an EXPERIMENTAL API.
class GRPCInferenceService(object):
"""Inference Server GRPC endpoints.
"""

@staticmethod
def ModelStreamInfer(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_stream(request, target, '/language_inference.GRPCInferenceService/ModelStreamInfer',
ic__pb2.ModelInferRequest.SerializeToString,
ic__pb2.ModelInferResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def ModelFetchRequest(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/language_inference.GRPCInferenceService/ModelFetchRequest',
ic__pb2.ModelFetchRequestParams.SerializeToString,
ic__pb2.ModelFetchRequestResult.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def ModelSendResponse(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(request_iterator, target, '/language_inference.GRPCInferenceService/ModelSendResponse',
ic__pb2.ModelInferResponse.SerializeToString,
ic__pb2.ModelSendResponseResult.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def ModelSendResponseList(request_iterator,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.stream_unary(request_iterator, target, '/language_inference.GRPCInferenceService/ModelSendResponseList',
ic__pb2.ModelInferResponseList.SerializeToString,
ic__pb2.ModelSendResponseResult.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
Empty file.
203 changes: 203 additions & 0 deletions llm_ic/fastdeploy_ic/server/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import time

import grpc
import json
import asyncio
from aioredis import RedisError

import fastdeploy_ic.proto.ic_pb2_grpc as ic_pb2_grpc
import fastdeploy_ic.proto.ic_pb2 as ic_pb2
from fastdeploy_ic.data.manager import DataManager
from fastdeploy_ic.config import GlobalConfig
from fastdeploy_ic.utils import get_logger

logger = get_logger("ic_server", "ic_server.log")

global_config = GlobalConfig()
redis_config = {
'host': global_config.redis_host,
'port': global_config.redis_port,
'db': global_config.redis_db,
'username': global_config.redis_username,
'password': global_config.redis_password
}

class GRPCInferenceServiceServicer(ic_pb2_grpc.GRPCInferenceServiceServicer):
def __init__(self):
self.data_manager = DataManager(redis_config)

async def ModelStreamInfer(self, request, context):
"""
Provided for request sender.
"""
try:
model_id = request.model_id
input_dict = json.loads(request.input)
if 'req_id' not in input_dict:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "ModelStreamInfer: there is no req_id in request")
if 'ic_req_data' not in input_dict:
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "ModelStreamInfer: there is no ic_req_data in request")
req_id = input_dict['req_id']
# Check whether req_id is repeated
# Warning: We only simply check whether there is any same req_id has been in,
# but we can not prevent two requests with the same req_id coming simultaneously.
# To achieve this, we should add lock to query and insert query into redis, which will influence performance.
# Currently, we assume different req_ids are confirmed by users.
# if await self.data_manager.check_req_id_exist(model_id, req_id):
# logger.info("ModelStreamInfer: req_id {}: has existed in other task".format(req_id))
# await context.abort(grpc.StatusCode.INVALID_ARGUMENT, "ModelStreamInfer: req_id {}: has existed in other task".format(req_id))
# 1. push request to redis
await self.data_manager.add_req_id_to_map(model_id, req_id)
await self.data_manager.enque_request(model_id, request)
logger.info("ModelStreamInfer: req_id {}: enqued request".format(req_id))
# 2. response stream results
response_start_time = time.time()
while True:
if time.time() - response_start_time > global_config.resonpse_timeout:
if await self.data_manager.check_req_id_exist(model_id, req_id): # clear resource about this req
await self.data_manager.remove_request(model_id, request)
await self.data_manager.clear_response(model_id, req_id)
await self.data_manager.remove_req_id_from_map(model_id, req_id)
logger.info("ModelStreamInfer: req_id {}: Get response from inference server timeout".format(req_id))
await context.abort(grpc.StatusCode.DEADLINE_EXCEEDED, "ModelStreamInfer: req_id {}: Get response from inference server timeout".format(req_id))
data = await self.data_manager.deque_response(model_id, req_id)
if data is None:
await asyncio.sleep(1)
continue
try:
output_dict = json.loads(data.output)
if 'ic_timestamp_tag' in output_dict:
if time.time() - output_dict['ic_timestamp_tag'] > global_config.resonpse_timeout: # the response is invalid because of timeout, even maybe from previous request with same req_id
continue
del output_dict['ic_timestamp_tag']
data.output = json.dumps(output_dict)
logger.info("ModelStreamInfer: req_id {}: response data: {}".format(req_id, output_dict))
yield data
# two cases denote the request is done
# 1. something error returned by server, but not normal result
# 2. is_end is 1
if ('is_end' not in output_dict) or (output_dict['is_end'] == 1):
# clear resource about this req, only req_id in map should be removed
await self.data_manager.remove_req_id_from_map(model_id, req_id)
return

except Exception as e:
if await self.data_manager.check_req_id_exist(model_id, req_id): # clear resource about this req
await self.data_manager.clear_response(model_id, req_id)
await self.data_manager.remove_req_id_from_map(model_id, req_id)
logger.info("ModelStreamInfer: req_id {}: Failed to read response data from inference server, exception {}".format(req_id, e))
await context.abort(grpc.StatusCode.INTERNAL, "ModelStreamInfer: req_id {}: Failed to read response data from inference server".format(req_id))
except RedisError as e:
# if redis operation failed, should arrive here
# Log the error message, and signal users internal error (we can not expose origin redis error to users)
logger.info("ModelStreamInfer: exception: {}".format(e))
await context.abort(grpc.StatusCode.INTERNAL, "Internal error happened")

except Exception as e:
logger.info("ModelStreamInfer: exception: type {}: {}".format(type(e), e))
await context.abort(grpc.StatusCode.INTERNAL, "Internal error happened")

async def ModelFetchRequest(self, request, context):
"""
Provide for inference service.
"""
# provide two types for providing tasks
# 1. ByRequest
# 2. ByToken
try:
model_ids = request.model_id
strategy = request.strategy
requests = []
for model_id in model_ids:
if strategy == ic_pb2.FetchStrategy.ByRequest:
requests.extend(await self.data_manager.get_requests_by_number(model_id, request.max_request_num))

else:
by_token_params = request.by_token_params
requests.extend(await self.data_manager.get_requests_by_block(model_id, request.max_request_num,
by_token_params.block_num, by_token_params.block_size, by_token_params.dec_token_num))

fetch_request_result = ic_pb2.ModelFetchRequestResult()
fetch_request_result.requests.extend(requests)
logger.info("ModelFetchRequest: return requests: {}".format(requests))
except RedisError as e:
# if operation failed, should arrive here
# Log the error message, and signal users internal error
logger.info("ModelFetchRequest: exception: {}".format(e))
await context.abort(grpc.StatusCode.INTERNAL, "Internal error happened")
return fetch_request_result


async def ModelSendResponse(self, response_iterator, context):
"""
Provide for inference service.
"""
# Get response from inference server
try:
response_start_time = time.time()
async for response in response_iterator:
try:
res = json.loads(response.output)
model_id = res['ic_req_data']['model_id']
req_id = res['req_id']
# add timestamp for response
res['ic_timestamp_tag'] = time.time() # we add this to prevent that client recieves
# response for previous request due to:
# 1. use the same req_id by mistake
# 2. the client corresponding to previous request did not recieve all responses for some reason
response.output = json.dumps(res)
except:
logger.info("ModelSendResponse: req_id {}: Failed to read response data from inference server".format(req_id))
await context.abort(grpc.StatusCode.INTERNAL, "ModelSendResponse: req_id {}: Failed to read response data from inference server".format(req_id))
await self.data_manager.enque_response(model_id, req_id, response)
logger.info("ModelSendResponse: req_id {}: response data: {}".format(req_id, res))
if ('is_end' not in res) or (res['is_end'] == 1):
return ic_pb2.ModelSendResponseResult()
if time.time() - response_start_time > global_config.resonpse_timeout:
await self.data_manager.clear_response(model_id, req_id)
logger.info("ModelSendResponse: req_id {}: Get response from inference server timeout".format(req_id))
await context.abort(grpc.StatusCode.DEADLINE_EXCEEDED, "ModelSendResponse: req_id {}: Get response from inference server timeout".format(req_id))
except RedisError as e:
# if operation failed, should arrive here
# Log the error message, and signal users internal error
logger.info("ModelSendResponse: exception: {}".format(e))
await context.abort(grpc.StatusCode.INTERNAL, "Internal error happened")

async def ModelSendResponseList(self, response_list_iterator, context):
"""
Provide for inference service.
"""
# Get response from inference server
try:
response_start_time = time.time()
async for response_list in response_list_iterator:
for response in response_list.response_list:
try:
res = json.loads(response.output)
model_id = res['ic_req_data']['model_id']
req_id = res['req_id']
# add timestamp for response
res['ic_timestamp_tag'] = time.time() # we add this to prevent that client recieves
# response for previous request due to:
# 1. use the same req_id by mistake
# 2. the client corresponding to previous request did not recieve all responses for some reason
response.output = json.dumps(res)
except:
logger.info("ModelSendResponseList: req_id {}: Failed to read response data from inference server".format(req_id))
await context.abort(grpc.StatusCode.INTERNAL, "ModelSendResponseList: req_id {}: Failed to read response data from inference server".format(req_id))
await self.data_manager.enque_response(model_id, req_id, response)
logger.info("ModelSendResponseList: req_id {}: response data: {}".format(req_id, res))
if ('is_end' not in res) or (res['is_end'] == 1):
break
if time.time() - response_start_time > global_config.resonpse_timeout:
await self.data_manager.clear_response(model_id, req_id)
logger.info("ModelSendResponseList: req_id {}: Get response from inference server timeout".format(req_id))
await context.abort(grpc.StatusCode.DEADLINE_EXCEEDED, "ModelSendResponseList: req_id {}: Get response from inference server timeout".format(req_id))
except RedisError as e:
# if operation failed, should arrive here
# Log the error message, and signal users internal error
logger.info("ModelSendResponseList: exception: {}".format(e))
await context.abort(grpc.StatusCode.INTERNAL, "Internal error happened")
return ic_pb2.ModelSendResponseResult()


74 changes: 74 additions & 0 deletions llm_ic/fastdeploy_ic/server/launcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from concurrent import futures
import contextlib
import multiprocessing
import socket
import sys
import asyncio

import grpc

import fastdeploy_ic.proto.ic_pb2_grpc as ic_pb2_grpc
from fastdeploy_ic.config import GlobalConfig
from .api import GRPCInferenceServiceServicer

global_config = GlobalConfig()
_PROCESS_COUNT = global_config.num_process
_THREAD_CONCURRENCY = _PROCESS_COUNT


async def _run_server(bind_address):
"""Start a server in a subprocess."""
options = (("grpc.so_reuseport", 1),)
server = grpc.aio.server(futures.ThreadPoolExecutor(
max_workers=_THREAD_CONCURRENCY,
),
options=options)
ic_pb2_grpc.add_GRPCInferenceServiceServicer_to_server(GRPCInferenceServiceServicer(), server)
server.add_insecure_port(bind_address)
await server.start()
await server.wait_for_termination()

def run(bind_address):
asyncio.run(_run_server(bind_address))




@contextlib.contextmanager
def _reserve_port(port):
"""Create a socket for all subprocesses to use."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 0:
raise RuntimeError("Failed to set SO_REUSEPORT.")
sock.bind(("", port))
try:
yield sock.getsockname()[1]
finally:
sock.close()


def serve(args):
with _reserve_port(args.grpc_port) as port:
bind_address = "0.0.0.0:{}".format(port)
print("Binding to '%s'", bind_address)
sys.stdout.flush()
workers = []
for _ in range(_PROCESS_COUNT):
# NOTE: It is imperative that the worker subprocesses be forked before
# any gRPC servers start up. See
# https://github.com/grpc/grpc/issues/16001 for more details.
worker = multiprocessing.Process(
target=run, args=(bind_address,)
)
worker.start()
workers.append(worker)
for worker in workers:
worker.join()



155 changes: 155 additions & 0 deletions llm_ic/fastdeploy_ic/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import os
import contextlib
import logging
import threading
import time
from typing import (Any, Generator, Optional, Union)
from logging.handlers import TimedRotatingFileHandler

import colorlog

from fastdeploy_ic.config import GlobalConfig

global_config = GlobalConfig()

__all__ = ['get_logger']

_LOG_CONFIG = {
'DEBUG': {
'color': 'purple'
},
'INFO': {
'color': 'green'
},
'WARNING': {
'color': 'yellow'
},
'ERROR': {
'color': 'red'
},
'CRITICAL': {
'color': 'bold_red'
},
}

class Logger(object):
_DEFAULT_NAME: str = 'fastdeploy_ic'

def __init__(self,
name: Optional[str]=None,
log_file=None,
time_rotation=7,
level=logging.INFO) -> None:
"""Initialize the instance based on a given name.
Args:
name: Logger name.
"""
super().__init__()
if name is None:
name = self._DEFAULT_NAME
self.logger = logging.getLogger(name)

self.format = colorlog.ColoredFormatter(
"%(log_color)s[%(asctime)-15s] [%(levelname)8s]%(reset)s - %(message)s",
log_colors={
key: conf['color']
for key, conf in _LOG_CONFIG.items()
}, )

if log_file is not None:
self.handler = TimedRotatingFileHandler(
log_file,
when="midnight",
backupCount=time_rotation,
encoding="utf-8")
else:
self.handler = logging.StreamHandler()
self.handler.setFormatter(self.format)

self.logger.addHandler(self.handler)
self.logger.setLevel(level)
self.logger.propagate = False
self._is_enabled = True

def __call__(self,
log_level: int,
msg: object,
*args: object,
**kwargs: Any) -> None:
if not self.is_enabled:
return

self.logger.log(log_level, msg, *args, **kwargs)

def debug(self, msg: object, *args: object, **kwargs: Any) -> None:
return self(logging.getLevelName('DEBUG'), msg, *args, **kwargs)

def info(self, msg: object, *args: object, **kwargs: Any) -> None:
return self(logging.getLevelName('INFO'), msg, *args, **kwargs)

def warning(self, msg: object, *args: object, **kwargs: Any) -> None:
return self(logging.getLevelName('WARNING'), msg, *args, **kwargs)

def error(self, msg: object, *args: object, **kwargs: Any) -> None:
return self(logging.getLevelName('ERROR'), msg, *args, **kwargs)

def critical(self, msg: object, *args: object, **kwargs: Any) -> None:
return self(logging.getLevelName('CRITICAL'), msg, *args, **kwargs)

def disable(self) -> None:
self._is_enabled = False

def enable(self) -> None:
self._is_enabled = True

@property
def is_enabled(self) -> bool:
return self._is_enabled

def set_level(self, log_level: Union[int, str]) -> None:
self.logger.setLevel(log_level)

@contextlib.contextmanager
def processing(self, msg: str,
interval: float=0.1) -> Generator[None, None, None]:
"""Display a message with spinners.
Args:
msg: Message to display.
interval: Spinning interval.
"""
end = False

def _printer() -> None:
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with self.use_terminator('\r'):
self.info(f"{msg}: {flag}")
time.sleep(interval)
index += 1

t = threading.Thread(target=_printer)
t.start()
yield
end = True

@contextlib.contextmanager
def use_terminator(self, terminator: str) -> Generator[None, None, None]:
old_terminator = self.handler.terminator
self.handler.terminator = terminator
yield
self.handler.terminator = old_terminator

def get_logger(name, file_name):
"""
Get logger
"""
if not os.path.exists(global_config.log_dir):
os.mkdir(global_config.log_dir)
file_path = os.path.join(global_config.log_dir, file_name)
logger = Logger(name=name, log_file=file_path)
return logger

9 changes: 9 additions & 0 deletions llm_ic/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import argparse

from fastdeploy_ic.server.launcher import serve

if __name__ == "__main__":
parser = argparse.ArgumentParser("Inference load balance controller launcher")
parser.add_argument("--grpc-port", type=int, default=9000)
args = parser.parse_args()
serve(args)
4 changes: 4 additions & 0 deletions llm_ic/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
aioredis
colorlog
grpcio
protobuf
33 changes: 33 additions & 0 deletions llm_ic/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import setuptools

setuptools.setup(
name="fastdeploy_ic",
version="0.0.9",
author="fastdeploy",
author_email="fastdeploy@baidu.com",
description="FastDeploy for Large Language Model",
long_description_content_type="text/plain",
url="https://github.com/PaddlePaddle/FastDeploy",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
install_requires=["colorlog", "aioredis", "grpcio", "protobuf"],
extras_require={"client": ['grpcio', 'tritonclient']},
license='Apache 2.0')