diff --git a/.editorconfig b/.editorconfig
new file mode 100755
index 0000000..65d71f7
--- /dev/null
+++ b/.editorconfig
@@ -0,0 +1,20 @@
+root = true
+
+[*]
+charset = utf-8
+end_of_line = lf
+insert_final_newline = true
+indent_style = space
+indent_size = 4
+trim_trailing_whitespace = true
+
+[*.md]
+trim_trailing_whitespace = false
+
+[*.yml]
+indent_size = 2
+indent_style = space
+
+[*.yaml]
+indent_size = 2
+indent_style = space
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..ecbfa5f
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,6 @@
+models
+__pycache__
+venv
+.idea
+.vs
+outputs
diff --git a/.idea/.gitignore b/.idea/.gitignore
deleted file mode 100644
index 13566b8..0000000
--- a/.idea/.gitignore
+++ /dev/null
@@ -1,8 +0,0 @@
-# Default ignored files
-/shelf/
-/workspace.xml
-# Editor-based HTTP Client requests
-/httpRequests/
-# Datasource local storage ignored files
-/dataSources/
-/dataSources.local.xml
diff --git a/.idea/.name b/.idea/.name
deleted file mode 100644
index 0b8d3da..0000000
--- a/.idea/.name
+++ /dev/null
@@ -1 +0,0 @@
-Diff4R
\ No newline at end of file
diff --git a/.idea/SUPIR.iml b/.idea/SUPIR.iml
deleted file mode 100644
index de95779..0000000
--- a/.idea/SUPIR.iml
+++ /dev/null
@@ -1,10 +0,0 @@
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
deleted file mode 100644
index 768ffba..0000000
--- a/.idea/deployment.xml
+++ /dev/null
@@ -1,23 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
deleted file mode 100644
index 7410f1c..0000000
--- a/.idea/inspectionProfiles/Project_Default.xml
+++ /dev/null
@@ -1,38 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
deleted file mode 100644
index 105ce2d..0000000
--- a/.idea/inspectionProfiles/profiles_settings.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
index 874415a..0000000
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index 8ae5627..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 35eb1dd..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/webResources.xml b/.idea/webResources.xml
deleted file mode 100644
index 717d9d6..0000000
--- a/.idea/webResources.xml
+++ /dev/null
@@ -1,14 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/CKPT_PTH.py b/CKPT_PTH.py
index 4ff9ebf..8d64e27 100644
--- a/CKPT_PTH.py
+++ b/CKPT_PTH.py
@@ -1,4 +1,4 @@
-LLAVA_CLIP_PATH = '/opt/data/private/AIGC_pretrain/LLaVA1.5/clip-vit-large-patch14-336'
-LLAVA_MODEL_PATH = '/opt/data/private/AIGC_pretrain/LLaVA1.5/llava-v1.5-13b'
-SDXL_CLIP1_PATH = '/opt/data/private/AIGC_pretrain/clip-vit-large-patch14'
-SDXL_CLIP2_CKPT_PTH = '/opt/data/private/AIGC_pretrain/CLIP-ViT-bigG-14-laion2B-39B-b160k/open_clip_pytorch_model.bin'
\ No newline at end of file
+LLAVA_CLIP_PATH = 'openai/clip-vit-large-patch14-336'
+LLAVA_MODEL_PATH = 'liuhaotian/llava-v1.5-7b'
+SDXL_CLIP1_PATH = 'openai/clip-vit-large-patch14'
+SDXL_CLIP2_CKPT_PTH = 'models/open_clip_pytorch_model.bin'
diff --git a/README.md b/README.md
index 76402fe..52c32bb 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
## (CVPR2024) Scaling Up to Excellence: Practicing Model Scaling for Photo-Realistic Image Restoration In the Wild
-> [[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [Online Demo (Coming soon)]
+> [[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)]
> Fanghua, Yu, [Jinjin Gu](https://www.jasongt.com/), Zheyuan Li, Jinfan Hu, Xiangtao Kong, [Xintao Wang](https://xinntao.github.io/), [Jingwen He](https://scholar.google.com.hk/citations?user=GUxrycUAAAAJ), [Yu Qiao](https://scholar.google.com.hk/citations?user=gFtI-8QAAAAJ), [Chao Dong](https://scholar.google.com.hk/citations?user=OSDCB0UAAAAJ)
> Shenzhen Institute of Advanced Technology; Shanghai AI Laboratory; University of Sydney; The Hong Kong Polytechnic University; ARC Lab, Tencent PCG; The Chinese University of Hong Kong
@@ -38,7 +38,7 @@ For users who can connect to huggingface, please setting `LLAVA_CLIP_PATH, SDXL_
* [SDXL CLIP Encoder-2](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
* [SDXL base 1.0_0.9vae](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_base_1.0_0.9vae.safetensors)
* [LLaVA CLIP](https://huggingface.co/openai/clip-vit-large-patch14-336)
-* [LLaVA v1.5 13B](https://huggingface.co/liuhaotian/llava-v1.5-13b)
+* [LLaVA v1.5 7B](https://huggingface.co/liuhaotian/llava-v1.5-7b)
* (optional) [Juggernaut-XL_v9_RunDiffusionPhoto_v2](https://huggingface.co/RunDiffusion/Juggernaut-XL-v9/blob/main/Juggernaut-XL_v9_RunDiffusionPhoto_v2.safetensors)
* Replacement of `SDXL base 1.0_0.9vae` for Photo Realistic
* (optional) [Juggernaut_RunDiffusionPhoto2_Lightning_4Steps](https://huggingface.co/RunDiffusion/Juggernaut-XL-Lightning/blob/main/Juggernaut_RunDiffusionPhoto2_Lightning_4Steps.safetensors)
@@ -124,8 +124,11 @@ CUDA_VISIBLE_DEVICES=0,1 python gradio_demo.py --ip 0.0.0.0 --port 6688 --use_im
-### Online Demo (Coming Soon)
+### Online Resources & Demos
+1. [Replicate Demo](https://replicate.com/cjwbw/supir)
+2. [RunPod Template](https://runpod.io/console/gpu-cloud?template=aa31uo64wv&ref=2xxro4sy)
+3. [YouTube Tutorial](https://www.youtube.com/watch?v=PqREA6-bC3w)
---
diff --git a/download_models.py b/download_models.py
new file mode 100755
index 0000000..e7cced6
--- /dev/null
+++ b/download_models.py
@@ -0,0 +1,69 @@
+#!/usr/bin/env python3
+import os
+import requests
+from tqdm import tqdm
+from huggingface_hub import snapshot_download
+
+
+def create_directory(path):
+ """Create directory if it does not exist."""
+ if not os.path.exists(path):
+ os.makedirs(path)
+ print(f'Directory created: {path}')
+ else:
+ print(f'Directory already exists: {path}')
+
+
+def download_file(url, folder_path, file_name=None):
+ """Download a file from a given URL to a specified folder with an optional file name."""
+ local_filename = file_name if file_name else url.split('/')[-1]
+ local_filepath = os.path.join(folder_path, local_filename)
+ print(f'Downloading {url} to: {local_filepath}')
+
+ # Stream download to handle large files
+ with requests.get(url, stream=True) as r:
+ r.raise_for_status()
+ total_size_in_bytes = int(r.headers.get('content-length', 0))
+ block_size = 1024 # 1 Kibibyte
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
+ with open(local_filepath, 'wb') as f:
+ for data in r.iter_content(block_size):
+ progress_bar.update(len(data))
+ f.write(data)
+ progress_bar.close()
+
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
+ print('ERROR, something went wrong')
+ else:
+ print(f'Downloaded {local_filename} to {folder_path}')
+
+
+# Define the folders and their corresponding file URLs with optional file names
+folders_and_files = {
+ os.path.join('models'): [
+ ('https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/resolve/main/open_clip_pytorch_model.bin', None),
+ ('https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0_0.9vae.safetensors', None),
+ ('https://huggingface.co/ashleykleynhans/SUPIR/resolve/main/SUPIR-v0F.ckpt', None),
+ ('https://huggingface.co/ashleykleynhans/SUPIR/resolve/main/SUPIR-v0Q.ckpt', None),
+ ]
+}
+
+
+if __name__ == '__main__':
+ for folder, files in folders_and_files.items():
+ create_directory(folder)
+ for file_url, file_name in files:
+ download_file(file_url, folder, file_name)
+
+ llava_model = os.getenv('LLAVA_MODEL', 'liuhaotian/llava-v1.5-7b')
+ llava_clip_model = 'openai/clip-vit-large-patch14-336'
+ sdxl_clip_model = 'openai/clip-vit-large-patch14'
+
+ print(f'Downloading LLaVA model: {llava_model}')
+ snapshot_download(llava_model)
+
+ print(f'Downloading LLaVA CLIP model: {llava_clip_model}')
+ snapshot_download(llava_clip_model)
+
+ print(f'Downloading SDXL CLIP model: {sdxl_clip_model}')
+ snapshot_download(sdxl_clip_model)
diff --git a/gradio_demo.py b/gradio_demo.py
index e399084..4993b68 100644
--- a/gradio_demo.py
+++ b/gradio_demo.py
@@ -12,50 +12,9 @@
from CKPT_PTH import LLAVA_MODEL_PATH
import einops
import copy
+import datetime
import time
-parser = argparse.ArgumentParser()
-parser.add_argument("--opt", type=str, default='options/SUPIR_v0.yaml')
-parser.add_argument("--ip", type=str, default='127.0.0.1')
-parser.add_argument("--port", type=int, default='6688')
-parser.add_argument("--no_llava", action='store_true', default=False)
-parser.add_argument("--use_image_slider", action='store_true', default=False)
-parser.add_argument("--log_history", action='store_true', default=False)
-parser.add_argument("--loading_half_params", action='store_true', default=False)
-parser.add_argument("--use_tile_vae", action='store_true', default=False)
-parser.add_argument("--encoder_tile_size", type=int, default=512)
-parser.add_argument("--decoder_tile_size", type=int, default=64)
-parser.add_argument("--load_8bit_llava", action='store_true', default=False)
-args = parser.parse_args()
-server_ip = args.ip
-server_port = args.port
-use_llava = not args.no_llava
-
-if torch.cuda.device_count() >= 2:
- SUPIR_device = 'cuda:0'
- LLaVA_device = 'cuda:1'
-elif torch.cuda.device_count() == 1:
- SUPIR_device = 'cuda:0'
- LLaVA_device = 'cuda:0'
-else:
- raise ValueError('Currently support CUDA only.')
-
-# load SUPIR
-model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
-if args.loading_half_params:
- model = model.half()
-if args.use_tile_vae:
- model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
-model = model.to(SUPIR_device)
-model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
-model.current_model = 'v0-Q'
-ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
-
-# load LLaVA
-if use_llava:
- llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
-else:
- llava_agent = None
def stage1_process(input_image, gamma_correction):
torch.cuda.set_device(SUPIR_device)
@@ -73,9 +32,10 @@ def stage1_process(input_image, gamma_correction):
LQ = LQ.round().clip(0, 255).astype(np.uint8)
return LQ
-def llave_process(input_image, temperature, top_p, qs=None):
- torch.cuda.set_device(LLaVA_device)
+
+def llava_process(input_image, temperature, top_p, qs=None):
if use_llava:
+ torch.cuda.set_device(LLaVA_device)
LQ = HWC3(input_image)
LQ = Image.fromarray(LQ.astype('uint8'))
captions = llava_agent.gen_image_caption([LQ], temperature=temperature, top_p=top_p, qs=qs)
@@ -83,10 +43,60 @@ def llave_process(input_image, temperature, top_p, qs=None):
captions = ['LLaVA is not available. Please add text manually.']
return captions[0]
+
+def batch_upscale(batch_process_folder, outputs_folder, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps,
+ s_stage1, s_stage2, s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype,
+ gamma_correction, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select,
+ num_images, random_seed, progress=gr.Progress()):
+ import os
+ import numpy as np
+ from PIL import Image
+
+ # Get the list of image files in the folder
+ image_files = [file for file in os.listdir(batch_process_folder) if
+ file.lower().endswith((".png", ".jpg", ".jpeg"))]
+ total_images = len(image_files)
+ main_prompt = prompt
+ # Iterate over all image files in the folder
+ for index, file_name in enumerate(image_files):
+ try:
+ progress((index + 1) / total_images, f"Processing {index + 1}/{total_images} images")
+ # Construct the full file path
+ file_path = os.path.join(batch_process_folder, file_name)
+ prompt = main_prompt
+ # Open the image file and convert it to a NumPy array
+ with Image.open(file_path) as img:
+ img_array = np.asarray(img)
+
+ # Construct the path for the prompt text file
+ base_name = os.path.splitext(file_name)[0]
+ prompt_file_path = os.path.join(batch_process_folder, f"{base_name}.txt")
+
+ # Read the prompt from the text file
+ if os.path.exists(prompt_file_path):
+ with open(prompt_file_path, "r", encoding="utf-8") as f:
+ prompt = f.read().strip()
+
+ # Call the stage2_process method for the image
+ stage2_process(img_array, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed, dont_update_progress=True, outputs_folder=outputs_folder, file_name=base_name)
+
+ # Update progress
+ except Exception as e:
+ print(f"Error processing {file_name}: {e}")
+ continue
+ return "Batch Processing Complete"
+
+
def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select):
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed, dont_update_progress=False, outputs_folder="outputs", file_name=None,
+ progress=gr.Progress()):
torch.cuda.set_device(SUPIR_device)
+
event_id = str(time.time_ns())
event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
@@ -106,8 +116,7 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
model.load_state_dict(ckpt_F, strict=False)
model.current_model = 'v0-F'
input_image = HWC3(input_image)
- input_image = upscale_image(input_image, upscale, unit_resolution=32,
- min_size=1024)
+ input_image = upscale_image(input_image, upscale, unit_resolution=32, min_size=1024)
LQ = np.array(input_image) / 255.0
LQ = np.power(LQ, gamma_correction)
@@ -123,15 +132,54 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
model.ae_dtype = convert_dtype(ae_dtype)
model.model.dtype = convert_dtype(diff_dtype)
- samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
- s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
- num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
- use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
- cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
+ if outputs_folder.strip() != "":
+ output_dir = outputs_folder
+ elif outputs_folder_arg:
+ output_dir = outputs_folder_arg
+ else:
+ output_dir = os.path.join("outputs")
+
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+
+ all_results = []
+ counter = 1
+ if not dont_update_progress:
+ progress(0 / num_images, desc="Generating images")
+ for img_num in range(num_images):
+ if random_seed or num_images > 1:
+ seed = np.random.randint(0, 2147483647)
+ start_time = time.time() # Track the start time
+ samples = model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
+ s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
+ num_samples=num_samples, p_p=a_prompt, n_p=n_prompt,
+ color_fix_type=color_fix_type,
+ use_linear_CFG=linear_CFG, use_linear_control_scale=linear_s_stage2,
+ cfg_scale_start=spt_linear_CFG, control_scale_start=spt_linear_s_stage2)
- x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
- 0, 255).astype(np.uint8)
- results = [x_samples[i] for i in range(num_samples)]
+ x_samples = (einops.rearrange(samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().round().clip(
+ 0, 255).astype(np.uint8)
+ results = [x_samples[i] for i in range(num_samples)]
+ image_generation_time = time.time() - start_time
+ desc = f"Generated image {counter}/{num_images} in {image_generation_time:.2f} seconds"
+ counter = counter + 1
+ if not dont_update_progress:
+ progress(counter / num_images, desc=desc)
+ print(desc) # Print the progress
+
+ for i, result in enumerate(results):
+ if file_name:
+ if num_images == 1:
+ output_filename = f'{file_name}_upscaled.png'
+ else:
+ img_index = img_num + 1
+ output_filename = f'{file_name}_upscaled_{img_index}.png'
+ else:
+ timestamp = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')[:-3]
+ output_filename = f'{timestamp}.png'
+
+ Image.fromarray(result).save(os.path.join(output_dir, output_filename))
+ all_results.extend(results)
if args.log_history:
os.makedirs(f'./history/{event_id[:5]}/{event_id[5:]}', exist_ok=True)
@@ -139,9 +187,9 @@ def stage2_process(input_image, prompt, a_prompt, n_prompt, num_samples, upscale
f.write(str(event_dict))
f.close()
Image.fromarray(input_image).save(f'./history/{event_id[:5]}/{event_id[5:]}/LQ.png')
- for i, result in enumerate(results):
+ for i, result in enumerate(all_results):
Image.fromarray(result).save(f'./history/{event_id[:5]}/{event_id[5:]}/HQ_{i}.png')
- return [input_image] + results, event_id, 3, ''
+ return [input_image] + all_results, event_id, 3, '', seed
def load_and_reset(param_setting):
@@ -186,74 +234,114 @@ def submit_feedback(event_id, fb_score, fb_text):
return 'Submit failed, the server is not set to log history.'
-title_md = """
-# **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
-
-⚠️SUPIR is still a research project under tested and is not yet a stable commercial product.
-
-[[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
-"""
-
-
-claim_md = """
-## **Terms of use**
-
-By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
-
-## **License**
-
-The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
-"""
-
-
-block = gr.Blocks(title='SUPIR').queue()
-with block:
- with gr.Row():
- gr.Markdown(title_md)
- with gr.Row():
- with gr.Column():
- with gr.Row(equal_height=True):
- with gr.Column():
- gr.Markdown("Input ")
- input_image = gr.Image(type="numpy", elem_id="image-input", height=400, width=400)
- with gr.Column():
- gr.Markdown("Stage1 Output ")
- denoise_image = gr.Image(type="numpy", elem_id="image-s1", height=400, width=400)
- prompt = gr.Textbox(label="Prompt", value="")
- with gr.Accordion("Stage1 options", open=False):
- gamma_correction = gr.Slider(label="Gamma Correction", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
- with gr.Accordion("LLaVA options", open=False):
- temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
- top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
- qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
- "The image is a realistic photography, not an art painting.")
- with gr.Accordion("Stage2 options", open=False):
- num_samples = gr.Slider(label="Num Samples", minimum=1, maximum=4 if not args.use_image_slider else 1
- , value=1, step=1)
- upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=1)
- edm_steps = gr.Slider(label="Steps", minimum=1, maximum=200, value=default_setting.edm_steps, step=1)
- s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0,
- value=default_setting.s_cfg_Quality, step=0.1)
- s_stage2 = gr.Slider(label="Stage2 Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
- s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0, step=1.0)
- seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
- s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
- s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
- a_prompt = gr.Textbox(label="Default Positive Prompt",
- value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
- 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
- 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
- 'hyper sharpness, perfect without deformations.')
- n_prompt = gr.Textbox(label="Default Negative Prompt",
- value='painting, oil painting, illustration, drawing, art, sketch, oil painting, '
- 'cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, '
- 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
- 'deformed, lowres, over-smooth')
+def launch_ui(launch_kwargs):
+ title_md = """
+ # **SUPIR: Practicing Model Scaling for Photo-Realistic Image Restoration**
+
+ ⚠️SUPIR is still a research project under tested and is not yet a stable commercial product.
+
+ [[Paper](https://arxiv.org/abs/2401.13627)] [[Project Page](http://supir.xpixel.group/)] [[How to play](https://github.com/Fanghua-Yu/SUPIR/blob/master/assets/DemoGuide.png)]
+ """
+
+ claim_md = """
+ ## **Terms of use**
+
+ By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. Please submit a feedback to us if you get any inappropriate answer! We will collect those to keep improving our models. For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
+
+ ## **License**
+
+ The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/Fanghua-Yu/SUPIR) of SUPIR.
+ """
+
+ interface = gr.Blocks(title='SUPIR').queue()
+ with interface:
+ with gr.Row():
+ gr.Markdown(title_md)
+ with gr.Row():
+ with gr.Column():
+ with gr.Row(equal_height=True):
+ with gr.Column():
+ gr.Markdown("Input ")
+ input_image = gr.Image(type="numpy", elem_id="image-input", height=400, width=400)
+ with gr.Column():
+ gr.Markdown("Stage1 Output ")
+ denoise_image = gr.Image(type="numpy", elem_id="image-s1", height=400, width=400)
+ prompt = gr.Textbox(label="Prompt", value="")
+ with gr.Accordion("Stage1 options", open=False):
+ gamma_correction = gr.Slider(label="Gamma Correction", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
+
+ with gr.Accordion("Stage2 options", open=True):
+ with gr.Row():
+ with gr.Column():
+ num_images = gr.Slider(label="Number Of Images To Generate", minimum=1, maximum=200
+ , value=1, step=1)
+ num_samples = gr.Slider(label="Batch Size", minimum=1,
+ maximum=4 if not args.use_image_slider else 1
+ , value=1, step=1)
+ with gr.Column():
+ upscale = gr.Slider(label="Upscale", minimum=1, maximum=8, value=1, step=0.1)
+ random_seed = gr.Checkbox(label="Randomize Seed", value=True)
+ with gr.Row():
+ edm_steps = gr.Slider(label="Steps", minimum=1, maximum=200, value=default_setting.edm_steps, step=1)
+ s_cfg = gr.Slider(label="Text Guidance Scale", minimum=1.0, maximum=15.0,
+ value=default_setting.s_cfg_Quality, step=0.1)
+ s_stage2 = gr.Slider(label="Stage2 Guidance Strength", minimum=0., maximum=1., value=1., step=0.05)
+ s_stage1 = gr.Slider(label="Stage1 Guidance Strength", minimum=-1.0, maximum=6.0, value=-1.0,
+ step=1.0)
+ seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, step=1, randomize=True)
+ s_churn = gr.Slider(label="S-Churn", minimum=0, maximum=40, value=5, step=1)
+ s_noise = gr.Slider(label="S-Noise", minimum=1.0, maximum=1.1, value=1.003, step=0.001)
+ with gr.Row():
+ a_prompt = gr.Textbox(label="Default Positive Prompt",
+ value='Cinematic, High Contrast, highly detailed, taken using a Canon EOS R '
+ 'camera, hyper detailed photo - realistic maximum detail, 32k, Color '
+ 'Grading, ultra HD, extreme meticulous detailing, skin pore detailing, '
+ 'hyper sharpness, perfect without deformations.')
+ n_prompt = gr.Textbox(label="Default Negative Prompt",
+ value='painting, oil painting, illustration, drawing, art, sketch, oil painting, '
+ 'cartoon, CG Style, 3D render, unreal engine, blurring, dirty, messy, '
+ 'worst quality, low quality, frames, watermark, signature, jpeg artifacts, '
+ 'deformed, lowres, over-smooth')
+
+ with gr.Column():
+ gr.Markdown("Upscaled Images Output ")
+ if not args.use_image_slider:
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1")
+ else:
+ result_gallery = ImageSlider(label='Output', show_label=False, elem_id="gallery1")
+ with gr.Row():
+ with gr.Column():
+ denoise_button = gr.Button(value="Stage1 Run")
+ with gr.Column():
+ llava_button = gr.Button(value="LlaVa Run")
+ with gr.Column():
+ diffusion_button = gr.Button(value="Stage2 Run")
+ with gr.Row():
+ with gr.Column():
+ batch_process_folder = gr.Textbox(
+ label="Batch Processing Input Folder Path - If image_file_name.txt exists it will be read and used as prompt (optional). Uses same settings of single upscale (Stage 2 Run). If no caption txt it will use the Prompt you written. It can be empty as well.",
+ placeholder="e.g. /workspace/SUPIR_video/comparison_images"
+ )
+ batch_outputs_folder = gr.Textbox(
+ label="Batch Processing Output Folder Path - If left empty images are saved in default folder",
+ placeholder="e.g. /workspace/SUPIR_video/comparison_images/outputs",
+ value=outputs_folder_arg
+ )
+ with gr.Row():
+ with gr.Column():
+ batch_upscale_button = gr.Button(value="Start Batch Upscaling")
+ batch_output_label = gr.Label("Batch Processing Progress")
+ with gr.Row():
+ with gr.Column():
+ param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
+ value="Quality")
+ with gr.Column():
+ restart_button = gr.Button(value="Reset Param", scale=2)
with gr.Row():
with gr.Column():
linear_CFG = gr.Checkbox(label="Linear CFG", value=True)
spt_linear_CFG = gr.Slider(label="CFG Start", minimum=1.0,
- maximum=9.0, value=default_setting.spt_linear_CFG_Quality, step=0.5)
+ maximum=9.0, value=default_setting.spt_linear_CFG_Quality, step=0.5)
with gr.Column():
linear_s_stage2 = gr.Checkbox(label="Linear Stage2 Guidance", value=False)
spt_linear_s_stage2 = gr.Slider(label="Guidance Start", minimum=0.,
@@ -271,44 +359,91 @@ def submit_feedback(event_id, fb_score, fb_text):
with gr.Column():
model_select = gr.Radio(["v0-Q", "v0-F"], label="Model Selection", value="v0-Q",
interactive=True)
+ with gr.Accordion("LLaVA options", open=False):
+ temperature = gr.Slider(label="Temperature", minimum=0., maximum=1.0, value=0.2, step=0.1)
+ top_p = gr.Slider(label="Top P", minimum=0., maximum=1.0, value=0.7, step=0.1)
+ qs = gr.Textbox(label="Question", value="Describe this image and its style in a very detailed manner. "
+ "The image is a realistic photography, not an art painting.")
+ with gr.Accordion("Feedback", open=False):
+ fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1,
+ interactive=True)
+ fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.')
+ submit_button = gr.Button(value="Submit Feedback")
+ with gr.Row():
+ gr.Markdown(claim_md)
+ event_id = gr.Textbox(label="Event ID", value="", visible=False)
- with gr.Column():
- gr.Markdown("Stage2 Output ")
- if not args.use_image_slider:
- result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery1")
- else:
- result_gallery = ImageSlider(label='Output', show_label=False, elem_id="gallery1")
- with gr.Row():
- with gr.Column():
- denoise_button = gr.Button(value="Stage1 Run")
- with gr.Column():
- llave_button = gr.Button(value="LlaVa Run")
- with gr.Column():
- diffusion_button = gr.Button(value="Stage2 Run")
- with gr.Row():
- with gr.Column():
- param_setting = gr.Dropdown(["Quality", "Fidelity"], interactive=True, label="Param Setting",
- value="Quality")
- with gr.Column():
- restart_button = gr.Button(value="Reset Param", scale=2)
- with gr.Accordion("Feedback", open=True):
- fb_score = gr.Slider(label="Feedback Score", minimum=1, maximum=5, value=3, step=1,
- interactive=True)
- fb_text = gr.Textbox(label="Feedback Text", value="", placeholder='Please enter your feedback here.')
- submit_button = gr.Button(value="Submit Feedback")
- with gr.Row():
- gr.Markdown(claim_md)
- event_id = gr.Textbox(label="Event ID", value="", visible=False)
-
- llave_button.click(fn=llave_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
- denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction],
- outputs=[denoise_image])
- stage2_ips = [input_image, prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
- s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
- linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select]
- diffusion_button.click(fn=stage2_process, inputs=stage2_ips, outputs=[result_gallery, event_id, fb_score, fb_text])
- restart_button.click(fn=load_and_reset, inputs=[param_setting],
- outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
- color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
- submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
-block.launch(server_name=server_ip, server_port=server_port)
+ llava_button.click(fn=llava_process, inputs=[denoise_image, temperature, top_p, qs], outputs=[prompt])
+ denoise_button.click(fn=stage1_process, inputs=[input_image, gamma_correction], outputs=[denoise_image])
+ stage_2_common_inputs = [prompt, a_prompt, n_prompt, num_samples, upscale, edm_steps, s_stage1, s_stage2,
+ s_cfg, seed, s_churn, s_noise, color_fix_type, diff_dtype, ae_dtype, gamma_correction,
+ linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2, model_select, num_images,
+ random_seed]
+ stage2_inputs = [input_image] + stage_2_common_inputs
+ diffusion_button.click(fn=stage2_process, inputs=stage2_inputs,
+ outputs=[result_gallery, event_id, fb_score, fb_text, seed],
+ show_progress=True,
+ queue=True)
+ restart_button.click(fn=load_and_reset, inputs=[param_setting],
+ outputs=[edm_steps, s_cfg, s_stage2, s_stage1, s_churn, s_noise, a_prompt, n_prompt,
+ color_fix_type, linear_CFG, linear_s_stage2, spt_linear_CFG, spt_linear_s_stage2])
+ submit_button.click(fn=submit_feedback, inputs=[event_id, fb_score, fb_text], outputs=[fb_text])
+ stage2_inputs_batch = [batch_process_folder, batch_outputs_folder] + stage_2_common_inputs
+ batch_upscale_button.click(fn=batch_upscale, inputs=stage2_inputs_batch, outputs=batch_output_label, show_progress=True,
+ queue=True)
+ interface.launch(**launch_kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--opt", type=str, default='options/SUPIR_v0.yaml')
+ parser.add_argument("--ip", type=str, default='127.0.0.1')
+ parser.add_argument("--share", type=str, default=False)
+ parser.add_argument("--port", type=int, default=7860)
+ parser.add_argument("--no_llava", action='store_true', default=False)
+ parser.add_argument("--use_image_slider", action='store_true', default=False)
+ parser.add_argument("--log_history", action='store_true', default=False)
+ parser.add_argument("--loading_half_params", action='store_true', default=False)
+ parser.add_argument("--use_tile_vae", action='store_true', default=False)
+ parser.add_argument("--encoder_tile_size", type=int, default=512)
+ parser.add_argument("--decoder_tile_size", type=int, default=64)
+ parser.add_argument("--load_8bit_llava", action='store_true', default=False)
+ parser.add_argument("--outputs_folder", type=str, default='outputs')
+ args = parser.parse_args()
+ use_llava = not args.no_llava
+ outputs_folder_arg = args.outputs_folder
+
+ if torch.cuda.device_count() >= 2:
+ SUPIR_device = 'cuda:0'
+ LLaVA_device = 'cuda:1'
+ elif torch.cuda.device_count() == 1:
+ SUPIR_device = 'cuda:0'
+ LLaVA_device = 'cuda:0'
+ else:
+ raise ValueError('Only CUDA is currently supported.')
+
+ # load SUPIR
+ model, default_setting = create_SUPIR_model(args.opt, SUPIR_sign='Q', load_default_setting=True)
+ if args.loading_half_params:
+ model = model.half()
+ if args.use_tile_vae:
+ model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
+ model = model.to(SUPIR_device)
+ model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
+ model.current_model = 'v0-Q'
+ ckpt_Q, ckpt_F = load_QF_ckpt(args.opt)
+
+ # load LLaVA
+ if use_llava:
+ llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
+ else:
+ llava_agent = None
+
+ launch_kwargs = {
+ "server_name": args.ip,
+ "server_port": args.port,
+ "share": args.share,
+ "inbrowser": True
+ }
+
+ launch_ui(launch_kwargs)
diff --git a/options/SUPIR_v0.yaml b/options/SUPIR_v0.yaml
index 99d76e5..dddbe6c 100644
--- a/options/SUPIR_v0.yaml
+++ b/options/SUPIR_v0.yaml
@@ -141,17 +141,17 @@ model:
scale_min: 4.0
p_p:
- 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera,
- hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing,
- skin pore detailing, hyper sharpness, perfect without deformations.'
+ 'Cinematic, High Contrast, highly detailed, taken using a Canon EOS R camera,
+ hyper detailed photo - realistic maximum detail, 32k, Color Grading, ultra HD, extreme meticulous detailing,
+ skin pore detailing, hyper sharpness, perfect without deformations.'
n_p:
- 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render,
- unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature,
+ 'painting, oil painting, illustration, drawing, art, sketch, oil painting, cartoon, CG Style, 3D render,
+ unreal engine, blurring, dirty, messy, worst quality, low quality, frames, watermark, signature,
jpeg artifacts, deformed, lowres, over-smooth'
-SDXL_CKPT: /opt/data/private/AIGC_pretrain/SDXL_cache/sd_xl_base_1.0_0.9vae.safetensors
-SUPIR_CKPT_F: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0F.ckpt
-SUPIR_CKPT_Q: /opt/data/private/AIGC_pretrain/SUPIR_cache/SUPIR-v0Q.ckpt
+SDXL_CKPT: models/sd_xl_base_1.0_0.9vae.safetensors
+SUPIR_CKPT_F: models/SUPIR-v0F.ckpt
+SUPIR_CKPT_Q: models/SUPIR-v0Q.ckpt
SUPIR_CKPT: ~
default_setting:
diff --git a/requirements.txt b/requirements.txt
index 11e74eb..84b00ce 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,14 +1,10 @@
-fastapi==0.95.1
-gradio==4.16.0
-gradio_imageslider==0.0.17
-gradio_client==0.8.1
Markdown==3.4.1
numpy==1.24.2
requests==2.28.2
sentencepiece==0.1.98
tokenizers==0.13.3
-torch>=2.1.0
-torchvision>=0.16.0
+torch
+torchvision
uvicorn==0.21.1
wandb==0.14.0
httpx==0.24.0
@@ -20,7 +16,7 @@ einops==0.7.0
einops-exts==0.0.4
timm==0.9.8
openai-clip==1.0.1
-fsspec==2023.4.0
+fsspec
kornia==0.6.9
matplotlib==3.7.1
ninja==1.11.1
@@ -33,10 +29,14 @@ pytorch-lightning==2.1.2
PyYAML==6.0
scipy==1.9.1
tqdm==4.65.0
-triton==2.1.0
urllib3==1.26.15
webdataset==0.2.48
xformers>=0.0.20
facexlib==0.3.0
k-diffusion==0.1.1.post1
diffusers==0.16.1
+fastapi
+bitsandbytes
+gradio==4.16.0
+gradio_imageslider==0.0.17
+gradio_client==0.8.1