In-Context Learning Unlocked for Diffusion Models
Zhendong Wang, Yifan Jiang, Yadong Lu, Yelong Shen, Pengcheng He, Weizhu Chen, Zhangyang Wang and Mingyuan Zhou
Abstract: We present Prompt Diffusion, a framework for enabling in-context learning in diffusion-based generative models. Given a pair of task-specific example images, such as depth from/to image and scribble from/to image, and a text guidance, our model automatically understands the underlying task and performs the same task on a new query image following the text guidance. To achieve this, we propose a vision-language prompt that can model a wide range of vision-language tasks and a diffusion model that takes it as input. The diffusion model is trained jointly on six different tasks using these prompts. The resulting Prompt Diffusion model becomes the first diffusion-based vision-language foundation model capable of in-context learning. It demonstrates high-quality in-context generation for the trained tasks and effectively generalizes to new, unseen vision tasks using their respective prompts. Our model also shows compelling text-guided image editing results. Our framework aims to facilitate research into in-context learning for computer vision, with code publicly available here.
We thank the contribution of iczaw. Now Prompt-Diffusion is supported through the diffusers package. Following the guidance code below for a quick try:
import torch
from diffusers import DDIMScheduler, UniPCMultistepScheduler
from diffusers.utils import load_image
from promptdiffusioncontrolnet import PromptDiffusionControlNetModel
from pipeline_prompt_diffusion import PromptDiffusionPipeline
from PIL import ImageOps
image_a = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house_line.png?raw=true"))
image_b = load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/house.png?raw=true")
query = ImageOps.invert(load_image("https://github.com/Zhendong-Wang/Prompt-Diffusion/blob/main/images_to_try/new_01.png?raw=true"))
# load prompt diffusion controlnet and prompt diffusion
controlnet = PromptDiffusionControlNetModel.from_pretrained("zhendongw/prompt-diffusion-diffusers", subfolder="controlnet", torch_dtype=torch.float16)
pipe = PromptDiffusionPipeline.from_pretrained("zhendongw/prompt-diffusion-diffusers", controlnet=controlnet).to(torch_dtype=torch.float16).to('cuda')
# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed
# pipe.enable_xformers_memory_efficient_attention()
# pipe.enable_model_cpu_offload()
# generate image
generator = torch.manual_seed(2023)
image = pipe("a tortoise", num_inference_steps=50, generator=generator, image_pair=[image_a,image_b], image=query).images[0]
image.save('./test.png')
We use the public dataset proposed by InstructPix2Pix as our base dataset,
which consists of around 310k image-caption pairs. Furthermore, we apply the ControlNet annotators
to collect image conditions such as HED/Depth/Segmentation maps of images. The code for collecting image conditions is provided in annotate_data.py
.
Training a Prompt Diffusion is as easy as follows,
python tool_add_control.py 'path to your stable diffusion checkpoint, e.g., /.../v1-5-pruned-emaonly.ckpt' ./models/control_sd15_ini.ckpt
python train.py --name 'experiment name' --gpus=8 --num_nodes=1 \
--logdir 'your logdir path' \
--data_config './models/dataset.yaml' --base './models/cldm_v15.yaml' \
--sd_locked
We also provide the job script in scripts/train_v1-5.sh
for an easy run.
We release the model checkpoints trained by us at our Huggingface Page and
the quick access for downloading is here.
We provide a jupyter notebook
run_prompt_diffusion.ipynb
for trying the inference code of Prompt Diffusion. We also provide a few images to try on in the folder
images_to_try
. We are preparing a demo based on Gradio and will release the demo soon.
@article{wang2023promptdiffusion,
title = {In-Context Learning Unlocked for Diffusion Models},
author = {Wang, Zhendong and Jiang, Yifan and Lu, Yadong and Shen, Yelong and He, Pengcheng and Chen, Weizhu and Wang, Zhangyang and Zhou, Mingyuan},
journal = {arXiv preprint arXiv:2305.01115},
year = {2023},
url = {https://arxiv.org/abs/2305.01115}
}
We thank Brooks et al. for sharing the dataset for finetuning Stable Diffusion. We also thank Lvmin Zhang and Maneesh Agrawala for providing the awesome code base ControlNet.