Skip to content

Implement distributed inference API server with FastAPI #60

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 9, 2025

Conversation

GACLove
Copy link
Contributor

@GACLove GACLove commented Jun 6, 2025

  • Added a new file api_server_dist.py to handle video generation tasks using distributed inference.
  • Introduced endpoints for task submission, status checking, and result retrieval.
  • Implemented image downloading and task management with error handling.
  • Enhanced infer.py to ensure proper initialization of distributed processes.
  • Created a shell script start_api_with_dist_inference.sh for easy server startup with environment setup.

This commit establishes a robust framework for managing video generation tasks in a distributed manner.

Contributing Guidelines

We have prepared a pre-commit hook to enforce consistent code formatting across the project. If your code complies with the standards, you should not see any errors, you can clean up your code following the steps below:

  1. Install the required dependencies:
    pip install ruff pre-commit
  1. Then, run the following command before commit:
    pre-commit run --all-files
  1. Finally, please double-check your code to ensure it complies with the following additional specifications as much as possible:
  • Avoid hard-coding local paths: Make sure your submissions do not include hard-coded local paths, as these paths are specific to individual development environments and can cause compatibility issues. Use relative paths or configuration files instead.
  • Clear error handling: Implement clear error-handling mechanisms in your code so that error messages can accurately indicate the location of the problem, possible causes, and suggested solutions, facilitating quick debugging.
  • Detailed comments and documentation: Add comments to complex code sections and provide comprehensive documentation to explain the functionality of the code, input-output requirements, and potential error scenarios.

Thank you for your contributions!

- Added a new file `api_server_dist.py` to handle video generation tasks using distributed inference.
- Introduced endpoints for task submission, status checking, and result retrieval.
- Implemented image downloading and task management with error handling.
- Enhanced `infer.py` to ensure proper initialization of distributed processes.
- Created a shell script `start_api_with_dist_inference.sh` for easy server startup with environment setup.

This commit establishes a robust framework for managing video generation tasks in a distributed manner.
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @GACLove, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

Hello team, gemini-code-assist here to provide a summary of this pull request. This PR introduces a new FastAPI-based API server designed to handle video generation tasks using distributed inference. The core idea is to provide a web interface for submitting tasks, managing their lifecycle, and retrieving results, leveraging multiple processes for the actual inference work. It includes a new main script (api_server_dist.py), a small but necessary adjustment to the existing inference initialization logic (infer.py), and a convenience shell script to launch the server with distributed inference enabled.

Highlights

  • New Distributed Inference API Server: A new Python script (lightx2v/api_server_dist.py) is added, implementing a FastAPI application. This server acts as the entry point for submitting video generation tasks.
  • FastAPI Endpoints: The new API server exposes several endpoints:
  • /v1/local/video/generate (POST): Submit a new video generation task.
  • /v1/local/video/generate/service_status (GET): Check the overall status of the API service.
  • /v1/local/video/generate/get_all_tasks (GET): Retrieve the status of all submitted tasks.
  • /v1/local/video/generate/task_status (POST): Get the status of a specific task by ID.
  • /v1/local/video/generate/get_task_result (GET): Download the resulting video file for a completed task.
  • /v1/file/download (GET): A general endpoint to download files (potentially for results).
  • /v1/local/video/generate/stop_running_task (GET): Attempt to stop the currently running task.
  • Multiprocessing for Distributed Inference: The API server utilizes torch.multiprocessing with queues (input_queues, output_queues) to communicate with separate worker processes. Each worker process initializes a distributed inference runner (init_runner) and processes tasks received from the input queue, sending results or errors back via the output queue. This mimics the torchrun multi-process setup.
  • Task Management and Status Tracking: The server implements basic task management, tracking the status (processing, success, failed) of submitted tasks using a BaseServiceStatus utility (presumably from lightx2v.utils.service_utils). It also handles downloading remote images specified in task requests.
  • Inference Initialization Fix: A small change in lightx2v/infer.py ensures that dist.init_process_group is only called if a process group is not already initialized. This is crucial when the distributed processes are managed externally by the new API server's multiprocessing setup.
  • Startup Script: A new shell script (scripts/start_api_with_dist_inference.sh) is added to simplify launching the API server with distributed inference workers. It sets necessary environment variables and calls the api_server_dist.py script with appropriate arguments.

Changelog

