From 2c821a470c9dca04534bea6e3b8804113ef5855d Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 09:26:53 +0200
Subject: [PATCH 01/28] Updated gradio interface and models
---
.gitignore | 5 ++
CKPT_PTH.py | 8 +-
gradio_demo.py | 186 ++++++++++++++++++++++++------------------
options/SUPIR_v0.yaml | 6 +-
requirements.txt | 3 -
5 files changed, 117 insertions(+), 91 deletions(-)
create mode 100644 .gitignore
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..3c8528f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+models
+__pycache__
+venv
+.idea
+.vs
diff --git a/CKPT_PTH.py b/CKPT_PTH.py
index 4ff9ebf..c02baf5 100644
--- a/CKPT_PTH.py
+++ b/CKPT_PTH.py
@@ -1,4 +1,4 @@
-LLAVA_CLIP_PATH = '/opt/data/private/AIGC_pretrain/LLaVA1.5/clip-vit-large-patch14-336'
-LLAVA_MODEL_PATH = '/opt/data/private/AIGC_pretrain/LLaVA1.5/llava-v1.5-13b'
-SDXL_CLIP1_PATH = '/opt/data/private/AIGC_pretrain/clip-vit-large-patch14'
-SDXL_CLIP2_CKPT_PTH = '/opt/data/private/AIGC_pretrain/CLIP-ViT-bigG-14-laion2B-39B-b160k/open_clip_pytorch_model.bin'
\ No newline at end of file
+LLAVA_CLIP_PATH = 'openai/clip-vit-large-patch14-336'
+LLAVA_MODEL_PATH = 'liuhaotian/llava-v1.5-13b'
+SDXL_CLIP1_PATH = 'openai/clip-vit-large-patch14'
+SDXL_CLIP2_CKPT_PTH = 'models/open_clip_pytorch_model.bin'
diff --git a/gradio_demo.py b/gradio_demo.py
index da3de5c..8bef227 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -1,9 +1,10 @@
import os
+from pickle import TRUE
import gradio as gr
from gradio_imageslider import ImageSlider
import argparse
-from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype
+from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype, Tensor2PIL
import numpy as np
import torch
from SUPIR.util import create_SUPIR_model, load_QF_ckpt
@@ -12,19 +13,17 @@
from CKPT_PTH import LLAVA_MODEL_PATH
import einops
import copy
+import datetime
import time
+
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='127.0.0.1')
-parser.add_argument("--port", type=int, default='6688')
-parser.add_argument("--no_llava", action='store_true', default=False)
+parser.add_argument("--share", type=str, default=False)
+parser.add_argument("--port", type=int, default='7860')
+parser.add_argument("--no_llava", action='store_true', default=True)
parser.add_argument("--use_image_slider", action='store_true', default=False)
parser.add_argument("--log_history", action='store_true', default=False)
-parser.add_argument("--loading_half_params", action='store_true', default=False)
-parser.add_argument("--use_tile_vae", action='store_true', default=False)
-parser.add_argument("--encoder_tile_size", type=int, default=512)
-parser.add_argument("--decoder_tile_size", type=int, default=64)
-parser.add_argument("--load_8bit_llava", action='store_true', default=False)
args = parser.parse_args()
server_ip = args.ip
server_port = args.port
@@ -40,24 +39,17 @@
raise ValueError('Currently support CUDA only.')
# load SUPIR
-model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q')
-if args.loading_half_params:
- model = model.half()
-if args.use_tile_vae:
- model.init_tile_vae(encoder_tile_size=512, decoder_tile_size=64)
-model = model.to(SUPIR_device)
+model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
model.current_model = 'v0-Q'
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
-
# load LLaVA
if use_llava:
- llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
+ llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device)
else:
llava_agent = None
def stage1_process(input_image, gamma_correction):
- torch.cuda.set_device(SUPIR_device)
LQ = HWC3(input_image)
LQ = fix_resize(LQ, 512)
# stage1
@@ -73,7 +65,6 @@ def stage1_process(input_image, gamma_correction):
return LQ
def llave_process(input_image, temperature, top_p, qs=None):
- torch.cuda.set_device(LLaVA_device)
if use_llava:
LQ = HWC3(input_image)
LQ = Image.fromarray(LQ.astype('uint8'))
@@ -84,8 +75,7 @@ def llave_process(input_image, temperature, top_p, qs=None):
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
- torch.cuda.set_device(SUPIR_device)
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,random_seed, progress=gr.Progress()):
event_id = str(time.time_ns())
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
@@ -119,15 +109,39 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
model.ae_dtype = convert_dtype(ae_dtype)
model.model.dtype = convert_dtype(diff_dtype)
- samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
- s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
- num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
- use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
- cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
+ output_dir = os.path.join("outputs")
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ all_results = []
+ counter = 1
+ progress(0 / num_images, desc="Generating images")
+ for _ in range(num_images):
+ if random_seed or num_images>1:
+ seed = np.random.randint(0, 2147483647)
+ start_time = time.time() # Track the start time
+ samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
+ s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
+ num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
+ use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
+ cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
+
+ x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
+ 0, 255).astype(np.uint8)
+ results = [x_samples[i] for i in range(num_samples)]
+ image_generation_time = time.time() - start_time
+ desc=f"Generated image {counter}/{num_images} in {image_generation_time:.2f} seconds"
+ counter=counter+1
+ progress(counter / num_images, desc=desc)
+ print(desc) # Print the progress
+ start_time = time.time() # Reset the start time for the next image
+
+ for i, result in enumerate(results):
+ timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')[:-3]
+ save_path = os.path.join(output_dir, f'{timestamp}.png')
+ Image.fromarray(result).save(save_path)
+ all_results.extend(results)
- x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
- 0, 255).astype(np.uint8)
- results = [x_samples[i] for i in range(num_samples)]
if args.log_history:
os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True)
@@ -135,9 +149,9 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
f.write(str(event_dict))
f.close()
Image.fromarray(input_image).save(f'./history/{event_id[:5]}/{event_id[5:]}/LQ.png')
- for i, result in enumerate(results):
+ for i, result in enumerate(all_results):
Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png')
- return [input_image] + results, event_id, 3, ''
+ return [input_image] + all_results, event_id, 3, '', seed
def load_and_reset(param_setting):
edm_steps = 50
@@ -216,54 +230,37 @@ def submit_feedback(event_id, fb_score, fb_text):
prompt = gr.Textbox(label="Prompt", value="")
with gr.Accordion("Stage1 options", open=False):
gamma_correction = gr.Slider(label="Gamma Correction", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
- with gr.Accordion("LLaVA options", open=False):
- temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
- top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
- qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
- "The image is a realistic photography, not an art painting.")
- with gr.Accordion("Stage2 options", open=False):
- num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1
+
+ with gr.Accordion("Stage2 options", open=True):
+ with gr.Row():
+ with gr.Column():
+ num_images = gr.Slider(label="Number Of Images To Generate", minimum=1, maximum=200
, value=1, step=1)
- upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1)
- edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
- s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
- s_stage2 = gr.Slider(label="Stage2 Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
- s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0)
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
- s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
- s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
- a_prompt = gr.Textbox(label="Default Positive Prompt",
- value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
- 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
- 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
- 'hyper sharpness, perfect without deformations.')
- n_prompt = gr.Textbox(label="Default Negative Prompt",
- value='painting, oil painting, illustration, drawing, art, sketch, oil painting, '
- 'cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, '
- 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
- 'deformed, lowres, over-smooth')
- with gr.Row():
- with gr.Column():
- linear_CFG = gr.Checkbox(label="Linear CFG", value=False)
- spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
- maximum=9.0, value=1.0, step=0.5)
- with gr.Column():
- linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
- spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
- maximum=1., value=0., step=0.05)
- with gr.Row():
- with gr.Column():
- diff_dtype = gr.Radio(['fp32', 'fp16', 'bf16'], label="Diffusion Data Type", value="fp16",
- interactive=True)
- with gr.Column():
- ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
- interactive=True)
- with gr.Column():
- color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
- interactive=True)
- with gr.Column():
- model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
- interactive=True)
+ num_samples = gr.Slider(label="Batch Size", minimum=1, maximum=4 if not args.use_image_slider else 1
+ , value=1, step=1)
+ with gr.Column():
+ upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1)
+ random_seed = gr.Checkbox(label="Randomize Seed", value=True)
+ with gr.Row():
+ edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
+ s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
+ s_stage2 = gr.Slider(label="Stage2 Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
+ s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
+ s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
+ with gr.Row():
+ a_prompt = gr.Textbox(label="Default Positive Prompt",
+ value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
+ 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
+ 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
+ 'hyper sharpness, perfect without deformations.')
+ n_prompt = gr.Textbox(label="Default Negative Prompt",
+ value='painting, oil painting, illustration, drawing, art, sketch, oil painting, '
+ 'cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, '
+ 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
+ 'deformed, lowres, over-smooth')
+
with gr.Column():
gr.Markdown("Stage2 Output ")
@@ -284,7 +281,34 @@ def submit_feedback(event_id, fb_score, fb_text):
value="Quality")
with gr.Column():
restart_button = gr.Button(value="Reset Param", scale=2)
- with gr.Accordion("Feedback", open=True):
+ with gr.Row():
+ with gr.Column():
+ linear_CFG = gr.Checkbox(label="Linear CFG", value=False)
+ spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
+ maximum=9.0, value=1.0, step=0.5)
+ with gr.Column():
+ linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
+ spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
+ maximum=1., value=0., step=0.05)
+ with gr.Row():
+ with gr.Column():
+ diff_dtype = gr.Radio(['fp32', 'fp16', 'bf16'], label="Diffusion Data Type", value="fp16",
+ interactive=True)
+ with gr.Column():
+ ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
+ interactive=True)
+ with gr.Column():
+ color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
+ interactive=True)
+ with gr.Column():
+ model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
+ interactive=True)
+ with gr.Accordion("LLaVA options", open=False):
+ temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
+ top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
+ qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
+ "The image is a realistic photography, not an art painting.")
+ with gr.Accordion("Feedback", open=False):
fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1,
interactive=True)
fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.')
@@ -298,10 +322,10 @@ def submit_feedback(event_id, fb_score, fb_text):
outputs=[denoise_image])
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
- diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text])
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select,num_images,random_seed]
+ diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text, seed], show_progress=True, queue=True)
restart_button.click(fn=load_and_reset, inputs=[param_setting],
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
-block.launch(server_name=server_ip, server_port=server_port)
+block.launch(server_name=server_ip, server_port=server_port, share=args.share, inbrowser=True)
diff --git a/options/SUPIR_v0.yaml b/options/SUPIR_v0.yaml
index ff80312..36d42b5 100644
--- a/options/SUPIR_v0.yaml
+++ b/options/SUPIR_v0.yaml
@@ -149,8 +149,8 @@ model:
unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature,
jpeg artifacts, deformed, lowres, over-smooth'
-SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors
-SUPIR_CKPT_F: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0F.ckpt
-SUPIR_CKPT_Q: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0Q.ckpt
+SDXL_CKPT: models/sd_xl_base_1.0_0.9vae.safetensors
+SUPIR_CKPT_F: models/v0F.ckpt
+SUPIR_CKPT_Q: models/v0Q.ckpt
SUPIR_CKPT: ~
diff --git a/requirements.txt b/requirements.txt
index 52bf25d..73488d4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,8 +7,6 @@ numpy==1.24.2
requests==2.28.2
sentencepiece==0.1.98
tokenizers==0.13.3
-torch>=2.1.0
-torchvision>=0.16.0
uvicorn==0.21.1
wandb==0.14.0
httpx==0.24.0
@@ -36,4 +34,3 @@ tqdm==4.65.0
triton==2.1.0
urllib3==1.26.15
webdataset==0.2.48
-xformers>=0.0.20
From 905475a5f17bd8e7776667b16789b07f829f0b17 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 10:14:17 +0200
Subject: [PATCH 02/28] Updated requirements
---
requirements.txt | 11 +++++------
1 file changed, 5 insertions(+), 6 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index 73488d4..e2cb10b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,6 @@
-fastapi==0.95.1
-gradio==4.16.0
-gradio_imageslider==0.0.17
-gradio_client==0.8.1
+gradio
+gradio_imageslider
+gradio_client
Markdown==3.4.1
numpy==1.24.2
requests==2.28.2
@@ -18,7 +17,7 @@ einops==0.7.0
einops-exts==0.0.4
timm==0.9.8
openai-clip==1.0.1
-fsspec==2023.4.0
+fsspec
kornia==0.6.9
matplotlib==3.7.1
ninja==1.11.1
@@ -31,6 +30,6 @@ pytorch-lightning==2.1.2
PyYAML==6.0
scipy==1.9.1
tqdm==4.65.0
-triton==2.1.0
urllib3==1.26.15
webdataset==0.2.48
+fastapi
From 2f6d9f853fba5342886eb1e3103356721a4bb6bf Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 10:55:51 +0200
Subject: [PATCH 03/28] Added diffusers to requirements
---
options/SUPIR_v0.yaml | 4 ++--
requirements.txt | 1 +
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/options/SUPIR_v0.yaml b/options/SUPIR_v0.yaml
index 36d42b5..a89e4d4 100644
--- a/options/SUPIR_v0.yaml
+++ b/options/SUPIR_v0.yaml
@@ -150,7 +150,7 @@ model:
jpeg artifacts, deformed, lowres, over-smooth'
SDXL_CKPT: models/sd_xl_base_1.0_0.9vae.safetensors
-SUPIR_CKPT_F: models/v0F.ckpt
-SUPIR_CKPT_Q: models/v0Q.ckpt
+SUPIR_CKPT_F: models/SUPIR-v0F.ckpt
+SUPIR_CKPT_Q: models/SUPIR-v0Q.ckpt
SUPIR_CKPT: ~
diff --git a/requirements.txt b/requirements.txt
index e2cb10b..ce0d4d8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -33,3 +33,4 @@ tqdm==4.65.0
urllib3==1.26.15
webdataset==0.2.48
fastapi
+diffusers
From 85bde0403146d1b9f00e69639d2b4953b2744a1e Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 12:34:13 +0200
Subject: [PATCH 04/28] Add back some stuff
---
gradio_demo.py | 17 +++++++++++++++--
1 file changed, 15 insertions(+), 2 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 8bef227..c5b2b6e 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -24,6 +24,11 @@
parser.add_argument("--no_llava", action='store_true', default=True)
parser.add_argument("--use_image_slider", action='store_true', default=False)
parser.add_argument("--log_history", action='store_true', default=False)
+parser.add_argument("--loading_half_params", action='store_true', default=False)
+parser.add_argument("--use_tile_vae", action='store_true', default=False)
+parser.add_argument("--encoder_tile_size", type=int, default=512)
+parser.add_argument("--decoder_tile_size", type=int, default=64)
+parser.add_argument("--load_8bit_llava", action='store_true', default=False)
args = parser.parse_args()
server_ip = args.ip
server_port = args.port
@@ -39,17 +44,23 @@
raise ValueError('Currently support CUDA only.')
# load SUPIR
-model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q').to(SUPIR_device)
+model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q')
+if args.loading_half_params:
+ model = model.half()
+if args.use_tile_vae:
+ model.init_tile_vae(encoder_tile_size=512, decoder_tile_size=64)
+model = model.to(SUPIR_device)
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
model.current_model = 'v0-Q'
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
# load LLaVA
if use_llava:
- llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device)
+ llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
else:
llava_agent = None
def stage1_process(input_image, gamma_correction):
+ torch.cuda.set_device(SUPIR_device)
LQ = HWC3(input_image)
LQ = fix_resize(LQ, 512)
# stage1
@@ -66,6 +77,7 @@ def stage1_process(input_image, gamma_correction):
def llave_process(input_image, temperature, top_p, qs=None):
if use_llava:
+ torch.cuda.set_device(LLaVA_device)
LQ = HWC3(input_image)
LQ = Image.fromarray(LQ.astype('uint8'))
captions = llava_agent.gen_image_caption([LQ], temperature=temperature, top_p=top_p, qs=qs)
@@ -76,6 +88,7 @@ def llave_process(input_image, temperature, top_p, qs=None):
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,random_seed, progress=gr.Progress()):
+ torch.cuda.set_device(SUPIR_device)
event_id = str(time.time_ns())
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
From 51f15354ba05ee20e8518ec877213bee93e81ae1 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 12:35:38 +0200
Subject: [PATCH 05/28] Reenable LLaVA by default
---
gradio_demo.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index c5b2b6e..87be992 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -21,7 +21,7 @@
parser.add_argument("--ip", type=str, default='127.0.0.1')
parser.add_argument("--share", type=str, default=False)
parser.add_argument("--port", type=int, default='7860')
-parser.add_argument("--no_llava", action='store_true', default=True)
+parser.add_argument("--no_llava", action='store_true', default=False
parser.add_argument("--use_image_slider", action='store_true', default=False)
parser.add_argument("--log_history", action='store_true', default=False)
parser.add_argument("--loading_half_params", action='store_true', default=False)
@@ -53,6 +53,7 @@
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
model.current_model = 'v0-Q'
ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
+
# load LLaVA
if use_llava:
llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
From 7ba69c4b3b70490291c68c66bd4b89993797cd47 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 12:36:16 +0200
Subject: [PATCH 06/28] Reenable LLaVA by default
---
gradio_demo.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 87be992..0faaed1 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -21,7 +21,7 @@
parser.add_argument("--ip", type=str, default='127.0.0.1')
parser.add_argument("--share", type=str, default=False)
parser.add_argument("--port", type=int, default='7860')
-parser.add_argument("--no_llava", action='store_true', default=False
+parser.add_argument("--no_llava", action='store_true', default=False)
parser.add_argument("--use_image_slider", action='store_true', default=False)
parser.add_argument("--log_history", action='store_true', default=False)
parser.add_argument("--loading_half_params", action='store_true', default=False)
From 7fb5066be970d3a7b34fbf8500c66ec32ee9a51c Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 13:12:40 +0200
Subject: [PATCH 07/28] Fix typo
---
.gitignore | 1 +
gradio_demo.py | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/.gitignore b/.gitignore
index 3c8528f..ecbfa5f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,3 +3,4 @@ __pycache__
venv
.idea
.vs
+outputs
diff --git a/gradio_demo.py b/gradio_demo.py
index 0faaed1..67596e9 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -20,7 +20,7 @@
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='127.0.0.1')
parser.add_argument("--share", type=str, default=False)
-parser.add_argument("--port", type=int, default='7860')
+parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--no_llava", action='store_true', default=False)
parser.add_argument("--use_image_slider", action='store_true', default=False)
parser.add_argument("--log_history", action='store_true', default=False)
From ec7a03e530093b3af8bfe73891c1b2d41e637837 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 14:47:05 +0200
Subject: [PATCH 08/28] Added bitsandbytes to requirements
---
requirements.txt | 1 +
1 file changed, 1 insertion(+)
diff --git a/requirements.txt b/requirements.txt
index ce0d4d8..c5a421a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -34,3 +34,4 @@ urllib3==1.26.15
webdataset==0.2.48
fastapi
diffusers
+bitsandbytes
From 8aa303f146b1c1c8b0973d812b7ce02cdbb61ebd Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 15:21:55 +0200
Subject: [PATCH 09/28] Pin gradio versions
---
requirements.txt | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index c5a421a..691d8c3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,3 @@
-gradio
-gradio_imageslider
-gradio_client
Markdown==3.4.1
numpy==1.24.2
requests==2.28.2
@@ -35,3 +32,6 @@ webdataset==0.2.48
fastapi
diffusers
bitsandbytes
+gradio==4.16.0
+gradio_imageslider==0.0.17
+gradio_client==0.8.1
From 73763945797d13ca13fb7ee0681089390012c29b Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 15:22:30 +0200
Subject: [PATCH 10/28] Switch LLaVA 13B to 7B
---
CKPT_PTH.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/CKPT_PTH.py b/CKPT_PTH.py
index c02baf5..8d64e27 100644
--- a/CKPT_PTH.py
+++ b/CKPT_PTH.py
@@ -1,4 +1,4 @@
LLAVA_CLIP_PATH = 'openai/clip-vit-large-patch14-336'
-LLAVA_MODEL_PATH = 'liuhaotian/llava-v1.5-13b'
+LLAVA_MODEL_PATH = 'liuhaotian/llava-v1.5-7b'
SDXL_CLIP1_PATH = 'openai/clip-vit-large-patch14'
SDXL_CLIP2_CKPT_PTH = 'models/open_clip_pytorch_model.bin'
From 273c987fe44f23df1e86c8594942d165185b69e5 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 19:13:04 +0200
Subject: [PATCH 11/28] Changed SDXL model
---
options/SUPIR_v0.yaml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/options/SUPIR_v0.yaml b/options/SUPIR_v0.yaml
index a89e4d4..675ecba 100644
--- a/options/SUPIR_v0.yaml
+++ b/options/SUPIR_v0.yaml
@@ -149,7 +149,7 @@ model:
unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature,
jpeg artifacts, deformed, lowres, over-smooth'
-SDXL_CKPT: models/sd_xl_base_1.0_0.9vae.safetensors
+SDXL_CKPT: models/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors
SUPIR_CKPT_F: models/SUPIR-v0F.ckpt
SUPIR_CKPT_Q: models/SUPIR-v0Q.ckpt
SUPIR_CKPT: ~
From c47d6b7d2fd3bd0a36263fe1e5765aabe1fac277 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 22:34:02 +0200
Subject: [PATCH 12/28] Added batch processing
---
gradio_demo.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++----
1 file changed, 71 insertions(+), 5 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 67596e9..da38d89 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -86,10 +86,57 @@ def llave_process(input_image, temperature, top_p, qs=None):
captions = ['LLaVA is not available. Please add text manually.']
return captions[0]
+
+def batch_upscale(batch_process_folder,outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images, random_seed, progress=gr.Progress()):
+ import os
+ import numpy as np
+ from PIL import Image
+
+ # Get the list of image files in the folder
+ image_files = [file for file in os.listdir(batch_process_folder) if file.lower().endswith((".png", ".jpg", ".jpeg"))]
+ total_images = len(image_files)
+ main_prompt = prompt
+ # Iterate over all image files in the folder
+ for index, file_name in enumerate(image_files):
+ try:
+ progress((index + 1) / total_images, f"Processing {index + 1}/{total_images} image")
+ # Construct the full file path
+ file_path = os.path.join(batch_process_folder, file_name)
+ prompt = main_prompt
+ # Open the image file and convert it to a NumPy array
+ with Image.open(file_path) as img:
+ img_array = np.asarray(img)
+
+ # Construct the path for the prompt text file
+ base_name = os.path.splitext(file_name)[0]
+ prompt_file_path = os.path.join(batch_process_folder, f"{base_name}.txt")
+
+ # Read the prompt from the text file
+ if os.path.exists(prompt_file_path):
+ with open(prompt_file_path, "r", encoding="utf-8") as f:
+ prompt = f.read().strip()
+
+ # Call the stage2_process method for the image
+ stage2_process(img_array, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images, random_seed, dont_update_progress=True, outputs_folder=outputs_folder)
+
+ # Update progress
+ except Exception as e:
+ print(f"Error processing {file_name}: {e}")
+ continue
+ return "All Done"
+
+
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,random_seed, progress=gr.Progress()):
+
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,random_seed,dont_update_progress=False,outputs_folder="outputs", progress=gr.Progress()):
+
torch.cuda.set_device(SUPIR_device)
+
event_id = str(time.time_ns())
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
@@ -127,9 +174,15 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
if not os.path.exists(output_dir):
os.makedirs(output_dir)
+ if outputs_folder.strip() != "" and outputs_folder != "outputs":
+ output_dir = outputs_folder
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
all_results = []
counter = 1
- progress(0 / num_images, desc="Generating images")
+ if not dont_update_progress:
+ progress(0 / num_images, desc="Generating images")
for _ in range(num_images):
if random_seed or num_images>1:
seed = np.random.randint(0, 2147483647)
@@ -146,7 +199,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
image_generation_time = time.time() - start_time
desc=f"Generated image {counter}/{num_images} in {image_generation_time:.2f} seconds"
counter=counter+1
- progress(counter / num_images, desc=desc)
+ if not dont_update_progress:
+ progress(counter / num_images, desc=desc)
print(desc) # Print the progress
start_time = time.time() # Reset the start time for the next image
@@ -211,7 +265,7 @@ def submit_feedback(event_id, fb_score, fb_text):
title_md = """
# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
-⚠️SUPIR is still a research project under tested and is not yet a stable commercial product.
+1 Click Installer (auto download models as well) : https://www.patreon.com/posts/99176057
[[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
"""
@@ -277,7 +331,7 @@ def submit_feedback(event_id, fb_score, fb_text):
with gr.Column():
- gr.Markdown("Stage2 Output ")
+ gr.Markdown("Upscaled Images Output ")
if not args.use_image_slider:
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1")
else:
@@ -289,6 +343,14 @@ def submit_feedback(event_id, fb_score, fb_text):
llave_button = gr.Button(value="LlaVa Run")
with gr.Column():
diffusion_button = gr.Button(value="Stage2 Run")
+ with gr.Row():
+ with gr.Column():
+ batch_process_folder = gr.Textbox(label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.", placeholder="e.g. R:\SUPIR video\comparison_images")
+ outputs_folder = gr.Textbox(label="Batch Processing Output Folder Path - If left empty images are saved in default folder", placeholder="e.g. R:\SUPIR video\comparison_images\outputs")
+ with gr.Row():
+ with gr.Column():
+ batch_upscale_button = gr.Button(value="Start Batch Upscaling")
+ outputlabel = gr.Label("Batch Processing Progress")
with gr.Row():
with gr.Column():
param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
@@ -342,4 +404,8 @@ def submit_feedback(event_id, fb_score, fb_text):
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
+ stage2_ips_batch = [batch_process_folder,outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select,num_images,random_seed]
+ batch_upscale_button.click(fn=batch_upscale, inputs=stage2_ips_batch, outputs=outputlabel, show_progress=True, queue=True)
block.launch(server_name=server_ip, server_port=server_port, share=args.share, inbrowser=True)
From 5a7a451eaf0dae6a71f97f0112cf5927fb950d3e Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Tue, 27 Feb 2024 23:45:58 +0200
Subject: [PATCH 13/28] Use step of 0.1 instead of 1 on upscale slider
---
gradio_demo.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index da38d89..3e71778 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -307,7 +307,7 @@ def submit_feedback(event_id, fb_score, fb_text):
num_samples = gr.Slider(label="Batch Size", minimum=1, maximum=4 if not args.use_image_slider else 1
, value=1, step=1)
with gr.Column():
- upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1)
+ upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=0.1
random_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
From 05749ed008159eab8e20bdc8beea5296d5437042 Mon Sep 17 00:00:00 2001
From: Bear Stonem
Date: Wed, 28 Feb 2024 11:02:29 -0800
Subject: [PATCH 14/28] add missing param on gr.Slider
---
gradio_demo.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 3e71778..0e129f2 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -307,7 +307,7 @@ def submit_feedback(event_id, fb_score, fb_text):
num_samples = gr.Slider(label="Batch Size", minimum=1, maximum=4 if not args.use_image_slider else 1
, value=1, step=1)
with gr.Column():
- upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=0.1
+ upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=0.1)
random_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
From 0b31941e20d3b5e879265807c6928189add4618a Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 09:38:26 +0200
Subject: [PATCH 15/28] Removed IntelliJ files
---
.idea/.gitignore | 8 ----
.idea/.name | 1 -
.idea/SUPIR.iml | 10 -----
.idea/deployment.xml | 23 -----------
.idea/inspectionProfiles/Project_Default.xml | 38 -------------------
.../inspectionProfiles/profiles_settings.xml | 6 ---
.idea/misc.xml | 4 --
.idea/modules.xml | 8 ----
.idea/vcs.xml | 6 ---
.idea/webResources.xml | 14 -------
10 files changed, 118 deletions(-)
delete mode 100644 .idea/.gitignore
delete mode 100644 .idea/.name
delete mode 100644 .idea/SUPIR.iml
delete mode 100644 .idea/deployment.xml
delete mode 100644 .idea/inspectionProfiles/Project_Default.xml
delete mode 100644 .idea/inspectionProfiles/profiles_settings.xml
delete mode 100644 .idea/misc.xml
delete mode 100644 .idea/modules.xml
delete mode 100644 .idea/vcs.xml
delete mode 100644 .idea/webResources.xml
diff --git a/.idea/.gitignore b/.idea/.gitignore
deleted file mode 100644
index 13566b8..0000000
--- a/.idea/.gitignore
+++ /dev/null
@@ -1,8 +0,0 @@
-# Default ignored files
-/shelf/
-/workspace.xml
-# Editor-based HTTP Client requests
-/httpRequests/
-# Datasource local storage ignored files
-/dataSources/
-/dataSources.local.xml
diff --git a/.idea/.name b/.idea/.name
deleted file mode 100644
index 0b8d3da..0000000
--- a/.idea/.name
+++ /dev/null
@@ -1 +0,0 @@
-Diff4R
\ No newline at end of file
diff --git a/.idea/SUPIR.iml b/.idea/SUPIR.iml
deleted file mode 100644
index de95779..0000000
--- a/.idea/SUPIR.iml
+++ /dev/null
@@ -1,10 +0,0 @@
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
deleted file mode 100644
index 768ffba..0000000
--- a/.idea/deployment.xml
+++ /dev/null
@@ -1,23 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
deleted file mode 100644
index 7410f1c..0000000
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ /dev/null
@@ -1,38 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
deleted file mode 100644
index 105ce2d..0000000
--- a/.idea/inspectionProfiles/profiles_settings.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
index 874415a..0000000
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index 8ae5627..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 35eb1dd..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/webResources.xml b/.idea/webResources.xml
deleted file mode 100644
index 717d9d6..0000000
--- a/.idea/webResources.xml
+++ /dev/null
@@ -1,14 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
From dae7909fea4d7e4384398380c4eecdc5b9032d76 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 10:49:46 +0200
Subject: [PATCH 16/28] Fix typos
---
gradio_demo.py | 35 +++++++++++++++++++----------------
1 file changed, 19 insertions(+), 16 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 0e129f2..6b94eeb 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -60,6 +60,7 @@
else:
llava_agent = None
+
def stage1_process(input_image, gamma_correction):
torch.cuda.set_device(SUPIR_device)
LQ = HWC3(input_image)
@@ -76,7 +77,8 @@ def stage1_process(input_image, gamma_correction):
LQ = LQ.round().clip(0, 255).astype(np.uint8)
return LQ
-def llave_process(input_image, temperature, top_p, qs=None):
+
+def llava_process(input_image, temperature, top_p, qs=None):
if use_llava:
torch.cuda.set_device(LLaVA_device)
LQ = HWC3(input_image)
@@ -87,9 +89,10 @@ def llave_process(input_image, temperature, top_p, qs=None):
return captions[0]
-def batch_upscale(batch_process_folder,outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
- s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images, random_seed, progress=gr.Progress()):
+def batch_upscale(batch_process_folder, outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps,
+ s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype,
+ gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select,
+ num_images, random_seed, progress=gr.Progress()):
import os
import numpy as np
from PIL import Image
@@ -121,7 +124,8 @@ def batch_upscale(batch_process_folder,outputs_folder, prompt, a_prompt, n_promp
# Call the stage2_process method for the image
stage2_process(img_array, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images, random_seed, dont_update_progress=True, outputs_folder=outputs_folder)
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed, dont_update_progress=True, outputs_folder=outputs_folder)
# Update progress
except Exception as e:
@@ -132,8 +136,8 @@ def batch_upscale(batch_process_folder,outputs_folder, prompt, a_prompt, n_promp
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
-
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,random_seed,dont_update_progress=False,outputs_folder="outputs", progress=gr.Progress()):
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed, dont_update_progress=False, outputs_folder="outputs", progress=gr.Progress()):
torch.cuda.set_device(SUPIR_device)
@@ -156,8 +160,7 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
model.load_state_dict(ckpt_F, strict=False)
model.current_model = 'v0-F'
input_image = HWC3(input_image)
- input_image = upscale_image(input_image, upscale, unit_resolution=32,
- min_size=1024)
+ input_image = upscale_image(input_image, upscale, unit_resolution=32, min_size=1024)
LQ = np.array(input_image) / 255.0
LQ = np.power(LQ, gamma_correction)
@@ -197,8 +200,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
image_generation_time = time.time() - start_time
- desc=f"Generated image {counter}/{num_images} in {image_generation_time:.2f} seconds"
- counter=counter+1
+ desc = f"Generated image {counter}/{num_images} in {image_generation_time:.2f} seconds"
+ counter = counter+1
if not dont_update_progress:
progress(counter / num_images, desc=desc)
print(desc) # Print the progress
@@ -210,7 +213,6 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
Image.fromarray(result).save(save_path)
all_results.extend(results)
-
if args.log_history:
os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True)
with open(f'./history/{event_id[:5]}/{event_id[5:]}/logs.txt', 'w') as f:
@@ -221,6 +223,7 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png')
return [input_image] + all_results, event_id, 3, '', seed
+
def load_and_reset(param_setting):
edm_steps = 50
s_stage2 = 1.0
@@ -262,10 +265,11 @@ def submit_feedback(event_id, fb_score, fb_text):
else:
return 'Submit failed, the server is not set to log history.'
+
title_md = """
# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
-1 Click Installer (auto download models as well) : https://www.patreon.com/posts/99176057
+⚠️SUPIR is still a research project under tested and is not yet a stable commercial product.
[[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
"""
@@ -329,7 +333,6 @@ def submit_feedback(event_id, fb_score, fb_text):
'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
'deformed, lowres, over-smooth')
-
with gr.Column():
gr.Markdown("Upscaled Images Output ")
if not args.use_image_slider:
@@ -340,7 +343,7 @@ def submit_feedback(event_id, fb_score, fb_text):
with gr.Column():
denoise_button = gr.Button(value="Stage1 Run")
with gr.Column():
- llave_button = gr.Button(value="LlaVa Run")
+ llava_button = gr.Button(value="LlaVa Run")
with gr.Column():
diffusion_button = gr.Button(value="Stage2 Run")
with gr.Row():
@@ -393,7 +396,7 @@ def submit_feedback(event_id, fb_score, fb_text):
gr.Markdown(claim_md)
event_id = gr.Textbox(label="Event ID", value="", visible=False)
- llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
+ llava_button.click(fn=llava_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
outputs=[denoise_image])
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
From 1327b6010a16793c7f40cc49ca9397d685dea928 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 11:00:45 +0200
Subject: [PATCH 17/28] Updated README and requirements
---
README.md | 2 +-
requirements.txt | 3 +++
2 files changed, 4 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 7a72e84..fa22c4a 100644
--- a/README.md
+++ b/README.md
@@ -38,7 +38,7 @@ For users who can connect to huggingface, please setting `LLAVA_CLIP_PATH, SDXL_
* [SDXL CLIP Encoder-2](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
* [SDXL base 1.0_0.9vae](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors)
* [LLaVA CLIP](https://huggingface.co/openai/clip-vit-large-patch14-336)
-* [LLaVA v1.5 13B](https://huggingface.co/liuhaotian/llava-v1.5-13b)
+* [LLaVA v1.5 7B](https://huggingface.co/liuhaotian/llava-v1.5-7b)
#### Models we provided:
diff --git a/requirements.txt b/requirements.txt
index 691d8c3..792557a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,8 @@ numpy==1.24.2
requests==2.28.2
sentencepiece==0.1.98
tokenizers==0.13.3
+torch
+torchvision
uvicorn==0.21.1
wandb==0.14.0
httpx==0.24.0
@@ -29,6 +31,7 @@ scipy==1.9.1
tqdm==4.65.0
urllib3==1.26.15
webdataset==0.2.48
+xformers
fastapi
diffusers
bitsandbytes
From 136b4156a2a0b54a553cb43224d1d8d41e231d27 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 11:22:33 +0200
Subject: [PATCH 18/28] Added script to download the models
---
download_models.py | 69 ++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 69 insertions(+)
create mode 100755 download_models.py
diff --git a/download_models.py b/download_models.py
new file mode 100755
index 0000000..4058995
--- /dev/null
+++ b/download_models.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python3
+import os
+import requests
+from tqdm import tqdm
+from huggingface_hub import snapshot_download
+
+
+def create_directory(path):
+ """Create directory if it does not exist."""
+ if not os.path.exists(path):
+ os.makedirs(path)
+ print(f'Directory created: {path}')
+ else:
+ print(f'Directory already exists: {path}')
+
+
+def download_file(url, folder_path, file_name=None):
+ """Download a file from a given URL to a specified folder with an optional file name."""
+ local_filename = file_name if file_name else url.split('/')[-1]
+ local_filepath = os.path.join(folder_path, local_filename)
+ print(f'Downloading {url} to: {local_filepath}')
+
+ # Stream download to handle large files
+ with requests.get(url, stream=True) as r:
+ r.raise_for_status()
+ total_size_in_bytes = int(r.headers.get('content-length', 0))
+ block_size = 1024 # 1 Kibibyte
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
+ with open(local_filepath, 'wb') as f:
+ for data in r.iter_content(block_size):
+ progress_bar.update(len(data))
+ f.write(data)
+ progress_bar.close()
+
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
+ print('ERROR, something went wrong')
+ else:
+ print(f'Downloaded {local_filename} to {folder_path}')
+
+
+# Define the folders and their corresponding file URLs with optional file names
+folders_and_files = {
+ os.path.join('models'): [
+ ('https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/resolve/main/open_clip_pytorch_model.bin', None),
+ ('https://huggingface.co/RunDiffusion/Juggernaut-XL-v9/resolve/main/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors', None),
+ ('https://huggingface.co/ashleykleynhans/SUPIR/resolve/main/SUPIR-v0F.ckpt', None),
+ ('https://huggingface.co/ashleykleynhans/SUPIR/resolve/main/SUPIR-v0Q.ckpt', None),
+ ]
+}
+
+
+if __name__ == '__main__':
+ for folder, files in folders_and_files.items():
+ create_directory(folder)
+ for file_url, file_name in files:
+ download_file(file_url, folder, file_name)
+
+ llava_model = os.getenv('LLAVA_MODEL', 'liuhaotian/llava-v1.5-7b')
+ llava_clip_model = 'openai/clip-vit-large-patch14-336'
+ sdxl_clip_model = 'openai/clip-vit-large-patch14'
+
+ print(f'Downloading LLaVA model: {llava_model}')
+ snapshot_download(llava_model)
+
+ print(f'Downloading LLaVA CLIP model: {llava_clip_model}')
+ snapshot_download(llava_clip_model)
+
+ print(f'Downloading SDXL CLIP model: {sdxl_clip_model}')
+ snapshot_download(sdxl_clip_model)
From a6e9dceff42cec84d13daa8538bc387422810ada Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 11:36:21 +0200
Subject: [PATCH 19/28] Fixed formatting
---
gradio_demo.py | 71 ++++++++++++++++++++++++++++----------------------
1 file changed, 40 insertions(+), 31 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 6b94eeb..ce4f4f0 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -16,7 +16,6 @@
import datetime
import time
-
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default='127.0.0.1')
parser.add_argument("--share", type=str, default=False)
@@ -98,7 +97,8 @@ def batch_upscale(batch_process_folder, outputs_folder, prompt, a_prompt, n_prom
from PIL import Image
# Get the list of image files in the folder
- image_files = [file for file in os.listdir(batch_process_folder) if file.lower().endswith((".png", ".jpg", ".jpeg"))]
+ image_files = [file for file in os.listdir(batch_process_folder) if
+ file.lower().endswith((".png", ".jpg", ".jpeg"))]
total_images = len(image_files)
main_prompt = prompt
# Iterate over all image files in the folder
@@ -138,7 +138,6 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
random_seed, dont_update_progress=False, outputs_folder="outputs", progress=gr.Progress()):
-
torch.cuda.set_device(SUPIR_device)
event_id = str(time.time_ns())
@@ -187,12 +186,13 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
if not dont_update_progress:
progress(0 / num_images, desc="Generating images")
for _ in range(num_images):
- if random_seed or num_images>1:
+ if random_seed or num_images > 1:
seed = np.random.randint(0, 2147483647)
start_time = time.time() # Track the start time
samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
- num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
+ num_samples=num_samples, p_p=a_prompt, n_p=n_prompt,
+ color_fix_type=color_fix_type,
use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
@@ -201,7 +201,7 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
results = [x_samples[i] for i in range(num_samples)]
image_generation_time = time.time() - start_time
desc = f"Generated image {counter}/{num_images} in {image_generation_time:.2f} seconds"
- counter = counter+1
+ counter = counter + 1
if not dont_update_progress:
progress(counter / num_images, desc=desc)
print(desc) # Print the progress
@@ -274,7 +274,6 @@ def submit_feedback(event_id, fb_score, fb_text):
[[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
"""
-
claim_md = """
## **Terms of use**
@@ -285,7 +284,6 @@ def submit_feedback(event_id, fb_score, fb_text):
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
"""
-
block = gr.Blocks(title='SUPIR').queue()
with block:
with gr.Row():
@@ -304,24 +302,26 @@ def submit_feedback(event_id, fb_score, fb_text):
gamma_correction = gr.Slider(label="Gamma Correction", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
with gr.Accordion("Stage2 options", open=True):
- with gr.Row():
- with gr.Column():
+ with gr.Row():
+ with gr.Column():
num_images = gr.Slider(label="Number Of Images To Generate", minimum=1, maximum=200
- , value=1, step=1)
- num_samples = gr.Slider(label="Batch Size", minimum=1, maximum=4 if not args.use_image_slider else 1
- , value=1, step=1)
- with gr.Column():
- upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=0.1)
- random_seed = gr.Checkbox(label="Randomize Seed", value=True)
- with gr.Row():
+ , value=1, step=1)
+ num_samples = gr.Slider(label="Batch Size", minimum=1,
+ maximum=4 if not args.use_image_slider else 1
+ , value=1, step=1)
+ with gr.Column():
+ upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=0.1)
+ random_seed = gr.Checkbox(label="Randomize Seed", value=True)
+ with gr.Row():
edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
s_stage2 = gr.Slider(label="Stage2 Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
- s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0)
+ s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0,
+ step=1.0)
seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
- with gr.Row():
+ with gr.Row():
a_prompt = gr.Textbox(label="Default Positive Prompt",
value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
@@ -348,8 +348,12 @@ def submit_feedback(event_id, fb_score, fb_text):
diffusion_button = gr.Button(value="Stage2 Run")
with gr.Row():
with gr.Column():
- batch_process_folder = gr.Textbox(label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.", placeholder="e.g. R:\SUPIR video\comparison_images")
- outputs_folder = gr.Textbox(label="Batch Processing Output Folder Path - If left empty images are saved in default folder", placeholder="e.g. R:\SUPIR video\comparison_images\outputs")
+ batch_process_folder = gr.Textbox(
+ label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.",
+ placeholder="e.g. R:\SUPIR video\comparison_images")
+ outputs_folder = gr.Textbox(
+ label="Batch Processing Output Folder Path - If left empty images are saved in default folder",
+ placeholder="e.g. R:\SUPIR video\comparison_images\outputs")
with gr.Row():
with gr.Column():
batch_upscale_button = gr.Button(value="Start Batch Upscaling")
@@ -357,14 +361,14 @@ def submit_feedback(event_id, fb_score, fb_text):
with gr.Row():
with gr.Column():
param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
- value="Quality")
+ value="Quality")
with gr.Column():
restart_button = gr.Button(value="Reset Param", scale=2)
with gr.Row():
with gr.Column():
linear_CFG = gr.Checkbox(label="Linear CFG", value=False)
spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
- maximum=9.0, value=1.0, step=0.5)
+ maximum=9.0, value=1.0, step=0.5)
with gr.Column():
linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
@@ -372,13 +376,13 @@ def submit_feedback(event_id, fb_score, fb_text):
with gr.Row():
with gr.Column():
diff_dtype = gr.Radio(['fp32', 'fp16', 'bf16'], label="Diffusion Data Type", value="fp16",
- interactive=True)
+ interactive=True)
with gr.Column():
ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
interactive=True)
with gr.Column():
color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
- interactive=True)
+ interactive=True)
with gr.Column():
model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
interactive=True)
@@ -401,14 +405,19 @@ def submit_feedback(event_id, fb_score, fb_text):
outputs=[denoise_image])
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select,num_images,random_seed]
- diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text, seed], show_progress=True, queue=True)
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed]
+ diffusion_button.click(fn=stage2_process, inputs=stage2_ips,
+ outputs=[result_gallery, event_id, fb_score, fb_text, seed], show_progress=True, queue=True)
restart_button.click(fn=load_and_reset, inputs=[param_setting],
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
- stage2_ips_batch = [batch_process_folder,outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
- s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select,num_images,random_seed]
- batch_upscale_button.click(fn=batch_upscale, inputs=stage2_ips_batch, outputs=outputlabel, show_progress=True, queue=True)
+ stage2_ips_batch = [batch_process_folder, outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale,
+ edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed]
+ batch_upscale_button.click(fn=batch_upscale, inputs=stage2_ips_batch, outputs=outputlabel, show_progress=True,
+ queue=True)
block.launch(server_name=server_ip, server_port=server_port, share=args.share, inbrowser=True)
From 5a60ca89f3e59a0726e0cc564e49cc531b51d45b Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 12:29:37 +0200
Subject: [PATCH 20/28] Some refactoring
---
gradio_demo.py | 379 +++++++++++++++++++++++++------------------------
1 file changed, 195 insertions(+), 184 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index ce4f4f0..fd96d3f 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -16,49 +16,6 @@
import datetime
import time
-parser = argparse.ArgumentParser()
-parser.add_argument("--ip", type=str, default='127.0.0.1')
-parser.add_argument("--share", type=str, default=False)
-parser.add_argument("--port", type=int, default=7860)
-parser.add_argument("--no_llava", action='store_true', default=False)
-parser.add_argument("--use_image_slider", action='store_true', default=False)
-parser.add_argument("--log_history", action='store_true', default=False)
-parser.add_argument("--loading_half_params", action='store_true', default=False)
-parser.add_argument("--use_tile_vae", action='store_true', default=False)
-parser.add_argument("--encoder_tile_size", type=int, default=512)
-parser.add_argument("--decoder_tile_size", type=int, default=64)
-parser.add_argument("--load_8bit_llava", action='store_true', default=False)
-args = parser.parse_args()
-server_ip = args.ip
-server_port = args.port
-use_llava = not args.no_llava
-
-if torch.cuda.device_count() >= 2:
- SUPIR_device = 'cuda:0'
- LLaVA_device = 'cuda:1'
-elif torch.cuda.device_count() == 1:
- SUPIR_device = 'cuda:0'
- LLaVA_device = 'cuda:0'
-else:
- raise ValueError('Currently support CUDA only.')
-
-# load SUPIR
-model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q')
-if args.loading_half_params:
- model = model.half()
-if args.use_tile_vae:
- model.init_tile_vae(encoder_tile_size=512, decoder_tile_size=64)
-model = model.to(SUPIR_device)
-model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
-model.current_model = 'v0-Q'
-ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
-
-# load LLaVA
-if use_llava:
- llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
-else:
- llava_agent = None
-
def stage1_process(input_image, gamma_correction):
torch.cuda.set_device(SUPIR_device)
@@ -266,158 +223,212 @@ def submit_feedback(event_id, fb_score, fb_text):
return 'Submit failed, the server is not set to log history.'
-title_md = """
-# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
-
-⚠️SUPIR is still a research project under tested and is not yet a stable commercial product.
+def launch_ui(launch_kwargs):
+ title_md = """
+ # **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
-[[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
-"""
+ ⚠️SUPIR is still a research project under tested and is not yet a stable commercial product.
-claim_md = """
-## **Terms of use**
+ [[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
+ """
-By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
+ claim_md = """
+ ## **Terms of use**
-## **License**
+ By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
-The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
-"""
+ ## **License**
-block = gr.Blocks(title='SUPIR').queue()
-with block:
- with gr.Row():
- gr.Markdown(title_md)
- with gr.Row():
- with gr.Column():
- with gr.Row(equal_height=True):
- with gr.Column():
- gr.Markdown("Input ")
- input_image = gr.Image(type="numpy", elem_id="image-input", height=400, width=400)
- with gr.Column():
- gr.Markdown("Stage1 Output ")
- denoise_image = gr.Image(type="numpy", elem_id="image-s1", height=400, width=400)
- prompt = gr.Textbox(label="Prompt", value="")
- with gr.Accordion("Stage1 options", open=False):
- gamma_correction = gr.Slider(label="Gamma Correction", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
+ """
- with gr.Accordion("Stage2 options", open=True):
+ interface = gr.Blocks(title='SUPIR').queue()
+ with interface:
+ with gr.Row():
+ gr.Markdown(title_md)
+ with gr.Row():
+ with gr.Column():
+ with gr.Row(equal_height=True):
+ with gr.Column():
+ gr.Markdown("Input ")
+ input_image = gr.Image(type="numpy", elem_id="image-input", height=400, width=400)
+ with gr.Column():
+ gr.Markdown("Stage1 Output ")
+ denoise_image = gr.Image(type="numpy", elem_id="image-s1", height=400, width=400)
+ prompt = gr.Textbox(label="Prompt", value="")
+ with gr.Accordion("Stage1 options", open=False):
+ gamma_correction = gr.Slider(label="Gamma Correction", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
+
+ with gr.Accordion("Stage2 options", open=True):
+ with gr.Row():
+ with gr.Column():
+ num_images = gr.Slider(label="Number Of Images To Generate", minimum=1, maximum=200
+ , value=1, step=1)
+ num_samples = gr.Slider(label="Batch Size", minimum=1,
+ maximum=4 if not args.use_image_slider else 1
+ , value=1, step=1)
+ with gr.Column():
+ upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=0.1)
+ random_seed = gr.Checkbox(label="Randomize Seed", value=True)
+ with gr.Row():
+ edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
+ s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
+ s_stage2 = gr.Slider(label="Stage2 Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
+ s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0,
+ step=1.0)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
+ s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
+ with gr.Row():
+ a_prompt = gr.Textbox(label="Default Positive Prompt",
+ value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
+ 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
+ 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
+ 'hyper sharpness, perfect without deformations.')
+ n_prompt = gr.Textbox(label="Default Negative Prompt",
+ value='painting, oil painting, illustration, drawing, art, sketch, oil painting, '
+ 'cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, '
+ 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
+ 'deformed, lowres, over-smooth')
+
+ with gr.Column():
+ gr.Markdown("Upscaled Images Output ")
+ if not args.use_image_slider:
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1")
+ else:
+ result_gallery = ImageSlider(label='Output', show_label=False, elem_id="gallery1")
+ with gr.Row():
+ with gr.Column():
+ denoise_button = gr.Button(value="Stage1 Run")
+ with gr.Column():
+ llava_button = gr.Button(value="LlaVa Run")
+ with gr.Column():
+ diffusion_button = gr.Button(value="Stage2 Run")
with gr.Row():
with gr.Column():
- num_images = gr.Slider(label="Number Of Images To Generate", minimum=1, maximum=200
- , value=1, step=1)
- num_samples = gr.Slider(label="Batch Size", minimum=1,
- maximum=4 if not args.use_image_slider else 1
- , value=1, step=1)
+ batch_process_folder = gr.Textbox(
+ label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.",
+ placeholder="e.g. R:\SUPIR video\comparison_images")
+ outputs_folder = gr.Textbox(
+ label="Batch Processing Output Folder Path - If left empty images are saved in default folder",
+ placeholder="e.g. R:\SUPIR video\comparison_images\outputs")
+ with gr.Row():
with gr.Column():
- upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=0.1)
- random_seed = gr.Checkbox(label="Randomize Seed", value=True)
+ batch_upscale_button = gr.Button(value="Start Batch Upscaling")
+ outputlabel = gr.Label("Batch Processing Progress")
with gr.Row():
- edm_steps = gr.Slider(label="Steps", minimum=20, maximum=200, value=50, step=1)
- s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0, value=7.5, step=0.1)
- s_stage2 = gr.Slider(label="Stage2 Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
- s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0,
- step=1.0)
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
- s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
- s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
+ with gr.Column():
+ param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
+ value="Quality")
+ with gr.Column():
+ restart_button = gr.Button(value="Reset Param", scale=2)
with gr.Row():
- a_prompt = gr.Textbox(label="Default Positive Prompt",
- value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
- 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
- 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
- 'hyper sharpness, perfect without deformations.')
- n_prompt = gr.Textbox(label="Default Negative Prompt",
- value='painting, oil painting, illustration, drawing, art, sketch, oil painting, '
- 'cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, '
- 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
- 'deformed, lowres, over-smooth')
-
- with gr.Column():
- gr.Markdown("Upscaled Images Output ")
- if not args.use_image_slider:
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1")
- else:
- result_gallery = ImageSlider(label='Output', show_label=False, elem_id="gallery1")
- with gr.Row():
- with gr.Column():
- denoise_button = gr.Button(value="Stage1 Run")
- with gr.Column():
- llava_button = gr.Button(value="LlaVa Run")
- with gr.Column():
- diffusion_button = gr.Button(value="Stage2 Run")
- with gr.Row():
- with gr.Column():
- batch_process_folder = gr.Textbox(
- label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.",
- placeholder="e.g. R:\SUPIR video\comparison_images")
- outputs_folder = gr.Textbox(
- label="Batch Processing Output Folder Path - If left empty images are saved in default folder",
- placeholder="e.g. R:\SUPIR video\comparison_images\outputs")
- with gr.Row():
- with gr.Column():
- batch_upscale_button = gr.Button(value="Start Batch Upscaling")
- outputlabel = gr.Label("Batch Processing Progress")
- with gr.Row():
- with gr.Column():
- param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
- value="Quality")
- with gr.Column():
- restart_button = gr.Button(value="Reset Param", scale=2)
- with gr.Row():
- with gr.Column():
- linear_CFG = gr.Checkbox(label="Linear CFG", value=False)
- spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
- maximum=9.0, value=1.0, step=0.5)
- with gr.Column():
- linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
- spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
- maximum=1., value=0., step=0.05)
- with gr.Row():
- with gr.Column():
- diff_dtype = gr.Radio(['fp32', 'fp16', 'bf16'], label="Diffusion Data Type", value="fp16",
- interactive=True)
- with gr.Column():
- ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
- interactive=True)
- with gr.Column():
- color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
+ with gr.Column():
+ linear_CFG = gr.Checkbox(label="Linear CFG", value=False)
+ spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
+ maximum=9.0, value=1.0, step=0.5)
+ with gr.Column():
+ linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
+ spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
+ maximum=1., value=0., step=0.05)
+ with gr.Row():
+ with gr.Column():
+ diff_dtype = gr.Radio(['fp32', 'fp16', 'bf16'], label="Diffusion Data Type", value="fp16",
interactive=True)
- with gr.Column():
- model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
+ with gr.Column():
+ ae_dtype = gr.Radio(['fp32', 'bf16'], label="Auto-Encoder Data Type", value="bf16",
interactive=True)
- with gr.Accordion("LLaVA options", open=False):
- temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
- top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
- qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
- "The image is a realistic photography, not an art painting.")
- with gr.Accordion("Feedback", open=False):
- fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1,
- interactive=True)
- fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.')
- submit_button = gr.Button(value="Submit Feedback")
- with gr.Row():
- gr.Markdown(claim_md)
- event_id = gr.Textbox(label="Event ID", value="", visible=False)
-
- llava_button.click(fn=llava_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
- denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
- outputs=[denoise_image])
- stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
- s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
- random_seed]
- diffusion_button.click(fn=stage2_process, inputs=stage2_ips,
- outputs=[result_gallery, event_id, fb_score, fb_text, seed], show_progress=True, queue=True)
- restart_button.click(fn=load_and_reset, inputs=[param_setting],
- outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
- color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
- submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
- stage2_ips_batch = [batch_process_folder, outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale,
- edm_steps, s_stage1, s_stage2,
- s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
- random_seed]
- batch_upscale_button.click(fn=batch_upscale, inputs=stage2_ips_batch, outputs=outputlabel, show_progress=True,
- queue=True)
-block.launch(server_name=server_ip, server_port=server_port, share=args.share, inbrowser=True)
+ with gr.Column():
+ color_fix_type = gr.Radio(["None", "AdaIn", "Wavelet"], label="Color-Fix Type", value="Wavelet",
+ interactive=True)
+ with gr.Column():
+ model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
+ interactive=True)
+ with gr.Accordion("LLaVA options", open=False):
+ temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
+ top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
+ qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
+ "The image is a realistic photography, not an art painting.")
+ with gr.Accordion("Feedback", open=False):
+ fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1,
+ interactive=True)
+ fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.')
+ submit_button = gr.Button(value="Submit Feedback")
+ with gr.Row():
+ gr.Markdown(claim_md)
+ event_id = gr.Textbox(label="Event ID", value="", visible=False)
+
+ llava_button.click(fn=llava_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
+ denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
+ outputs=[denoise_image])
+ stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed]
+ diffusion_button.click(fn=stage2_process, inputs=stage2_ips,
+ outputs=[result_gallery, event_id, fb_score, fb_text, seed], show_progress=True, queue=True)
+ restart_button.click(fn=load_and_reset, inputs=[param_setting],
+ outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
+ color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
+ submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
+ stage2_ips_batch = [batch_process_folder, outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale,
+ edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed]
+ batch_upscale_button.click(fn=batch_upscale, inputs=stage2_ips_batch, outputs=outputlabel, show_progress=True,
+ queue=True)
+ interface.launch(server_name=server_ip, server_port=server_port, share=args.share, inbrowser=True)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--ip", type=str, default='127.0.0.1')
+ parser.add_argument("--share", type=str, default=False)
+ parser.add_argument("--port", type=int, default=7860)
+ parser.add_argument("--no_llava", action='store_true', default=False)
+ parser.add_argument("--use_image_slider", action='store_true', default=False)
+ parser.add_argument("--log_history", action='store_true', default=False)
+ parser.add_argument("--loading_half_params", action='store_true', default=False)
+ parser.add_argument("--use_tile_vae", action='store_true', default=False)
+ parser.add_argument("--encoder_tile_size", type=int, default=512)
+ parser.add_argument("--decoder_tile_size", type=int, default=64)
+ parser.add_argument("--load_8bit_llava", action='store_true', default=False)
+ parser.add_argument("--outputs_folder", type=str)
+ args = parser.parse_args()
+ use_llava = not args.no_llava
+
+ if torch.cuda.device_count() >= 2:
+ SUPIR_device = 'cuda:0'
+ LLaVA_device = 'cuda:1'
+ elif torch.cuda.device_count() == 1:
+ SUPIR_device = 'cuda:0'
+ LLaVA_device = 'cuda:0'
+ else:
+ raise ValueError('Only CUDA is currently supported.')
+
+ # load SUPIR
+ model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign='Q')
+ if args.loading_half_params:
+ model = model.half()
+ if args.use_tile_vae:
+ model.init_tile_vae(encoder_tile_size=512, decoder_tile_size=64)
+ model = model.to(SUPIR_device)
+ model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
+ model.current_model = 'v0-Q'
+ ckpt_Q, ckpt_F = load_QF_ckpt('options/SUPIR_v0.yaml')
+
+ # load LLaVA
+ if use_llava:
+ llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
+ else:
+ llava_agent = None
+
+ launch_kwargs = {
+ "server_name": args.ip,
+ "server_port": args.port,
+ "share": args.share,
+ "inbrowser": True
+ }
+
+ launch_ui(launch_kwargs)
From ded8b8c7202eb2f74083ecf6c15e10b6171b23a0 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 12:59:30 +0200
Subject: [PATCH 21/28] Support providing outputs_folder as command line
argument
---
.editorconfig | 20 ++++++++++++++++++++
gradio_demo.py | 28 ++++++++++++++++------------
2 files changed, 36 insertions(+), 12 deletions(-)
create mode 100755 .editorconfig
diff --git a/.editorconfig b/.editorconfig
new file mode 100755
index 0000000..65d71f7
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,20 @@
+root = true
+
+[*]
+charset = utf-8
+end_of_line = lf
+insert_final_newline = true
+indent_style = space
+indent_size = 4
+trim_trailing_whitespace = true
+
+[*.md]
+trim_trailing_whitespace = false
+
+[*.yml]
+indent_size = 2
+indent_style = space
+
+[*.yaml]
+indent_size = 2
+indent_style = space
diff --git a/gradio_demo.py b/gradio_demo.py
index fd96d3f..4b28f74 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -129,7 +129,11 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
model.ae_dtype = convert_dtype(ae_dtype)
model.model.dtype = convert_dtype(diff_dtype)
- output_dir = os.path.join("outputs")
+ if args.outputs_folder:
+ output_dir = args.outputs_folder
+ else:
+ output_dir = os.path.join("outputs")
+
if not os.path.exists(output_dir):
os.makedirs(output_dir)
@@ -308,10 +312,10 @@ def launch_ui(launch_kwargs):
with gr.Column():
batch_process_folder = gr.Textbox(
label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.",
- placeholder="e.g. R:\SUPIR video\comparison_images")
+ placeholder="e.g. /workspace/SUPIR_video/comparison_images")
outputs_folder = gr.Textbox(
label="Batch Processing Output Folder Path - If left empty images are saved in default folder",
- placeholder="e.g. R:\SUPIR video\comparison_images\outputs")
+ placeholder="e.g. /workspace/SUPIR_video/comparison_images/outputs"),
with gr.Row():
with gr.Column():
batch_upscale_button = gr.Button(value="Start Batch Upscaling")
@@ -359,26 +363,26 @@ def launch_ui(launch_kwargs):
event_id = gr.Textbox(label="Event ID", value="", visible=False)
llava_button.click(fn=llava_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
- denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
- outputs=[denoise_image])
+ denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction], outputs=[denoise_image])
stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
random_seed]
diffusion_button.click(fn=stage2_process, inputs=stage2_ips,
- outputs=[result_gallery, event_id, fb_score, fb_text, seed], show_progress=True, queue=True)
+ outputs=[result_gallery, event_id, fb_score, fb_text, seed],
+ show_progress=True,
+ queue=True)
restart_button.click(fn=load_and_reset, inputs=[param_setting],
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
stage2_ips_batch = [batch_process_folder, outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale,
- edm_steps, s_stage1, s_stage2,
- s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
- random_seed]
+ edm_steps, s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype,
+ ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG,
+ spt_linear_s_stage2, model_select, num_images, random_seed]
batch_upscale_button.click(fn=batch_upscale, inputs=stage2_ips_batch, outputs=outputlabel, show_progress=True,
queue=True)
- interface.launch(server_name=server_ip, server_port=server_port, share=args.share, inbrowser=True)
+ interface.launch(**launch_kwargs)
if __name__ == "__main__":
@@ -394,7 +398,7 @@ def launch_ui(launch_kwargs):
parser.add_argument("--encoder_tile_size", type=int, default=512)
parser.add_argument("--decoder_tile_size", type=int, default=64)
parser.add_argument("--load_8bit_llava", action='store_true', default=False)
- parser.add_argument("--outputs_folder", type=str)
+ parser.add_argument("--outputs_folder", type=str, default='outputs')
args = parser.parse_args()
use_llava = not args.no_llava
From 9656f0a471d2e032000aac285d6fc28238144656 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 14:25:23 +0200
Subject: [PATCH 22/28] Fixes and remove duplicate code
---
gradio_demo.py | 33 ++++++++++++++-------------------
1 file changed, 14 insertions(+), 19 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 4b28f74..9a8f985 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -129,7 +129,9 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
model.ae_dtype = convert_dtype(ae_dtype)
model.model.dtype = convert_dtype(diff_dtype)
- if args.outputs_folder:
+ if outputs_folder.strip() != "":
+ output_dir = outputs_folder
+ elif args.outputs_folder:
output_dir = args.outputs_folder
else:
output_dir = os.path.join("outputs")
@@ -137,11 +139,6 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
if not os.path.exists(output_dir):
os.makedirs(output_dir)
- if outputs_folder.strip() != "" and outputs_folder != "outputs":
- output_dir = outputs_folder
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
all_results = []
counter = 1
if not dont_update_progress:
@@ -313,13 +310,13 @@ def launch_ui(launch_kwargs):
batch_process_folder = gr.Textbox(
label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.",
placeholder="e.g. /workspace/SUPIR_video/comparison_images")
- outputs_folder = gr.Textbox(
+ batch_outputs_folder = gr.Textbox(
label="Batch Processing Output Folder Path - If left empty images are saved in default folder",
- placeholder="e.g. /workspace/SUPIR_video/comparison_images/outputs"),
+ placeholder="e.g. /workspace/SUPIR_video/comparison_images/outputs")
with gr.Row():
with gr.Column():
batch_upscale_button = gr.Button(value="Start Batch Upscaling")
- outputlabel = gr.Label("Batch Processing Progress")
+ batch_output_label = gr.Label("Batch Processing Progress")
with gr.Row():
with gr.Column():
param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
@@ -364,11 +361,12 @@ def launch_ui(launch_kwargs):
llava_button.click(fn=llava_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction], outputs=[denoise_image])
- stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
- s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
- random_seed]
- diffusion_button.click(fn=stage2_process, inputs=stage2_ips,
+ stage_2_common_inputs = [prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed]
+ stage2_inputs = [input_image] + stage_2_common_inputs
+ diffusion_button.click(fn=stage2_process, inputs=stage2_inputs,
outputs=[result_gallery, event_id, fb_score, fb_text, seed],
show_progress=True,
queue=True)
@@ -376,11 +374,8 @@ def launch_ui(launch_kwargs):
outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
- stage2_ips_batch = [batch_process_folder, outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale,
- edm_steps, s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype,
- ae_dtype, gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG,
- spt_linear_s_stage2, model_select, num_images, random_seed]
- batch_upscale_button.click(fn=batch_upscale, inputs=stage2_ips_batch, outputs=outputlabel, show_progress=True,
+ stage2_inputs_batch = [batch_process_folder, batch_outputs_folder] + stage_2_common_inputs
+ batch_upscale_button.click(fn=batch_upscale, inputs=stage2_inputs_batch, outputs=batch_output_label, show_progress=True,
queue=True)
interface.launch(**launch_kwargs)
From 6c84eb15e81595a7491277d005745eac96930b15 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 14:54:11 +0200
Subject: [PATCH 23/28] Additional implementation for outputs_folder command
line argument
---
gradio_demo.py | 17 ++++++++++-------
1 file changed, 10 insertions(+), 7 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 9a8f985..6a2669c 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -61,7 +61,7 @@ def batch_upscale(batch_process_folder, outputs_folder, prompt, a_prompt, n_prom
# Iterate over all image files in the folder
for index, file_name in enumerate(image_files):
try:
- progress((index + 1) / total_images, f"Processing {index + 1}/{total_images} image")
+ progress((index + 1) / total_images, f"Processing {index + 1}/{total_images} images")
# Construct the full file path
file_path = os.path.join(batch_process_folder, file_name)
prompt = main_prompt
@@ -88,7 +88,7 @@ def batch_upscale(batch_process_folder, outputs_folder, prompt, a_prompt, n_prom
except Exception as e:
print(f"Error processing {file_name}: {e}")
continue
- return "All Done"
+ return "Batch Processing Complete"
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
@@ -131,8 +131,8 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
if outputs_folder.strip() != "":
output_dir = outputs_folder
- elif args.outputs_folder:
- output_dir = args.outputs_folder
+ elif outputs_folder_arg:
+ output_dir = outputs_folder_arg
else:
output_dir = os.path.join("outputs")
@@ -163,7 +163,6 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
if not dont_update_progress:
progress(counter / num_images, desc=desc)
print(desc) # Print the progress
- start_time = time.time() # Reset the start time for the next image
for i, result in enumerate(results):
timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')[:-3]
@@ -309,10 +308,13 @@ def launch_ui(launch_kwargs):
with gr.Column():
batch_process_folder = gr.Textbox(
label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.",
- placeholder="e.g. /workspace/SUPIR_video/comparison_images")
+ placeholder="e.g. /workspace/SUPIR_video/comparison_images"
+ )
batch_outputs_folder = gr.Textbox(
label="Batch Processing Output Folder Path - If left empty images are saved in default folder",
- placeholder="e.g. /workspace/SUPIR_video/comparison_images/outputs")
+ placeholder="e.g. /workspace/SUPIR_video/comparison_images/outputs",
+ value=outputs_folder_arg
+ )
with gr.Row():
with gr.Column():
batch_upscale_button = gr.Button(value="Start Batch Upscaling")
@@ -396,6 +398,7 @@ def launch_ui(launch_kwargs):
parser.add_argument("--outputs_folder", type=str, default='outputs')
args = parser.parse_args()
use_llava = not args.no_llava
+ outputs_folder_arg = args.outputs_folder
if torch.cuda.device_count() >= 2:
SUPIR_device = 'cuda:0'
From 454c3f1650060e4b3e8fea57ea64018ddd613585 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 15:23:08 +0200
Subject: [PATCH 24/28] Added Replicate demo and RunPod template to README
---
README.md | 9 ++++++---
1 file changed, 6 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index fa22c4a..bb2cec5 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
## (CVPR2024) Scaling Up to Excellence: Practicing Model Scaling for Photo-Realistic Image Restoration In the Wild
-> [[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [Online Demo (Coming soon)]
+> [[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)]
> Fanghua, Yu, [Jinjin Gu](https://www.jasongt.com/), Zheyuan Li, Jinfan Hu, Xiangtao Kong, [Xintao Wang](https://xinntao.github.io/), [Jingwen He](https://scholar.google.com.hk/citations?user=GUxrycUAAAAJ), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ)
> Shenzhen Institute of Advanced Technology; Shanghai AI Laboratory; University of Sydney; The Hong Kong Polytechnic University; ARC Lab, Tencent PCG; The Chinese University of Hong Kong
@@ -36,7 +36,7 @@ For users who can connect to huggingface, please setting `LLAVA_CLIP_PATH, SDXL_
#### Dependent Models
* [SDXL CLIP Encoder-1](https://huggingface.co/openai/clip-vit-large-patch14)
* [SDXL CLIP Encoder-2](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
-* [SDXL base 1.0_0.9vae](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors)
+* [Juggernaut-XL_v9_RunDiffusionPhoto_v2](https://huggingface.co/RunDiffusion/Juggernaut-XL-v9/resolve/main/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors)
* [LLaVA CLIP](https://huggingface.co/openai/clip-vit-large-patch14-336)
* [LLaVA v1.5 7B](https://huggingface.co/liuhaotian/llava-v1.5-7b)
@@ -117,7 +117,10 @@ CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_im
-### Online Demo (Coming Soon)
+### Online Resources & Demos
+
+1. [Replicate Demo](https://replicate.com/cjwbw/supir)
+2. [RunPod Template](https://runpod.io/console/gpu-cloud?template=aa31uo64wv&ref=2xxro4sy)
---
From ff9ad551e1b8319edb0410acc1da95cfcf25d242 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Thu, 29 Feb 2024 15:25:47 +0200
Subject: [PATCH 25/28] Added YouTube tutorial
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index bb2cec5..8c37031 100644
--- a/README.md
+++ b/README.md
@@ -121,7 +121,7 @@ CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_im
1. [Replicate Demo](https://replicate.com/cjwbw/supir)
2. [RunPod Template](https://runpod.io/console/gpu-cloud?template=aa31uo64wv&ref=2xxro4sy)
-
+3. [YouTube Tutorial](https://www.youtube.com/watch?v=PqREA6-bC3w)
---
From ea332a28a7608a6e2b9c95b92ccc0f1367aa152b Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Fri, 1 Mar 2024 12:56:05 +0200
Subject: [PATCH 26/28] Revert SDXL_CKPT to base SDXL model since Juggernaut is
not very good with eyes
---
README.md | 2 +-
download_models.py | 2 +-
options/SUPIR_v0.yaml | 13 ++++++-------
3 files changed, 8 insertions(+), 9 deletions(-)
diff --git a/README.md b/README.md
index 8c37031..169ec9b 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ For users who can connect to huggingface, please setting `LLAVA_CLIP_PATH, SDXL_
#### Dependent Models
* [SDXL CLIP Encoder-1](https://huggingface.co/openai/clip-vit-large-patch14)
* [SDXL CLIP Encoder-2](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
-* [Juggernaut-XL_v9_RunDiffusionPhoto_v2](https://huggingface.co/RunDiffusion/Juggernaut-XL-v9/resolve/main/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors)
+* [sd_xl_base_1.0_0.9vae](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors)
* [LLaVA CLIP](https://huggingface.co/openai/clip-vit-large-patch14-336)
* [LLaVA v1.5 7B](https://huggingface.co/liuhaotian/llava-v1.5-7b)
diff --git a/download_models.py b/download_models.py
index 4058995..e7cced6 100755
--- a/download_models.py
+++ b/download_models.py
@@ -42,7 +42,7 @@ def download_file(url, folder_path, file_name=None):
folders_and_files = {
os.path.join('models'): [
('https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/resolve/main/open_clip_pytorch_model.bin', None),
- ('https://huggingface.co/RunDiffusion/Juggernaut-XL-v9/resolve/main/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors', None),
+ ('https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors', None),
('https://huggingface.co/ashleykleynhans/SUPIR/resolve/main/SUPIR-v0F.ckpt', None),
('https://huggingface.co/ashleykleynhans/SUPIR/resolve/main/SUPIR-v0Q.ckpt', None),
]
diff --git a/options/SUPIR_v0.yaml b/options/SUPIR_v0.yaml
index 675ecba..5252885 100644
--- a/options/SUPIR_v0.yaml
+++ b/options/SUPIR_v0.yaml
@@ -141,16 +141,15 @@ model:
scale_min: 4.0
p_p:
- 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera,
- hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing,
- skin pore detailing, hyper sharpness, perfect without deformations.'
+ 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera,
+ hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing,
+ skin pore detailing, hyper sharpness, perfect without deformations.'
n_p:
- 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render,
- unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature,
+ 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render,
+ unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature,
jpeg artifacts, deformed, lowres, over-smooth'
-SDXL_CKPT: models/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors
+SDXL_CKPT: models/sd_xl_base_1.0_0.9vae.safetensors
SUPIR_CKPT_F: models/SUPIR-v0F.ckpt
SUPIR_CKPT_Q: models/SUPIR-v0Q.ckpt
SUPIR_CKPT: ~
-
From 9a5689bfe7fc42b2b0c17ebfdb3c311f13716e65 Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Sat, 2 Mar 2024 21:32:11 +0200
Subject: [PATCH 27/28] Remove unused code
---
gradio_demo.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index 6a2669c..e425824 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -1,10 +1,9 @@
import os
-from pickle import TRUE
import gradio as gr
from gradio_imageslider import ImageSlider
import argparse
-from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype, Tensor2PIL
+from SUPIR.util import HWC3, upscale_image, fix_resize, convert_dtype
import numpy as np
import torch
from SUPIR.util import create_SUPIR_model, load_QF_ckpt
From 37715afdd46cecf5245f1fc5afc515c850f21dcd Mon Sep 17 00:00:00 2001
From: Ashley Kleynhans
Date: Mon, 4 Mar 2024 11:44:39 +0200
Subject: [PATCH 28/28] Retain original filename in batch processing
---
gradio_demo.py | 22 +++++++++++++++-------
1 file changed, 15 insertions(+), 7 deletions(-)
diff --git a/gradio_demo.py b/gradio_demo.py
index a54f79c..4a2f239 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -81,7 +81,7 @@ def batch_upscale(batch_process_folder, outputs_folder, prompt, a_prompt, n_prom
stage2_process(img_array, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
- random_seed, dont_update_progress=True, outputs_folder=outputs_folder)
+ random_seed, dont_update_progress=True, outputs_folder=outputs_folder, file_name=base_name)
# Update progress
except Exception as e:
@@ -93,7 +93,8 @@ def batch_upscale(batch_process_folder, outputs_folder, prompt, a_prompt, n_prom
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
- random_seed, dont_update_progress=False, outputs_folder="outputs", progress=gr.Progress()):
+ random_seed, dont_update_progress=False, outputs_folder="outputs", file_name=None,
+ progress=gr.Progress()):
torch.cuda.set_device(SUPIR_device)
event_id = str(time.time_ns())
@@ -145,7 +146,7 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
counter = 1
if not dont_update_progress:
progress(0 / num_images, desc="Generating images")
- for _ in range(num_images):
+ for img_num in range(num_images):
if random_seed or num_images > 1:
seed = np.random.randint(0, 2147483647)
start_time = time.time() # Track the start time
@@ -167,9 +168,17 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
print(desc) # Print the progress
for i, result in enumerate(results):
- timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')[:-3]
- save_path = os.path.join(output_dir, f'{timestamp}.png')
- Image.fromarray(result).save(save_path)
+ if file_name:
+ if num_images == 1:
+ output_filename = f'{file_name}_upscaled.png'
+ else:
+ img_index = img_num + 1
+ output_filename = f'{file_name}_upscaled_{img_index}.png'
+ else:
+ timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')[:-3]
+ output_filename = f'{timestamp}.png'
+
+ Image.fromarray(result).save(os.path.join(output_dir, output_filename))
all_results.extend(results)
if args.log_history:
@@ -183,7 +192,6 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
return [input_image] + all_results, event_id, 3, '', seed
-
def load_and_reset(param_setting):
edm_steps = 50
s_stage2 = 1.0