-
Notifications
You must be signed in to change notification settings - Fork 8
Implement distributed inference API server with FastAPI #60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Added a new file `api_server_dist.py` to handle video generation tasks using distributed inference. - Introduced endpoints for task submission, status checking, and result retrieval. - Implemented image downloading and task management with error handling. - Enhanced `infer.py` to ensure proper initialization of distributed processes. - Created a shell script `start_api_with_dist_inference.sh` for easy server startup with environment setup. This commit establishes a robust framework for managing video generation tasks in a distributed manner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @GACLove, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
Summary of Changes
Hello team, gemini-code-assist here to provide a summary of this pull request. This PR introduces a new FastAPI-based API server designed to handle video generation tasks using distributed inference. The core idea is to provide a web interface for submitting tasks, managing their lifecycle, and retrieving results, leveraging multiple processes for the actual inference work. It includes a new main script (api_server_dist.py
), a small but necessary adjustment to the existing inference initialization logic (infer.py
), and a convenience shell script to launch the server with distributed inference enabled.
Highlights
- New Distributed Inference API Server: A new Python script (
lightx2v/api_server_dist.py
) is added, implementing a FastAPI application. This server acts as the entry point for submitting video generation tasks. - FastAPI Endpoints: The new API server exposes several endpoints:
/v1/local/video/generate
(POST): Submit a new video generation task./v1/local/video/generate/service_status
(GET): Check the overall status of the API service./v1/local/video/generate/get_all_tasks
(GET): Retrieve the status of all submitted tasks./v1/local/video/generate/task_status
(POST): Get the status of a specific task by ID./v1/local/video/generate/get_task_result
(GET): Download the resulting video file for a completed task./v1/file/download
(GET): A general endpoint to download files (potentially for results)./v1/local/video/generate/stop_running_task
(GET): Attempt to stop the currently running task.
- Multiprocessing for Distributed Inference: The API server utilizes
torch.multiprocessing
with queues (input_queues
,output_queues
) to communicate with separate worker processes. Each worker process initializes a distributed inference runner (init_runner
) and processes tasks received from the input queue, sending results or errors back via the output queue. This mimics thetorchrun
multi-process setup. - Task Management and Status Tracking: The server implements basic task management, tracking the status (processing, success, failed) of submitted tasks using a
BaseServiceStatus
utility (presumably fromlightx2v.utils.service_utils
). It also handles downloading remote images specified in task requests. - Inference Initialization Fix: A small change in
lightx2v/infer.py
ensures thatdist.init_process_group
is only called if a process group is not already initialized. This is crucial when the distributed processes are managed externally by the new API server's multiprocessing setup. - Startup Script: A new shell script (
scripts/start_api_with_dist_inference.sh
) is added to simplify launching the API server with distributed inference workers. It sets necessary environment variables and calls theapi_server_dist.py
script with appropriate arguments.
Changelog
Click here to see the changelog
- lightx2v/api_server_dist.py
- Added a new file implementing a FastAPI server.
- Defined a
Message
Pydantic model for task submission payload (lines 42-54). - Implemented
download_image
function to handle image URLs (lines 61-78). - Implemented
local_video_generate
function to submit tasks to distributed workers via queues and wait for results (lines 80-148). - Added FastAPI endpoints for task submission, status checking, result retrieval, and file download (lines 150-204).
- Added an endpoint to stop the currently running task using thread signaling (lines 218-235).
- Implemented
distributed_inference_worker
function for the multiprocessing worker processes, handling distributed initialization and task processing loops (lines 242-339). - Implemented
start_distributed_inference_with_queue
to launch worker processes usingtorch.multiprocessing
and queues (lines 341-380). - Implemented
stop_distributed_inference_with_queue
for graceful shutdown of worker processes (lines 383-431). - Added main execution block to parse arguments, optionally start distributed inference workers, and run the FastAPI server (lines 438-495).
- Included signal handling and atexit registration for cleanup (lines 470-482, 492-495).
- lightx2v/infer.py
- Modified
init_runner
to check if a distributed process group is already initialized before callingdist.init_process_group
(lines 28-29).
- Modified
- scripts/start_api_with_dist_inference.sh
- Added a new shell script to set up environment variables (
CUDA_VISIBLE_DEVICES
,PYTHONPATH
, etc.) (lines 1-28). - Added checks for required path variables (
lightx2v_path
,model_path
) (lines 14-22). - Added commands to execute the
lightx2v.api_server_dist
module with specific arguments to start the API server and distributed inference (lines 38-45).
- Added a new shell script to set up environment variables (
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new FastAPI server (api_server_dist.py
) designed to manage distributed inference tasks for video generation. It includes endpoints for task submission, status checking, and result retrieval, leveraging torch.multiprocessing
for distributed execution. A shell script (start_api_with_dist_inference.sh
) is provided for easy startup. The changes in infer.py
correctly handle the distributed process group initialization.
The overall structure for integrating FastAPI with distributed PyTorch inference via queues is a reasonable approach. However, the implementation has several areas that require attention, particularly regarding error handling, task management logic, and security.
Summary of Findings
- Security Vulnerabilities (Critical): The
/v1/file/download
endpoint and the result retrieval logic in/v1/local/video/generate/get_task_result
are vulnerable to path traversal, allowing potential access to arbitrary files on the server. The use ofverify=False
indownload_image
is also a security risk. - Unsafe Thread Termination (Critical): The
_async_raise
function and the/v1/local/video/generate/stop_running_task
endpoint usectypes
to force thread termination, which is unsafe and can lead to instability, resource leaks, or crashes. - Task Result Handling Logic (High): The logic for retrieving task results from the output queue in
local_video_generate
is flawed and can lead to unprocessed results or race conditions if multiple tasks are active. - Hardcoded Paths (Medium): The startup script contains hardcoded local paths, violating the contributing guidelines and reducing portability.
- Inefficient/Problematic Asyncio Usage (Medium): Creating and setting a new asyncio event loop for each task within the worker process is inefficient and not the standard way to use asyncio.
- Broad Exception Handling (Medium): Bare
except:
blocks are used in the cleanup logic, which can hide unexpected errors.
Merge Readiness
This pull request introduces critical security vulnerabilities (path traversal, unsafe thread termination) and has significant issues in the task result handling logic. These issues must be addressed before the code can be considered for merging. Additionally, there are medium severity issues related to hardcoded paths, asyncio usage, and broad exception handling that should also be fixed to improve code quality and maintainability. I am unable to approve this pull request; please address the critical and high severity issues and request another review.
lightx2v/api_server_dist.py
Outdated
@app.get("/v1/local/video/generate/stop_running_task") | ||
async def stop_running_task(): | ||
global thread | ||
if thread and thread.is_alive(): | ||
try: | ||
_async_raise(thread.ident, SystemExit) | ||
thread.join() | ||
|
||
# Clean up the thread reference | ||
thread = None | ||
ApiServerServiceStatus.clean_stopped_task() | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
return {"stop_status": "success", "reason": "Task stopped successfully."} | ||
except Exception as e: | ||
return {"stop_status": "error", "reason": str(e)} | ||
else: | ||
return {"stop_status": "do_nothing", "reason": "No running task found."} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This endpoint uses the unsafe _async_raise
function to terminate the background task thread. As noted in the comment on _async_raise
, this can lead to unpredictable behavior, resource leaks, and crashes. The thread.join()
call after _async_raise
might hang indefinitely if the thread is stuck in a state where it cannot process the injected exception.
This approach to stopping a task is fundamentally unsafe and should be replaced with a cooperative cancellation mechanism.
lightx2v/api_server_dist.py
Outdated
@app.get("/v1/local/video/generate/get_task_result") | ||
async def get_task_result(message: TaskStatusMessage): | ||
result = ApiServerServiceStatus.get_status_task_id(message.task_id) | ||
# 传输save_video_path内容到外部 | ||
save_video_path = result.get("save_video_path") | ||
|
||
if save_video_path and Path(save_video_path).is_absolute() and Path(save_video_path).exists(): | ||
return FileResponse(save_video_path) | ||
elif save_video_path and not Path(save_video_path).is_absolute(): | ||
video_path = OUTPUT_VIDEO_DIR / save_video_path | ||
if video_path.exists(): | ||
return FileResponse(video_path) | ||
|
||
return {"status": "not_found", "message": "Task result not found"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The /v1/local/video/generate/get_task_result
endpoint takes a TaskStatusMessage
which only contains a task_id
. It then retrieves the task status and attempts to return a FileResponse
based on the save_video_path
stored in the task result.
This endpoint is vulnerable to path traversal if the save_video_path
stored in the task result can be controlled by the user or contains relative path components (..
). An attacker could potentially craft a save_video_path
that points to arbitrary files on the server's filesystem, allowing them to download sensitive files.
While the code checks Path(save_video_path).is_absolute()
and prepends OUTPUT_VIDEO_DIR
for relative paths, this doesn't fully mitigate the risk if the stored path contains ..
components within the relative path or if the OUTPUT_VIDEO_DIR
itself is not sufficiently restricted.
To mitigate this, you should:
- Strictly validate the
save_video_path
format and content when the task is submitted. - Ensure that the resolved path (
OUTPUT_VIDEO_DIR / save_video_path
for relative paths) is strictly contained within theOUTPUT_VIDEO_DIR
usingpathlib.Path.resolve()
and checking if it's a child of the output directory. - Consider returning the video content directly or using a more secure method for serving files, rather than relying solely on user-provided paths stored in the task result.
lightx2v/api_server_dist.py
Outdated
@app.get("/v1/file/download") | ||
async def download_file(file_path: str): | ||
if Path(file_path).is_absolute(): | ||
if Path(file_path).exists(): | ||
return FileResponse(file_path) | ||
else: | ||
return {"status": "not_found", "message": f"File not found: {file_path}"} | ||
else: | ||
file_path_obj = OUTPUT_VIDEO_DIR / file_path | ||
if file_path_obj.exists(): | ||
return FileResponse(file_path_obj) | ||
else: | ||
return {"status": "not_found", "message": f"File not found: {file_path}"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The /v1/file/download
endpoint takes an arbitrary file_path
string and directly serves the file using FileResponse
if it exists. This is a severe path traversal vulnerability.
An attacker can provide a file_path
like /etc/passwd
or ../../../../etc/passwd
to download arbitrary files from the server's filesystem, regardless of whether the path is absolute or relative to OUTPUT_VIDEO_DIR
.
This endpoint should be removed or heavily restricted. If the intention is only to serve files from the OUTPUT_VIDEO_DIR
, the endpoint should:
- Only accept relative paths.
- Strictly validate the path to prevent
..
or other traversal attempts. - Resolve the path relative to
OUTPUT_VIDEO_DIR
and verify that the final resolved path is a child ofOUTPUT_VIDEO_DIR
before serving the file.
Without these checks, this endpoint exposes the entire filesystem to anyone who can access the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/fix
lightx2v/api_server_dist.py
Outdated
def _async_raise(tid, exctype): | ||
"""Force thread tid to raise exception exctype""" | ||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype)) | ||
if res == 0: | ||
raise ValueError("Invalid thread ID") | ||
elif res > 1: | ||
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), 0) | ||
raise SystemError("PyThreadState_SetAsyncExc failed") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _async_raise
function uses ctypes
to inject an exception into another thread. This is a highly dangerous and unreliable method for thread termination.
Injecting exceptions can leave resources (like locks, file handles, network connections) in an inconsistent state, leading to deadlocks, resource leaks, or crashes. Python's threading model is not designed for external, forceful termination like this.
Python's recommended way to stop a thread is to signal it to stop gracefully (e.g., by setting an event or flag that the thread periodically checks) and let the thread clean up and exit on its own. While this requires cooperation from the thread's target function, it's the only safe way.
Using _async_raise
for the stop_running_task
endpoint is a critical
issue due to the potential for instability and crashes.
with httpx.Client(verify=False) as client: | ||
response = client.get(image_url) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using verify=False
in httpx.Client
disables SSL certificate verification. This is a significant security risk as it makes the application vulnerable to man-in-the-middle attacks when downloading images from external URLs. It should only be used in development/testing environments, never in production. Consider adding a configuration option to control this or removing it entirely and ensuring proper certificate handling.
lightx2v/api_server_dist.py
Outdated
if 'loop' in locals() and loop and not loop.is_closed(): | ||
loop.close() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The finally
block attempts to close the loop
if it exists and is not closed. However, the loop
variable is defined inside the while True
loop (lines 274-275). If an exception occurs before the loop is created (e.g., during initialization), the loop
variable won't exist, and this check will fail with a NameError
. The if 'loop' in locals()
check helps, but it's still tied to the inner loop's variable.
The event loop management should be outside the task processing loop, created once when the worker starts and closed when the worker exits.
lightx2v/api_server_dist.py
Outdated
except: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The except:
block here is too broad. Catching all exceptions (except:
) without specifying the exception type is generally bad practice as it can hide unexpected errors (like KeyboardInterrupt
, SystemExit
, NameError
, etc.) and make debugging difficult. It's better to catch specific exceptions you expect (e.g., queue.Empty
if that's the only expected error during cleanup) or at least log the exception type and traceback.
lightx2v/api_server_dist.py
Outdated
except: | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lightx2v/api_server_dist.py
Outdated
def distributed_inference_worker(rank, world_size, master_addr, master_port, args, input_queue, output_queue): | ||
"""分布式推理服务工作进程""" | ||
try: | ||
# 设置环境变量 | ||
os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
os.environ["ENABLE_PROFILING_DEBUG"] = "true" | ||
os.environ["ENABLE_GRAPH_MODE"] = "false" | ||
os.environ["RANK"] = str(rank) | ||
os.environ["WORLD_SIZE"] = str(world_size) | ||
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) | ||
|
||
logger.info(f"进程 {rank}/{world_size-1} 正在初始化分布式推理服务...") | ||
|
||
dist.init_process_group(backend='nccl', init_method=f'tcp://{master_addr}:{master_port}', rank=rank, world_size=world_size) | ||
|
||
config = set_config(args) | ||
config["mode"] = "server" | ||
logger.info(f"config: {config}") | ||
runner = init_runner(config) | ||
|
||
logger.info(f"进程 {rank}/{world_size-1} 分布式推理服务初始化完成,等待任务...") | ||
|
||
|
||
while True: | ||
try: | ||
task_data = input_queue.get(timeout=1.0) # 1秒超时 | ||
if task_data is None: # 停止信号 | ||
logger.info(f"进程 {rank}/{world_size-1} 收到停止信号,退出推理服务") | ||
break | ||
logger.info(f"进程 {rank}/{world_size-1} 收到推理任务: {task_data['task_id']}") | ||
|
||
runner.set_inputs(task_data) | ||
loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(loop) | ||
|
||
# 运行推理,复用已创建的事件循环 | ||
try: | ||
loop.run_until_complete(runner.run_pipeline()) | ||
|
||
# 只有 Rank 0 负责将结果放入输出队列,避免重复 | ||
if rank == 0: | ||
result = { | ||
"task_id": task_data["task_id"], | ||
"status": "success", | ||
"save_video_path": task_data["save_video_path"], | ||
"message": "推理完成" | ||
} | ||
output_queue.put(result) | ||
logger.info(f"任务 {task_data['task_id']} 处理完成 (由 Rank 0 报告)") | ||
dist.barrier() | ||
|
||
except Exception as e: | ||
# 只有 Rank 0 负责报告错误 | ||
if rank == 0: | ||
result = { | ||
"task_id": task_data["task_id"], | ||
"status": "failed", | ||
"error": str(e), | ||
"message": f"推理失败: {str(e)}" | ||
} | ||
output_queue.put(result) | ||
logger.error(f"任务 {task_data['task_id']} 推理失败: {str(e)} (由 Rank 0 报告)") | ||
dist.barrier() | ||
|
||
except queue.Empty: | ||
# 队列为空,继续等待 | ||
continue | ||
except Exception as e: | ||
logger.error(f"进程 {rank}/{world_size-1} 处理任务时发生错误: {str(e)}") | ||
# 只有 Rank 0 负责发送错误结果 | ||
if rank == 0: | ||
error_result = { | ||
"task_id": task_data.get("task_id", "unknown") if 'task_data' in locals() else "unknown", | ||
"status": "error", | ||
"error": str(e), | ||
"message": f"处理任务时发生错误: {str(e)}" | ||
} | ||
output_queue.put(error_result) | ||
dist.barrier() | ||
|
||
except Exception as e: | ||
logger.error(f"分布式推理服务进程 {rank}/{world_size-1} 启动失败: {str(e)}") | ||
# 只有 Rank 0 负责报告启动失败 | ||
if rank == 0: | ||
error_result = { | ||
"task_id": "startup", | ||
"status": "startup_failed", | ||
"error": str(e), | ||
"message": f"推理服务启动失败: {str(e)}" | ||
} | ||
output_queue.put(error_result) | ||
# 在进程最终退出时关闭事件循环和销毁分布式组 | ||
finally: | ||
if 'loop' in locals() and loop and not loop.is_closed(): | ||
loop.close() | ||
if dist.is_initialized(): | ||
dist.destroy_process_group() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distributed_inference_worker
function sets CUDA_VISIBLE_DEVICES
to str(rank)
. This assumes a specific mapping where each process corresponds to a single, distinct GPU index equal to its rank. This might not be true in all distributed setups, especially if CUDA_VISIBLE_DEVICES
is already set externally to specify a subset of available GPUs.
A more robust approach would be to use torch.cuda.set_device(rank)
after dist.init_process_group
if the process group is initialized across the devices specified by the original CUDA_VISIBLE_DEVICES
environment variable. If torchrun
or a similar launcher is used, it typically handles setting LOCAL_RANK
and RANK
and managing device assignment. If launching manually, the script or launcher should manage CUDA_VISIBLE_DEVICES
or the worker should use os.environ['LOCAL_RANK']
to determine its assigned device index within the visible devices.
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then | ||
cuda_devices=2,3 | ||
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}" | ||
export CUDA_VISIBLE_DEVICES=${cuda_devices} | ||
fi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block checks if CUDA_VISIBLE_DEVICES
is set and provides a default value (2,3
) if it's not. While helpful, hardcoding specific device IDs like 2,3
might not be appropriate for all environments. Consider making the default configurable or providing instructions on how users should set this variable externally.
… handling - Introduced a new endpoint `/v1/local/video/generate_form` for video generation that accepts image uploads. - Implemented unique filename generation for uploaded images to prevent conflicts. - Enhanced directory management for input and output paths. - Improved file download response with detailed status and size information. - Added error handling for distributed inference processes and graceful shutdown procedures.
api_server_dist.py
to handle video generation tasks using distributed inference.infer.py
to ensure proper initialization of distributed processes.start_api_with_dist_inference.sh
for easy server startup with environment setup.This commit establishes a robust framework for managing video generation tasks in a distributed manner.
Contributing Guidelines
We have prepared a
pre-commit
hook to enforce consistent code formatting across the project. If your code complies with the standards, you should not see any errors, you can clean up your code following the steps below:Thank you for your contributions!