Skip to content

Commit

Permalink
add mplug-owl-2 training code
Browse files Browse the repository at this point in the history
  • Loading branch information
haoning.wu committed Dec 13, 2023
1 parent 274e415 commit ab78904
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 2 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ See [Model Zoo](model_zoo). Both **huggingface** and **modelscope** weights are

## Training

At present, we only provide the training scripts with LLaVA-v1.5 (7B/13B). Please see [Training Docs](scripts/llava_v1.5) for more details.
- [Training Docs for LLaVA-v1.5](scripts/llava_v1.5)
- [Training Docs for mPLUG-Owl-2](scripts/mplug_owl_2)

## License

Expand Down
2 changes: 1 addition & 1 deletion scripts/llava_v1.5/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Training with Q-Instruct
## Training@LLaVA-v1.5

This document provides instruction on how to train with **Q-Instruct** dataset on LLaVA-v1.5 (7B/13B), under the proposed two strategies (***mix*** and ***after***), shown as follows.

Expand Down
117 changes: 117 additions & 0 deletions scripts/mplug_owl_2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
## Training@mPLUG-Owl-2

This document provides instruction on how to train with **Q-Instruct** dataset on mPLUG-Owl-2 (LLaMA2-7B), under the proposed two strategies (***mix*** and ***after***), shown as follows.

![](strategies.png)

*Due to copyright issues, the pre-trained weights of mPLUG-Owl-2 is not available at present. Therefore, the open-source mix strategy at present is slightly different, that we need to use a mix of **high-level datasets** and **Q-Instruct** on the SFT checkpoint of mPLUG-Owl-2.*

### Step 0: Pre-requisites

Install [mPLUG-Owl](https://github.com/X-PLUG/mPLUG-Owl/) under the main repository, with flash attention for more efficient training.

```shell
git clone https://github.com/X-PLUG/mPLUG-Owl.git
cd mPLUG-Owl/mPLUG_Owl_2
pip install -e ".[train]"
pip install flash_attn --no-build-isolation
cd ..
```

After that, you can conduct *low-level visual instruction tuning* as follows, under either ***mix*** or ***after*** strategy.

### Step 1: Download Training Datasets


#### Download Q-Instruct

*Note: If you have already downloaded for LLaVA, you may directly copy them here.*

For the **Q-Instruct** dataset, download them directly via the following script:

```shell
cd mPLUG-Owl/mPLUG_Owl_2/playground/data
wget https://huggingface.co/datasets/teowu/Q-Instruct/resolve/main/cleaned_labels.json
wget https://huggingface.co/datasets/teowu/Q-Instruct/resolve/main/q-instruct-images.tar
tar -xf q-instruct-images.tar
rm -f q-instruct-images.tar
cd ../../../..
```

Make sure you have the file structures as follows under `LLaVA/playground/data`.

```
├── spaq_koniq
├── livefb_liveitw_aigc
├── cleaned_labels.json
```

#### Download Public High-level Instruction Tuning Datasets

*Note: If you have already downloaded for LLaVA, you may directly copy them here.*

If you choose the ***mix*** strategy, the high-level datasets also need to be downloaded via the following steps:


1. Download the annotation of the final mixture our instruction tuning data [llava_v1_5_mix665k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json):

```shell
wget -P mPLUG-Owl/mPLUG_Owl2/playground/data https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json
```

2. Download the images from constituting datasets:

- COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip)
- GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip)
- OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing), **we save all files as `.jpg`**
- TextVQA: [train_val_images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip)
- VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip)

After downloading all of them, organize the high-level data as follows in `LLaVA/playground/data`,

```
├── coco
│ └── train2017
├── gqa
│ └── images
├── ocr_vqa
│ └── images
├── textvqa
│ └── train_images
└── vg
├── VG_100K
└── VG_100K_2
```

3. Merge the **Q-Instruct** labels with labels from high-level datasets.

```shell
jq -s 'add' mPLUG-Owl/mPLUG_Owl2/playground/data/cleaned_labels.json mPLUG-Owl/mPLUG_Owl2/playground/data/llava_v1_5_mix665k.json > mPLUG-Owl/mPLUG_Owl2/playground/data/mix_cleaned_labels.json
```


### Step 2: Start Training

Please make sure you have enough computational resources before training.

- [Must Do!] Replace all the `<image>` token in the json into `<|image|>`, Otherwise the image will not be loaded into training.

```shell
sed -i 's/<image>/<|image|>/g' mPLUG-Owl/mPLUG_Owl2/playground/data/mix_cleaned_labels.json
```

#### Strategy (a): Mix with High-level Datasets

- Training *(requires 8x A100 80G), 11h*

```shell
sh scripts/mplug_owl_2/mix_qinstruct.sh
```

#### Strategy (b): After High-level Datasets

- Training *(requires 8x A100 80G), 1.5h*

```shell
sh scripts/mplug_owl_2/after_qinstruct.sh
```PLUG-Owl/mPLUG_Owl2/playground/data/mix_cleaned_labels.json
35 changes: 35 additions & 0 deletions scripts/mplug_owl_2/after_q_instruct.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
LOAD='MAGAer13/mplug-owl2-llama2-7b'

DATA_FILE=mPLUG-Owl/mPLUG_Owl2/playground/data/cleaned_labels.json
deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
--data_path $DATA_FILE \
--image_folder mPLUG-Owl/mPLUG_Owl2/playground/data/ \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./checkpoints/mplug-owl2-finetune-after \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 800 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--tune_visual_abstractor True \
--freeze_vision_model False \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to tensorboard
35 changes: 35 additions & 0 deletions scripts/mplug_owl_2/mix_q_instruct.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
LOAD='MAGAer13/mplug-owl2-llama2-7b'

DATA_FILE=mPLUG-Owl/mPLUG_Owl2/playground/data/mix_cleaned_labels.json
deepspeed --master_port 25801 mplug_owl2/train/train_mem.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path $LOAD \
--version v1 \
--data_path $DATA_FILE \
--image_folder mPLUG-Owl/mPLUG_Owl2/playground/data/ \
--image_aspect_ratio pad \
--group_by_modality_length True \
--bf16 True \
--output_dir ./checkpoints/mplug-owl2-finetune-mix \
--num_train_epochs 1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 800 \
--save_total_limit 1 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 2048 \
--gradient_checkpointing True \
--tune_visual_abstractor True \
--freeze_vision_model False \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to tensorboard
Binary file added scripts/mplug_owl_2/strategies.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit ab78904

Please sign in to comment.