Skip to content

Latest commit

 

History

History
312 lines (263 loc) · 9.04 KB

README.md

File metadata and controls

312 lines (263 loc) · 9.04 KB

Enhancing Diffusion Models with Text-Encoder Reinforcement Learning

Official PyTorch codes for paper Enhancing Diffusion Models with Text-Encoder Reinforcement Learning

arXiv google colab logo huggingface visitors

teaser_img

Requirements & Installation

  • Clone the repo and install required packages with
# git clone this repository
git clone https://github.com/chaofengc/TexForce.git
cd TexForce 

# create new anaconda env
conda create -n texforce python=3.8
source activate texforce 

# install python dependencies
pip3 install -r requirements.txt

Results on SDXL-Turbo

We also applied our method to the recent model sdxl-turbo. The model is trained with ImageReward feedback through direct back-propagation to save training time. Test with the following codes

## Note: sdturboxl requires latest diffusers installed from source with the following command
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
from diffusers import AutoPipelineForText2Image
import torch

pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
pipe.load_lora_weights('chaofengc/sdxl-turbo_texforce')

pt = ['a photo of a cat.']
img = pipe(prompt=pt, num_inference_steps=1, guidance_scale=0.0).images[0]

Here are some example results:

sdxl-turbo sdxl-turbo + TexForce
A photo of a cat.
An astronaut riding a horse.
water bottle.

Results on SD-Turbo

We applied our method to the recent model sdturbo. The model is trained with Q-Instruct feedback through direct back-propagation to save training time. Test with the following codes

## Note: sdturbo requires latest diffusers>=0.24.0 with AutoPipelineForText2Image class

from diffusers import AutoPipelineForText2Image
from peft import PeftModel
import torch

pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
PeftModel.from_pretrained(pipe.text_encoder, 'chaofengc/sd-turbo_texforce')

pt = ['a photo of a cat.']
img = pipe(prompt=pt, num_inference_steps=1, guidance_scale=0.0).images[0]

Here are some example results:

sd-turbo sd-turbo + TexForce
A photo of a cat.
A photo of a dog.
A photo of a boy, colorful.

Results on SD-1.4, SD-1.5, SD-2.1

Due to code compatibility, you need to install the following diffusers first:

pip uninstall diffusers
pip install diffusers==0.16.0

You may simply load the pretrained lora weights with the following code block to improve performance of original stable diffusion model:

from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler 
from peft import PeftModel
import torch

def load_model_weights(pipe, weight_path, model_type):
    if model_type == 'text+lora':
        text_encoder = pipe.text_encoder
        PeftModel.from_pretrained(text_encoder, weight_path)
    elif model_type == 'unet+lora':
        pipe.unet.load_attn_procs(weight_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

load_model_weights(pipe, './lora_weights/sd14_refl/', 'unet+lora')
load_model_weights(pipe, './lora_weights/sd14_texforce/', 'text+lora')

prompt = ['a painting of a dog.']
img = pipe(prompt).images[0]

Here are some example results:

SDv1.4 ReFL TexForce ReFL+TexForce
astronaut drifting afloat in space, in the darkness away from anyone else, alone, black background dotted with stars, realistic
portrait of a cute cyberpunk cat, realistic, professional
a coffee mug made of cardboard

Training

We rewrite the training codes based on trl with the latest diffusers library.

Note

The latest diffusers support simple loading of lora weights with pipeline.load_lora_weights after training.

You may train the model with the following command:

Example script for single prompt training

accelerate launch --num_processes 2 src/train_ddpo.py \
    --mixed_precision="fp16" \
    --sample_num_steps 50 --train_timestep_fraction 0.5 \
    --num_epochs 40 \
    --sample_batch_size 4 --sample_num_batches_per_epoch 64 \
    --train_batch_size 4 --train_gradient_accumulation_steps 1 \
    --prompt="single" --single_prompt_type="hand" --reward_list="handdetreward" \
    --per_prompt_stat_tracking=True \
    --tracker_project_name="texforce_hand" \
    --log_with="tensorboard"

The supported prompts and reward functions are listed below:

  • prompts: hand, face, color, count, comp, location
  • rewards: handdetreward, topiq_nr-face, imagereward

Example script for complex multi-prompt training

accelerate launch --num_processes 2 src/train_ddpo.py \
    --mixed_precision="fp16" \
    --sample_num_steps 50 --train_timestep_fraction 0.5 \
    --num_epochs 50 \
    --sample_batch_size 4 --sample_num_batches_per_epoch 128 \
    --train_batch_size 4 --train_gradient_accumulation_steps 4 \
    --prompt="imagereward" --reward_list="imagereward" \
    --per_prompt_stat_tracking=True \
    --tracker_project_name="texforce_imgreward" \
    --log_with="tensorboard"

The supported prompts and reward functions are:

  • prompts: imagereward, hps
  • rewards: imagereward, hpsreward, laion_aes

Citation

If you find this code useful for your research, please cite our paper:

@inproceedings{chen2024texforce,
  title={Enhancing Diffusion Models with Text-Encoder Reinforcement Learning},
  author={Chaofeng Chen and Annan Wang and Haoning Wu and Liang Liao and Wenxiu Sun and Qiong Yan and Weisi Lin},
  year={2024},
  booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
}

License

This work is licensed under NTU S-Lab License 1.0 and a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.

Creative Commons License

Acknowledgement

This project is largely based on trl. The hand detection codes are taken from Unified-Gesture-and-Fingertip-Detection. Many thanks to their great work 🤗!