Click here to see the changelog
  • lightx2v/api_server_dist.py
    • Added a new file implementing a FastAPI server.
    • Defined a Message Pydantic model for task submission payload (lines 42-54).
    • Implemented download_image function to handle image URLs (lines 61-78).
    • Implemented local_video_generate function to submit tasks to distributed workers via queues and wait for results (lines 80-148).
    • Added FastAPI endpoints for task submission, status checking, result retrieval, and file download (lines 150-204).
    • Added an endpoint to stop the currently running task using thread signaling (lines 218-235).
    • Implemented distributed_inference_worker function for the multiprocessing worker processes, handling distributed initialization and task processing loops (lines 242-339).
    • Implemented start_distributed_inference_with_queue to launch worker processes using torch.multiprocessing and queues (lines 341-380).
    • Implemented stop_distributed_inference_with_queue for graceful shutdown of worker processes (lines 383-431).
    • Added main execution block to parse arguments, optionally start distributed inference workers, and run the FastAPI server (lines 438-495).
    • Included signal handling and atexit registration for cleanup (lines 470-482, 492-495).
  • lightx2v/infer.py
    • Modified init_runner to check if a distributed process group is already initialized before calling dist.init_process_group (lines 28-29).
  • scripts/start_api_with_dist_inference.sh
    • Added a new shell script to set up environment variables (CUDA_VISIBLE_DEVICES, PYTHONPATH, etc.) (lines 1-28).
    • Added checks for required path variables (lightx2v_path, model_path) (lines 14-22).
    • Added commands to execute the lightx2v.api_server_dist module with specific arguments to start the API server and distributed inference (lines 38-45).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new FastAPI server (api_server_dist.py) designed to manage distributed inference tasks for video generation. It includes endpoints for task submission, status checking, and result retrieval, leveraging torch.multiprocessing for distributed execution. A shell script (start_api_with_dist_inference.sh) is provided for easy startup. The changes in infer.py correctly handle the distributed process group initialization.

The overall structure for integrating FastAPI with distributed PyTorch inference via queues is a reasonable approach. However, the implementation has several areas that require attention, particularly regarding error handling, task management logic, and security.

Summary of Findings

  • Security Vulnerabilities (Critical): The /v1/file/download endpoint and the result retrieval logic in /v1/local/video/generate/get_task_result are vulnerable to path traversal, allowing potential access to arbitrary files on the server. The use of verify=False in download_image is also a security risk.
  • Unsafe Thread Termination (Critical): The _async_raise function and the /v1/local/video/generate/stop_running_task endpoint use ctypes to force thread termination, which is unsafe and can lead to instability, resource leaks, or crashes.
  • Task Result Handling Logic (High): The logic for retrieving task results from the output queue in local_video_generate is flawed and can lead to unprocessed results or race conditions if multiple tasks are active.
  • Hardcoded Paths (Medium): The startup script contains hardcoded local paths, violating the contributing guidelines and reducing portability.
  • Inefficient/Problematic Asyncio Usage (Medium): Creating and setting a new asyncio event loop for each task within the worker process is inefficient and not the standard way to use asyncio.
  • Broad Exception Handling (Medium): Bare except: blocks are used in the cleanup logic, which can hide unexpected errors.

Merge Readiness

This pull request introduces critical security vulnerabilities (path traversal, unsafe thread termination) and has significant issues in the task result handling logic. These issues must be addressed before the code can be considered for merging. Additionally, there are medium severity issues related to hardcoded paths, asyncio usage, and broad exception handling that should also be fixed to improve code quality and maintainability. I am unable to approve this pull request; please address the critical and high severity issues and request another review.

Comment on lines 218 to 235
@app.get("/v1/local/video/generate/stop_running_task")
async def stop_running_task():
global thread
if thread and thread.is_alive():
try:
_async_raise(thread.ident, SystemExit)
thread.join()

# Clean up the thread reference
thread = None
ApiServerServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
return {"stop_status": "success", "reason": "Task stopped successfully."}
except Exception as e:
return {"stop_status": "error", "reason": str(e)}
else:
return {"stop_status": "do_nothing", "reason": "No running task found."}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This endpoint uses the unsafe _async_raise function to terminate the background task thread. As noted in the comment on _async_raise, this can lead to unpredictable behavior, resource leaks, and crashes. The thread.join() call after _async_raise might hang indefinitely if the thread is stuck in a state where it cannot process the injected exception.

This approach to stopping a task is fundamentally unsafe and should be replaced with a cooperative cancellation mechanism.

Comment on lines 177 to 190
@app.get("/v1/local/video/generate/get_task_result")
async def get_task_result(message: TaskStatusMessage):
result = ApiServerServiceStatus.get_status_task_id(message.task_id)
# 传输save_video_path内容到外部
save_video_path = result.get("save_video_path")

if save_video_path and Path(save_video_path).is_absolute() and Path(save_video_path).exists():
return FileResponse(save_video_path)
elif save_video_path and not Path(save_video_path).is_absolute():
video_path = OUTPUT_VIDEO_DIR / save_video_path
if video_path.exists():
return FileResponse(video_path)

