diff --git a/app.py b/app.py index 2447929..2a7c788 100644 --- a/app.py +++ b/app.py @@ -8,55 +8,22 @@ import tempfile import torch from diffusers import FluxTransformer2DModel, FluxFillPipeline +import subprocess -import shutil +subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True) +dtype = torch.bfloat16 +device = "cuda" if torch.cuda.is_available() else "cpu" -def find_cuda(): - # Check if CUDA_HOME or CUDA_PATH environment variables are set - cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') - - if cuda_home and os.path.exists(cuda_home): - return cuda_home - - # Search for the nvcc executable in the system's PATH - nvcc_path = shutil.which('nvcc') - - if nvcc_path: - # Remove the 'bin/nvcc' part to get the CUDA installation path - cuda_path = os.path.dirname(os.path.dirname(nvcc_path)) - return cuda_path - - return None - -cuda_path = find_cuda() - -if cuda_path: - print(f"CUDA installation found at: {cuda_path}") -else: - print("CUDA installation not found") - -device = torch.device('cuda') - -print("Start loading LoRA weights") -state_dict, network_alphas = FluxFillPipeline.lora_state_dict( - pretrained_model_name_or_path_or_dict="xiaozaa/catvton-flux-lora-alpha", ## The tryon Lora weights - weight_name="pytorch_lora_weights.safetensors", - return_alphas=True -) -is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys()) -if not is_correct_format: - raise ValueError("Invalid LoRA checkpoint.") print('Loading diffusion model ...') +transformer = FluxTransformer2DModel.from_pretrained( + "xiaozaa/catvton-flux-alpha", + torch_dtype=device +) pipe = FluxFillPipeline.from_pretrained( - "black-forest-labs/FLUX.1-Fill-dev", - torch_dtype=torch.bfloat16 + "black-forest-labs/FLUX.1-dev", + transformer=transformer, + torch_dtype=device ).to(device) -FluxFillPipeline.load_lora_into_transformer( - state_dict=state_dict, - network_alphas=network_alphas, - transformer=pipe.transformer, -) - print('Loading Finished!') @spaces.GPU @@ -109,7 +76,7 @@ def gradio_inference( with gr.Blocks() as demo: gr.Markdown(""" - # CATVTON FLUX Virtual Try-On Demo (by using LoRA weights) + # CATVTON FLUX Virtual Try-On Demo Upload a model image, draw a mask, and a garment image to generate virtual try-on results. [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/xiaozaa/catvton-flux-alpha) @@ -222,4 +189,4 @@ def gradio_inference( ) -demo.launch() \ No newline at end of file +demo.launch()