PyTorch code for the ECCV 2022 paper "StoryDALL-E: Adapting Pretrained Text-to-Image Transformers for Story Continuation".
[Paper] [Model Card] [Spaces Demo] [Replicate Demo]
Download the PororoSV dataset and associated files from here (updated) and save it as ./data/pororo/
.
Download the FlintstonesSV dataset and associated files from here and save it as ./data/flintstones
Download the DiDeMoSV dataset and associated files from here and save it as ./data/didemo
This repository contains separate folders for training StoryDALL-E based on minDALL-E and DALL-E Mega models i.e. the ./story_dalle/
and ./mega-story-dalle
models respectively.
- To finetune the minDALL-E model for story continuation, first migrate to the corresponding folder:
cd story-dalle
- Set the environment variables in
train_story.sh
to point to the right locations in your system. Specifically, change the$DATA_DIR
,$OUTPUT_ROOT
and$LOG_DIR
if different from the default locations. - Download the pretrained checkpoint from here and save it in
./1.3B
- Run the following command:
bash train_story.sh <dataset_name>
- To finetune the DALL-E Mega model for story continuation, first migrate to the corresponding folder:
cd mega-story-dalle
- Set the environment variables in
train_story.sh
to point to the right locations in your system. Specifically, change the$DATA_DIR
,$OUTPUT_ROOT
and$LOG_DIR
if different from the default locations. - Pretrained checkpoints for generative model and VQGAN detokenizer are automatically downloaded upon initialization. Download the pretrained weights for VQGAN tokenizer from here and place it in the same folder as VQGAN detokenizer.
- Run the following command:
bash train_story.sh <dataset_name>
Pretrained checkpoints for minDALL-E based StoryDALL-E can be downloaded from here: PororoSV
For a demo of inference using cog, check out this repo.
- To infer from the minDALL-E model for story continuation, first migrate to the corresponding folder:
cd story-dalle
- Set the environment variables in
infer_story.sh
to point to the right locations in your system. Specifically, change the$DATA_DIR
,$OUTPUT_ROOT
and$MODEL_CKPT
if different from the default locations. 3Run the following command:bash infer_story.sh <dataset_name>
For double-precision inference, the StoryDALLE model requires nearly 40 GB of space. The memory requirements can be reduced to 20GB by performing mixed precision inference from the autoregressive decoder (included in codebase, see line 1095 in story-dalle/dalle/models/_init.py). Note that the VQGAN model needs to operate at full precision to retain high-quality of the generated images.
Thanks to the fantastic folks at Kakao Brain and HuggingFace for their work on open-sourced versions of min-DALLE and DALL-E Mega. Much of this codebase has been adapted from here and here.