return {"status": "not_found", "message": "Task result not found"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The /v1/local/video/generate/get_task_result endpoint takes a TaskStatusMessage which only contains a task_id. It then retrieves the task status and attempts to return a FileResponse based on the save_video_path stored in the task result.

This endpoint is vulnerable to path traversal if the save_video_path stored in the task result can be controlled by the user or contains relative path components (..). An attacker could potentially craft a save_video_path that points to arbitrary files on the server's filesystem, allowing them to download sensitive files.

While the code checks Path(save_video_path).is_absolute() and prepends OUTPUT_VIDEO_DIR for relative paths, this doesn't fully mitigate the risk if the stored path contains .. components within the relative path or if the OUTPUT_VIDEO_DIR itself is not sufficiently restricted.

To mitigate this, you should:

  1. Strictly validate the save_video_path format and content when the task is submitted.
  2. Ensure that the resolved path (OUTPUT_VIDEO_DIR / save_video_path for relative paths) is strictly contained within the OUTPUT_VIDEO_DIR using pathlib.Path.resolve() and checking if it's a child of the output directory.
  3. Consider returning the video content directly or using a more secure method for serving files, rather than relying solely on user-provided paths stored in the task result.

Comment on lines 192 to 204
@app.get("/v1/file/download")
async def download_file(file_path: str):
if Path(file_path).is_absolute():
if Path(file_path).exists():
return FileResponse(file_path)
else:
return {"status": "not_found", "message": f"File not found: {file_path}"}
else:
file_path_obj = OUTPUT_VIDEO_DIR / file_path
if file_path_obj.exists():
return FileResponse(file_path_obj)
else:
return {"status": "not_found", "message": f"File not found: {file_path}"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The /v1/file/download endpoint takes an arbitrary file_path string and directly serves the file using FileResponse if it exists. This is a severe path traversal vulnerability.

An attacker can provide a file_path like /etc/passwd or ../../../../etc/passwd to download arbitrary files from the server's filesystem, regardless of whether the path is absolute or relative to OUTPUT_VIDEO_DIR.

This endpoint should be removed or heavily restricted. If the intention is only to serve files from the OUTPUT_VIDEO_DIR, the endpoint should:

  1. Only accept relative paths.
  2. Strictly validate the path to prevent .. or other traversal attempts.
  3. Resolve the path relative to OUTPUT_VIDEO_DIR and verify that the final resolved path is a child of OUTPUT_VIDEO_DIR before serving the file.

Without these checks, this endpoint exposes the entire filesystem to anyone who can access the API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/fix

Comment on lines 208 to 215
def _async_raise(tid, exctype):
"""Force thread tid to raise exception exctype"""
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype))
if res == 0:
raise ValueError("Invalid thread ID")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _async_raise function uses ctypes to inject an exception into another thread. This is a highly dangerous and unreliable method for thread termination.

Injecting exceptions can leave resources (like locks, file handles, network connections) in an inconsistent state, leading to deadlocks, resource leaks, or crashes. Python's threading model is not designed for external, forceful termination like this.

Python's recommended way to stop a thread is to signal it to stop gracefully (e.g., by setting an event or flag that the thread periodically checks) and let the thread clean up and exit on its own. While this requires cooperation from the thread's target function, it's the only safe way.

Using _async_raise for the stop_running_task endpoint is a critical issue due to the potential for instability and crashes.

Comment on lines +62 to +63
with httpx.Client(verify=False) as client:
response = client.get(image_url)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using verify=False in httpx.Client disables SSL certificate verification. This is a significant security risk as it makes the application vulnerable to man-in-the-middle attacks when downloading images from external URLs. It should only be used in development/testing environments, never in production. Consider adding a configuration option to control this or removing it entirely and ensuring proper certificate handling.

Comment on lines 335 to 336
if 'loop' in locals() and loop and not loop.is_closed():
loop.close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The finally block attempts to close the loop if it exists and is not closed. However, the loop variable is defined inside the while True loop (lines 274-275). If an exception occurs before the loop is created (e.g., during initialization), the loop variable won't exist, and this check will fail with a NameError. The if 'loop' in locals() check helps, but it's still tied to the inner loop's variable.

The event loop management should be outside the task processing loop, created once when the worker starts and closed when the worker exits.

Comment on lines 415 to 416
except:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The except: block here is too broad. Catching all exceptions (except:) without specifying the exception type is generally bad practice as it can hide unexpected errors (like KeyboardInterrupt, SystemExit, NameError, etc.) and make debugging difficult. It's better to catch specific exceptions you expect (e.g., queue.Empty if that's the only expected error during cleanup) or at least log the exception type and traceback.

Comment on lines 423 to 424
except:
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, using a bare except: block here is too broad and hides potential issues during output queue cleanup.

