diff --git a/lightx2v/api_server_dist.py b/lightx2v/api_server_dist.py new file mode 100644 index 0000000..3653e20 --- /dev/null +++ b/lightx2v/api_server_dist.py @@ -0,0 +1,666 @@ +import asyncio +import argparse +from fastapi import FastAPI, UploadFile, File +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from loguru import logger +import uvicorn +import threading +import ctypes +import gc +import torch +import os +import sys +import time +import torch.multiprocessing as mp +import queue +import torch.distributed as dist +import random +import uuid + +from lightx2v.utils.set_config import set_config +from lightx2v.infer import init_runner +from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager +import httpx +from pathlib import Path +from urllib.parse import urlparse + +# ========================= +# FastAPI Related Code +# ========================= + +runner = None +thread = None + +app = FastAPI() + +CACHE_DIR = Path(__file__).parent.parent / "cache" +INPUT_IMAGE_DIR = CACHE_DIR / "inputs" / "imgs" +OUTPUT_VIDEO_DIR = CACHE_DIR / "outputs" + +for directory in [INPUT_IMAGE_DIR, OUTPUT_VIDEO_DIR]: + directory.mkdir(parents=True, exist_ok=True) + + +class Message(BaseModel): + task_id: str + task_id_must_unique: bool = False + + prompt: str + use_prompt_enhancer: bool = False + negative_prompt: str = "" + image_path: str = "" + num_fragments: int = 1 + save_video_path: str + + def get(self, key, default=None): + return getattr(self, key, default) + + +class ApiServerServiceStatus(BaseServiceStatus): + pass + + +def download_image(image_url: str): + with httpx.Client(verify=False) as client: + response = client.get(image_url) + + image_name = Path(urlparse(image_url).path).name + if not image_name: + raise ValueError(f"Invalid image URL: {image_url}") + + image_path = INPUT_IMAGE_DIR / image_name + image_path.parent.mkdir(parents=True, exist_ok=True) + + if response.status_code == 200: + with open(image_path, "wb") as f: + f.write(response.content) + return image_path + else: + raise ValueError(f"Failed to download image from {image_url}") + + +stop_generation_event = threading.Event() + + +def local_video_generate(message: Message, stop_event: threading.Event): + try: + global input_queues, output_queues + + if input_queues is None or output_queues is None: + logger.error("分布式推理服务未启动") + ApiServerServiceStatus.record_failed_task(message, error="分布式推理服务未启动") + return + + logger.info(f"提交任务到分布式推理服务: {message.task_id}") + + # 将任务数据转换为字典 + task_data = { + "task_id": message.task_id, + "prompt": message.prompt, + "use_prompt_enhancer": message.use_prompt_enhancer, + "negative_prompt": message.negative_prompt, + "image_path": message.image_path, + "num_fragments": message.num_fragments, + "save_video_path": message.save_video_path, + } + + if message.image_path.startswith("http"): + image_path = download_image(message.image_path) + task_data["image_path"] = str(image_path) + + save_video_path = Path(message.save_video_path) + if not save_video_path.is_absolute(): + task_data["save_video_path"] = str(OUTPUT_VIDEO_DIR / message.save_video_path) + + # 将任务放入输入队列 + for input_queue in input_queues: + input_queue.put(task_data) + + # 等待结果 + timeout = 300 # 5分钟超时 + start_time = time.time() + + while time.time() - start_time < timeout: + if stop_event.is_set(): + logger.info(f"任务 {message.task_id} 收到停止信号,正在终止") + ApiServerServiceStatus.record_failed_task(message, error="任务被停止") + return + + try: + result = output_queues[0].get(timeout=1.0) + + # 检查是否是当前任务的结果 + if result.get("task_id") == message.task_id: + if result.get("status") == "success": + logger.info(f"任务 {message.task_id} 推理成功") + ApiServerServiceStatus.complete_task(message) + else: + error_msg = result.get("error", "推理失败") + logger.error(f"任务 {message.task_id} 推理失败: {error_msg}") + ApiServerServiceStatus.record_failed_task(message, error=error_msg) + return + else: + # 不是当前任务的结果,放回队列 + # 注意:如果并发任务很多,这种做法可能导致当前任务的结果被延迟。 + # 更健壮的并发结果处理需要更复杂的设计,例如每个任务有独立的输出队列。 + output_queues[0].put(result) + time.sleep(0.1) + + except queue.Empty: + # 队列为空,继续等待 + continue + + # 超时 + logger.error(f"任务 {message.task_id} 处理超时") + ApiServerServiceStatus.record_failed_task(message, error="处理超时") + + except Exception as e: + logger.error(f"任务 {message.task_id} 处理失败: {str(e)}") + ApiServerServiceStatus.record_failed_task(message, error=str(e)) + + +@app.post("/v1/local/video/generate") +async def v1_local_video_generate(message: Message): + try: + task_id = ApiServerServiceStatus.start_task(message) + # Use background threads to perform long-running tasks + global thread, stop_generation_event + stop_generation_event.clear() + thread = threading.Thread( + target=local_video_generate, + args=( + message, + stop_generation_event, + ), + daemon=True, + ) + thread.start() + return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path} + except RuntimeError as e: + return {"error": str(e)} + + +@app.post("/v1/local/video/generate_form") +async def v1_local_video_generate_form( + task_id: str, + prompt: str, + save_video_path: str, + task_id_must_unique: bool = False, + use_prompt_enhancer: bool = False, + negative_prompt: str = "", + num_fragments: int = 1, + image_file: UploadFile = File(None), +): + # 处理上传的图片文件 + image_path = "" + if image_file and image_file.filename: + # 生成唯一的文件名 + file_extension = Path(image_file.filename).suffix + unique_filename = f"{uuid.uuid4()}{file_extension}" + image_path = INPUT_IMAGE_DIR / unique_filename + + # 保存上传的文件 + with open(image_path, "wb") as buffer: + content = await image_file.read() + buffer.write(content) + + image_path = str(image_path) + + message = Message( + task_id=task_id, + task_id_must_unique=task_id_must_unique, + prompt=prompt, + use_prompt_enhancer=use_prompt_enhancer, + negative_prompt=negative_prompt, + image_path=image_path, + num_fragments=num_fragments, + save_video_path=save_video_path, + ) + try: + task_id = ApiServerServiceStatus.start_task(message) + global thread, stop_generation_event + stop_generation_event.clear() + thread = threading.Thread( + target=local_video_generate, + args=( + message, + stop_generation_event, + ), + daemon=True, + ) + thread.start() + return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path} + except RuntimeError as e: + return {"error": str(e)} + + +@app.get("/v1/local/video/generate/service_status") +async def get_service_status(): + return ApiServerServiceStatus.get_status_service() + + +@app.get("/v1/local/video/generate/get_all_tasks") +async def get_all_tasks(): + return ApiServerServiceStatus.get_all_tasks() + + +@app.post("/v1/local/video/generate/task_status") +async def get_task_status(message: TaskStatusMessage): + return ApiServerServiceStatus.get_status_task_id(message.task_id) + + +@app.get("/v1/local/video/generate/get_task_result") +async def get_task_result(message: TaskStatusMessage): + result = ApiServerServiceStatus.get_status_task_id(message.task_id) + # 传输save_video_path内容到外部 + save_video_path = result.get("save_video_path") + + if save_video_path and Path(save_video_path).is_absolute() and Path(save_video_path).exists(): + file_path = Path(save_video_path) + relative_path = file_path.relative_to(OUTPUT_VIDEO_DIR.resolve()) if str(file_path).startswith(str(OUTPUT_VIDEO_DIR.resolve())) else file_path.name + return { + "status": "success", + "task_status": result.get("status", "unknown"), + "filename": file_path.name, + "file_size": file_path.stat().st_size, + "download_url": f"/v1/file/download/{relative_path}", + "message": "任务结果已准备就绪", + } + elif save_video_path and not Path(save_video_path).is_absolute(): + video_path = OUTPUT_VIDEO_DIR / save_video_path + if video_path.exists(): + return { + "status": "success", + "task_status": result.get("status", "unknown"), + "filename": video_path.name, + "file_size": video_path.stat().st_size, + "download_url": f"/v1/file/download/{save_video_path}", + "message": "任务结果已准备就绪", + } + + return {"status": "not_found", "message": "Task result not found", "task_status": result.get("status", "unknown")} + + +def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024): + """文件流生成器,逐块读取文件""" + with open(file_path, "rb") as file: + while chunk := file.read(chunk_size): + yield chunk + + +@app.get( + "/v1/file/download/{file_path:path}", + response_class=StreamingResponse, + summary="下载文件", + description="流式下载指定的文件", + responses={200: {"description": "文件下载成功", "content": {"application/octet-stream": {}}}, 404: {"description": "文件未找到"}, 500: {"description": "服务器错误"}}, +) +async def download_file(file_path: str): + try: + full_path = OUTPUT_VIDEO_DIR / file_path + resolved_path = full_path.resolve() + + # 安全检查:确保文件在允许的目录内 + if not str(resolved_path).startswith(str(OUTPUT_VIDEO_DIR.resolve())): + return {"status": "forbidden", "message": "不允许访问该文件"} + + if resolved_path.exists() and resolved_path.is_file(): + file_size = resolved_path.stat().st_size + filename = resolved_path.name + + # 设置适当的 MIME 类型 + mime_type = "application/octet-stream" + if filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")): + mime_type = "video/mp4" + elif filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")): + mime_type = "image/jpeg" + + headers = { + "Content-Disposition": f'attachment; filename="{filename}"', + "Content-Length": str(file_size), + "Accept-Ranges": "bytes", + } + + return StreamingResponse(file_stream_generator(str(resolved_path)), media_type=mime_type, headers=headers) + else: + return {"status": "not_found", "message": f"文件未找到: {file_path}"} + except Exception as e: + logger.error(f"处理文件下载请求时发生错误: {e}") + return {"status": "error", "message": "文件下载失败"} + + +@app.get("/v1/local/video/generate/stop_running_task") +async def stop_running_task(): + global thread, stop_generation_event + if thread and thread.is_alive(): + try: + logger.info("正在发送停止信号给运行中的任务线程...") + stop_generation_event.set() # 设置事件,通知线程停止 + thread.join(timeout=5) # 等待线程结束,设置超时时间 + + if thread.is_alive(): + logger.warning("任务线程未在规定时间内停止,可能需要手动干预。") + return {"stop_status": "warning", "reason": "任务线程未在规定时间内停止,可能需要手动干预。"} + else: + # 清理线程引用 + thread = None + ApiServerServiceStatus.clean_stopped_task() + gc.collect() + torch.cuda.empty_cache() + logger.info("任务已成功停止。") + return {"stop_status": "success", "reason": "Task stopped successfully."} + except Exception as e: + logger.error(f"停止任务时发生错误: {str(e)}") + return {"stop_status": "error", "reason": str(e)} + else: + return {"stop_status": "do_nothing", "reason": "No running task found."} + + +# 使用多进程队列进行通信 +input_queues = [] +output_queues = [] +distributed_runners = [] + + +def distributed_inference_worker(rank, world_size, master_addr, master_port, args, input_queue, output_queue): + """分布式推理服务工作进程""" + try: + # 设置环境变量 + os.environ["TOKENIZERS_PARALLELISM"] = "false" + os.environ["ENABLE_PROFILING_DEBUG"] = "true" + os.environ["ENABLE_GRAPH_MODE"] = "false" + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["CUDA_VISIBLE_DEVICES"] = str(rank) + + logger.info(f"进程 {rank}/{world_size - 1} 正在初始化分布式推理服务...") + + dist.init_process_group(backend="nccl", init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size) + + config = set_config(args) + config["mode"] = "server" + logger.info(f"config: {config}") + runner = init_runner(config) + + logger.info(f"进程 {rank}/{world_size - 1} 分布式推理服务初始化完成,等待任务...") + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + while True: + try: + task_data = input_queue.get(timeout=1.0) # 1秒超时 + if task_data is None: # 停止信号 + logger.info(f"进程 {rank}/{world_size - 1} 收到停止信号,退出推理服务") + break + logger.info(f"进程 {rank}/{world_size - 1} 收到推理任务: {task_data['task_id']}") + + runner.set_inputs(task_data) + + # 运行推理,复用已创建的事件循环 + try: + loop.run_until_complete(runner.run_pipeline()) + + # 只有 Rank 0 负责将结果放入输出队列,避免重复 + if rank == 0: + result = {"task_id": task_data["task_id"], "status": "success", "save_video_path": task_data["save_video_path"], "message": "推理完成"} + output_queue.put(result) + logger.info(f"任务 {task_data['task_id']} 处理完成 (由 Rank 0 报告)") + if dist.is_initialized(): + dist.barrier() + + except Exception as e: + # 只有 Rank 0 负责报告错误 + if rank == 0: + result = {"task_id": task_data["task_id"], "status": "failed", "error": str(e), "message": f"推理失败: {str(e)}"} + output_queue.put(result) + logger.error(f"任务 {task_data['task_id']} 推理失败: {str(e)} (由 Rank 0 报告)") + if dist.is_initialized(): + dist.barrier() + + except queue.Empty: + # 队列为空,继续等待 + continue + except KeyboardInterrupt: + logger.info(f"进程 {rank}/{world_size - 1} 收到 KeyboardInterrupt,优雅退出") + break + except Exception as e: + logger.error(f"进程 {rank}/{world_size - 1} 处理任务时发生错误: {str(e)}") + # 只有 Rank 0 负责发送错误结果 + task_data = task_data if "task_data" in locals() else {} + if rank == 0: + error_result = { + "task_id": task_data.get("task_id", "unknown"), + "status": "error", + "error": str(e), + "message": f"处理任务时发生错误: {str(e)}", + } + try: + output_queue.put(error_result) + except: # noqa: E722 + pass + if dist.is_initialized(): + try: + dist.barrier() + except: # noqa: E722 + pass + + except KeyboardInterrupt: + logger.info(f"进程 {rank}/{world_size - 1} 主循环收到 KeyboardInterrupt,正在退出") + except Exception as e: + logger.error(f"分布式推理服务进程 {rank}/{world_size - 1} 启动失败: {str(e)}") + # 只有 Rank 0 负责报告启动失败 + if rank == 0: + try: + error_result = {"task_id": "startup", "status": "startup_failed", "error": str(e), "message": f"推理服务启动失败: {str(e)}"} + output_queue.put(error_result) + except: # noqa: E722 + pass + # 在进程最终退出时关闭事件循环和销毁分布式组 + finally: + try: + if "loop" in locals() and loop and not loop.is_closed(): + loop.close() + except: # noqa: E722 + pass + try: + if dist.is_initialized(): + dist.destroy_process_group() + except: # noqa: E722 + pass + + +def start_distributed_inference_with_queue(args): + """使用队列启动分布式推理服务,并模拟torchrun的多进程模式""" + global input_queues, output_queues, distributed_runners + + nproc_per_node = args.nproc_per_node + + if nproc_per_node <= 0: + logger.error("nproc_per_node 必须大于0") + return False + + try: + master_addr = "127.0.0.1" + master_port = str(random.randint(20000, 29999)) + logger.info(f"分布式推理服务 Master Addr: {master_addr}, Master Port: {master_port}") + + processes = [] + + ctx = mp.get_context("spawn") + for rank in range(nproc_per_node): + input_queue = ctx.Queue() + output_queue = ctx.Queue() + p = ctx.Process(target=distributed_inference_worker, args=(rank, nproc_per_node, master_addr, master_port, args, input_queue, output_queue), daemon=True) + + p.start() + processes.append(p) + input_queues.append(input_queue) + output_queues.append(output_queue) + + distributed_runners = processes + return True + + except Exception as e: + logger.exception(f"启动分布式推理服务时发生错误: {str(e)}") + stop_distributed_inference_with_queue() + return False + + +def stop_distributed_inference_with_queue(): + """停止分布式推理服务""" + global input_queues, output_queues, distributed_runners + + try: + if distributed_runners: + logger.info(f"正在停止 {len(distributed_runners)} 个分布式推理服务进程...") + + # 向所有工作进程发送停止信号 + if input_queues: + for input_queue in input_queues: + try: + input_queue.put(None) + except: # noqa: E722 + pass + + # 等待所有进程结束 + for p in distributed_runners: + try: + p.join(timeout=10) + except: # noqa: E722 + pass + + # 强制终止任何未结束的进程 + for p in distributed_runners: + try: + if p.is_alive(): + logger.warning(f"推理服务进程 {p.pid} 未在规定时间内结束,强制终止...") + p.terminate() + p.join(timeout=5) + except: # noqa: E722 + pass + + logger.info("所有分布式推理服务进程已停止") + + # 清理队列 + if input_queues: + try: + for input_queue in input_queues: + try: + while not input_queue.empty(): + input_queue.get_nowait() + except: # noqa: E722 + pass + except: # noqa: E722 + pass + + if output_queues: + try: + for output_queue in output_queues: + try: + while not output_queue.empty(): + output_queue.get_nowait() + except: # noqa: E722 + pass + except: # noqa: E722 + pass + + distributed_runners = [] + input_queues = [] + output_queues = [] + + except Exception as e: + logger.error(f"停止分布式推理服务时发生错误: {str(e)}") + except KeyboardInterrupt: + logger.info("停止分布式推理服务时收到 KeyboardInterrupt,强制清理") + + +# ========================= +# Main Entry +# ========================= + +if __name__ == "__main__": + global startup_args + + ProcessManager.register_signal_handler() + parser = argparse.ArgumentParser() + parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan") + parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--config_json", type=str, required=True) + parser.add_argument("--split", action="store_true") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--seed", type=int, default=42) + + parser.add_argument("--start_inference", action="store_true", help="是否在启动API服务器前启动分布式推理服务") + parser.add_argument("--nproc_per_node", type=int, default=4, help="分布式推理时每个节点的进程数") + + args = parser.parse_args() + logger.info(f"args: {args}") + + # 保存启动参数供重启功能使用 + startup_args = args + + if args.start_inference: + logger.info("正在启动分布式推理服务...") + success = start_distributed_inference_with_queue(args) + if not success: + logger.error("分布式推理服务启动失败,退出程序") + sys.exit(1) + + # 注册程序退出时的清理函数 + import atexit + + atexit.register(stop_distributed_inference_with_queue) + + # 注册信号处理器,用于优雅关闭 + import signal + + def signal_handler(signum, frame): + logger.info(f"接收到信号 {signum},正在优雅关闭...") + try: + stop_distributed_inference_with_queue() + except: # noqa: E722 + logger.error("关闭分布式推理服务时发生错误") + finally: + sys.exit(0) + + signal.signal(signal.SIGTERM, signal_handler) + signal.signal(signal.SIGINT, signal_handler) + + try: + logger.info(f"正在启动FastAPI服务器,端口: {args.port}") + uvicorn.run(app, host="0.0.0.0", port=args.port, reload=False, workers=1) + except KeyboardInterrupt: + logger.info("接收到KeyboardInterrupt,正在关闭服务...") + except Exception as e: + logger.error(f"FastAPI服务器运行时发生错误: {str(e)}") + finally: + # 确保在程序结束时停止推理服务 + if args.start_inference: + stop_distributed_inference_with_queue() + +""" +curl -X 'POST' \ + 'http://localhost:8000/v1/local/video/generate_form?task_id=abc&prompt=%E8%B7%B3%E8%88%9E&save_video_path=a.mp4&task_id_must_unique=false&use_prompt_enhancer=false&num_fragments=1' \ + -H 'accept: application/json' \ + -H 'Content-Type: multipart/form-data' \ + -F 'image_file=@图片1.png;type=image/png' + +curl -X 'POST' \ + 'http://localhost:8000/v1/local/video/generate' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "task_id": "abcde", + "task_id_must_unique": false, + "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline'\''s intricate details and the refreshing atmosphere of the seaside.", + "use_prompt_enhancer": false, + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "/mnt/aigc/users/gaopeng1/ComfyUI-Lightx2vWrapper/lightx2v/assets/inputs/imgs/img_0.jpg", + "num_fragments": 1, + "save_video_path": "/mnt/aigc/users/lijiaqi2/ComfyUI/custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v/save_results/img_0.mp4" +}' + +""" diff --git a/lightx2v/infer.py b/lightx2v/infer.py index 471e704..76592d7 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -25,7 +25,8 @@ def init_runner(config): seed_all(config.seed) if config.parallel_attn_type: - dist.init_process_group(backend="nccl") + if not dist.is_initialized(): + dist.init_process_group(backend="nccl") if CHECK_ENABLE_GRAPH_MODE(): default_runner = RUNNER_REGISTER[config.model_cls](config) diff --git a/scripts/start_api_with_dist_inference.sh b/scripts/start_api_with_dist_inference.sh new file mode 100755 index 0000000..0fa098a --- /dev/null +++ b/scripts/start_api_with_dist_inference.sh @@ -0,0 +1,47 @@ +#!/bin/bash + +# 设置路径 +lightx2v_path=/mnt/aigc/users/lijiaqi2/ComfyUI/custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v +model_path=/mnt/aigc/users/lijiaqi2/wan_model/Wan2.1-I2V-14B-720P-cfg + +# 检查参数 +if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then + cuda_devices=2,3 + echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}" + export CUDA_VISIBLE_DEVICES=${cuda_devices} +fi + +if [ -z "${lightx2v_path}" ]; then + echo "Error: lightx2v_path is not set. Please set this variable first." + exit 1 +fi + +if [ -z "${model_path}" ]; then + echo "Error: model_path is not set. Please set this variable first." + exit 1 +fi + +# 设置环境变量 +export TOKENIZERS_PARALLELISM=false +export PYTHONPATH=${lightx2v_path}:$PYTHONPATH +export ENABLE_PROFILING_DEBUG=true +export ENABLE_GRAPH_MODE=false + +echo "==========================================" +echo "启动分布式推理API服务器" +echo "模型路径: $model_path" +echo "CUDA设备: $CUDA_VISIBLE_DEVICES" +echo "API端口: 8000" +echo "==========================================" + +# 启动API服务器,同时启动分布式推理服务 +python -m lightx2v.api_server_dist \ +--model_cls wan2.1 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/wan_i2v_dist.json \ +--port 8000 \ +--start_inference \ +--nproc_per_node 2 + +echo "服务已停止"