Comment on lines 242 to 339
def distributed_inference_worker(rank, world_size, master_addr, master_port, args, input_queue, output_queue):
"""分布式推理服务工作进程"""
try:
# 设置环境变量
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["ENABLE_PROFILING_DEBUG"] = "true"
os.environ["ENABLE_GRAPH_MODE"] = "false"
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)

logger.info(f"进程 {rank}/{world_size-1} 正在初始化分布式推理服务...")

dist.init_process_group(backend='nccl', init_method=f'tcp://{master_addr}:{master_port}', rank=rank, world_size=world_size)

config = set_config(args)
config["mode"] = "server"
logger.info(f"config: {config}")
runner = init_runner(config)

logger.info(f"进程 {rank}/{world_size-1} 分布式推理服务初始化完成,等待任务...")


while True:
try:
task_data = input_queue.get(timeout=1.0) # 1秒超时
if task_data is None: # 停止信号
logger.info(f"进程 {rank}/{world_size-1} 收到停止信号,退出推理服务")
break
logger.info(f"进程 {rank}/{world_size-1} 收到推理任务: {task_data['task_id']}")

runner.set_inputs(task_data)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

# 运行推理,复用已创建的事件循环
try:
loop.run_until_complete(runner.run_pipeline())

# 只有 Rank 0 负责将结果放入输出队列,避免重复
if rank == 0:
result = {
"task_id": task_data["task_id"],
"status": "success",
"save_video_path": task_data["save_video_path"],
"message": "推理完成"
}
output_queue.put(result)
logger.info(f"任务 {task_data['task_id']} 处理完成 (由 Rank 0 报告)")
dist.barrier()

except Exception as e:
# 只有 Rank 0 负责报告错误
if rank == 0:
result = {
"task_id": task_data["task_id"],
"status": "failed",
"error": str(e),
"message": f"推理失败: {str(e)}"
}
output_queue.put(result)
logger.error(f"任务 {task_data['task_id']} 推理失败: {str(e)} (由 Rank 0 报告)")
dist.barrier()

except queue.Empty:
# 队列为空,继续等待
continue
except Exception as e:
logger.error(f"进程 {rank}/{world_size-1} 处理任务时发生错误: {str(e)}")
# 只有 Rank 0 负责发送错误结果
if rank == 0:
error_result = {
"task_id": task_data.get("task_id", "unknown") if 'task_data' in locals() else "unknown",
"status": "error",
"error": str(e),
"message": f"处理任务时发生错误: {str(e)}"
}
output_queue.put(error_result)
dist.barrier()

except Exception as e:
logger.error(f"分布式推理服务进程 {rank}/{world_size-1} 启动失败: {str(e)}")
# 只有 Rank 0 负责报告启动失败
if rank == 0:
error_result = {
"task_id": "startup",
"status": "startup_failed",
"error": str(e),
"message": f"推理服务启动失败: {str(e)}"
}
output_queue.put(error_result)
# 在进程最终退出时关闭事件循环和销毁分布式组
finally:
if 'loop' in locals() and loop and not loop.is_closed():
loop.close()
if dist.is_initialized():
dist.destroy_process_group()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The distributed_inference_worker function sets CUDA_VISIBLE_DEVICES to str(rank). This assumes a specific mapping where each process corresponds to a single, distinct GPU index equal to its rank. This might not be true in all distributed setups, especially if CUDA_VISIBLE_DEVICES is already set externally to specify a subset of available GPUs.

A more robust approach would be to use torch.cuda.set_device(rank) after dist.init_process_group if the process group is initialized across the devices specified by the original CUDA_VISIBLE_DEVICES environment variable. If torchrun or a similar launcher is used, it typically handles setting LOCAL_RANK and RANK and managing device assignment. If launching manually, the script or launcher should manage CUDA_VISIBLE_DEVICES or the worker should use os.environ['LOCAL_RANK'] to determine its assigned device index within the visible devices.

Comment on lines +8 to +12
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=2,3
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}"
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block checks if CUDA_VISIBLE_DEVICES is set and provides a default value (2,3) if it's not. While helpful, hardcoding specific device IDs like 2,3 might not be appropriate for all environments. Consider making the default configurable or providing instructions on how users should set this variable externally.

GACLove added 4 commits June 6, 2025 15:04
… handling

- Introduced a new endpoint `/v1/local/video/generate_form` for video generation that accepts image uploads.
- Implemented unique filename generation for uploaded images to prevent conflicts.
- Enhanced directory management for input and output paths.
- Improved file download response with detailed status and size information.
- Added error handling for distributed inference processes and graceful shutdown procedures.
@helloyongyang helloyongyang merged commit f471bb9 into ModelTC:main Jun 9, 2025
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants