Skip to content

Commit

Permalink
Merge pull request #27 from invoke-ai/ryan/dreambooth-prep
Browse files Browse the repository at this point in the history
Preparation for DreamBooth
  • Loading branch information
RyanJDick authored Aug 30, 2023
2 parents 90bf97f + cb564d9 commit 9ed39d7
Show file tree
Hide file tree
Showing 22 changed files with 353 additions and 150 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ A library for training custom Stable Diffusion models (fine-tuning, LoRA trainin

## Training Modes

There are currently 2 supported training scripts:
- Finetune with LoRA
- Stable Diffusion v1/v2: `invoke-finetune-lora-sd`
- Stable Diffusion XL: `invoke-finetune-lora-sdxl`
- DreamBooth with LoRA
- Stable Diffusion v1/v2: `invoke-dreambooth-lora-sd`
- Stable Diffusion XL: `invoke-dreambooth-lora-sdxl`

More training modes will be added soon.

Expand Down
3 changes: 2 additions & 1 deletion configs/finetune_lora_sd_pokemon_1x8gb_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ optimizer:

dataset:
dataset_name: lambdalabs/pokemon-blip-captions
resolution: 512
image_transforms:
resolution: 512

# General
model: runwayml/stable-diffusion-v1-5
Expand Down
3 changes: 2 additions & 1 deletion configs/finetune_lora_sdxl_pokemon_1x24gb_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ optimizer:

dataset:
dataset_name: lambdalabs/pokemon-blip-captions
resolution: 512
image_transforms:
resolution: 512

# General
model: stabilityai/stable-diffusion-xl-base-1.0
Expand Down
3 changes: 2 additions & 1 deletion configs/finetune_lora_sdxl_pokemon_1x8gb_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ optimizer:

dataset:
dataset_name: lambdalabs/pokemon-blip-captions
resolution: 512
image_transforms:
resolution: 512

# General
model: stabilityai/stable-diffusion-xl-base-1.0
Expand Down
4 changes: 2 additions & 2 deletions docs/dataset_formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dataset:
# ...
```

See [lora_training_config.py](/src/invoke_training/training/lora/lora_training_config.py) for full documentation of the `DatasetConfig`.
See [data_config.py](/src/invoke_training/training/config/data_config.py) for full documentation of the `ImageCaptionDatasetConfig`.

### ImageFolder Datasets
If you want to create custom datasets, then you will most likely want to use the [ImageFolder](https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder) dataset format.
Expand Down Expand Up @@ -57,4 +57,4 @@ dataset:
# ...
```

See [lora_training_config.py](/src/invoke_training/training/lora/lora_training_config.py) for full documentation of the `DatasetConfig`.
See [data_config.py](/src/invoke_training/training/config/data_config.py) for full documentation of the `ImageCaptionDatasetConfig`.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ dependencies = [
[project.scripts]
"invoke-finetune-lora-sd" = "invoke_training.scripts.invoke_finetune_lora_sd:main"
"invoke-finetune-lora-sdxl" = "invoke_training.scripts.invoke_finetune_lora_sdxl:main"
"invoke-dreambooth-lora-sd" = "invoke_training.scripts.invoke_dreambooth_lora_sd:main"
"invoke-dreambooth-lora-sdxl" = "invoke_training.scripts.invoke_dreambooth_lora_sdxl:main"

[project.urls]
"Homepage" = "https://github.com/invoke-ai/invoke-training"
Expand Down
36 changes: 36 additions & 0 deletions src/invoke_training/scripts/invoke_dreambooth_lora_sd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse
from pathlib import Path

import yaml

from invoke_training.training.config.finetune_lora_config import DreamBoothLoRAConfig
from invoke_training.training.dreambooth_lora.dreambooth_lora_sd import run_training


def parse_args():
parser = argparse.ArgumentParser(
description="DreamBooth training with LoRA for Stable Diffusion v1 and v2 base models."
)
parser.add_argument(
"--cfg-file",
type=Path,
required=True,
help="Path to the YAML training config file. See `DreamBoothLoRAConfig` for the supported fields.",
)
return parser.parse_args()


def main():
args = parse_args()

# Load YAML config file.
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)

train_config = DreamBoothLoRAConfig(**cfg)

run_training(train_config)


if __name__ == "__main__":
main()
36 changes: 36 additions & 0 deletions src/invoke_training/scripts/invoke_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import argparse
from pathlib import Path

import yaml

from invoke_training.training.config.finetune_lora_config import (
DreamBoothLoRASDXLConfig,
)
from invoke_training.training.dreambooth_lora.dreambooth_lora_sd import run_training


def parse_args():
parser = argparse.ArgumentParser(description="DreamBooth training with LoRA for Stable Diffusion XL base models.")
parser.add_argument(
"--cfg-file",
type=Path,
required=True,
help="Path to the YAML training config file. See `DreamBoothLoRASDXLConfig` for the supported fields.",
)
return parser.parse_args()


def main():
args = parse_args()

# Load YAML config file.
with open(args.cfg_file, "r") as f:
cfg = yaml.safe_load(f)

train_config = DreamBoothLoRASDXLConfig(**cfg)

run_training(train_config)


if __name__ == "__main__":
main()
57 changes: 57 additions & 0 deletions src/invoke_training/training/config/data_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import typing

from pydantic import BaseModel


class ImageTransformConfig(BaseModel):
# The resolution for input images. All of the images in the dataset will be resized to this (square) resolution.
resolution: int = 512

# If True, input images will be center-cropped to resolution.
# If False, input images will be randomly cropped to resolution.
center_crop: bool = False

# Whether random flip augmentations should be applied to input images.
random_flip: bool = False


class ImageCaptionDatasetConfig(BaseModel):
# The name of a Hugging Face dataset.
# One of dataset_name and dataset_dir should be set (dataset_name takes precedence).
# See also: dataset_config_name.
dataset_name: typing.Optional[str] = None

# The directory to load a dataset from. The dataset is expected to be in
# Hugging Face imagefolder format (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).
# One of 'dataset_name' and 'dataset_dir' should be set ('dataset_name' takes precedence).
dataset_dir: typing.Optional[str] = None

# The Hugging Face dataset config name. Leave as None if there's only one config.
# This parameter is only used if dataset_name is set.
dataset_config_name: typing.Optional[str] = None

# The Hugging Face cache directory to use for dataset downloads.
# If None, the default value will be used (usually '~/.cache/huggingface/datasets').
hf_cache_dir: typing.Optional[str] = None

# The name of the dataset column that contains image paths.
image_column: str = "image"

# The name of the dataset column that contains captions.
caption_column: str = "text"

# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
dataloader_num_workers: int = 0

image_transforms: ImageTransformConfig


class ImageDirDatasetConfig(BaseModel):
# The directory to load images from.
dataset_dir: str

# The image file extensions to include in the dataset.
# If None, then the following file extensions will be loaded: [".png", ".jpg", ".jpeg"].
image_file_extensions: typing.Optional[list[str]] = None

image_transforms: ImageTransformConfig
92 changes: 41 additions & 51 deletions src/invoke_training/training/config/finetune_lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from pydantic import BaseModel

from invoke_training.training.config.data_config import (
ImageCaptionDatasetConfig,
ImageDirDatasetConfig,
)
from invoke_training.training.config.optimizer_config import OptimizerConfig


Expand All @@ -24,57 +28,8 @@ class TrainingOutputConfig(BaseModel):
save_model_as: typing.Literal["ckpt", "pt", "safetensors"] = "safetensors"


class DatasetConfig(BaseModel):
# The name of a Hugging Face dataset.
# One of dataset_name and dataset_dir should be set (dataset_name takes precedence).
# See also: dataset_config_name.
dataset_name: typing.Optional[str] = None

# The directory to load a dataset from. The dataset is expected to be in
# Hugging Face imagefolder format (https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder).
# One of 'dataset_name' and 'dataset_dir' should be set ('dataset_name' takes precedence).
dataset_dir: typing.Optional[str] = None

# The Hugging Face dataset config name. Leave as None if there's only one config.
# This parameter is only used if dataset_name is set.
dataset_config_name: typing.Optional[str] = None

# The Hugging Face cache directory to use for dataset downloads.
# If None, the default value will be used (usually '~/.cache/huggingface/datasets').
hf_cache_dir: typing.Optional[str] = None

# The name of the dataset column that contains image paths.
image_column: str = "image"

# The name of the dataset column that contains captions.
caption_column: str = "text"

# The resolution for input images. All of the images in the dataset will be resized to this (square) resolution.
resolution: int = 512

# If True, input images will be center-cropped to resolution.
# If False, input images will be randomly cropped to resolution.
center_crop: bool = False

# Whether random flip augmentations should be applied to input images.
random_flip: bool = False

# Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
dataloader_num_workers: int = 0


class FinetuneLoRAConfig(BaseModel):
"""The configuration for a LoRA training run."""

output: TrainingOutputConfig

optimizer: OptimizerConfig

dataset: DatasetConfig

##################
# General Configs
##################
class LoRATrainingConfig(BaseModel):
"""The base configuration for any LoRA training run."""

# The name of the Hugging Face Hub model to train against.
model: str = "runwayml/stable-diffusion-v1-5"
Expand Down Expand Up @@ -166,8 +121,43 @@ class FinetuneLoRAConfig(BaseModel):
train_batch_size: int = 4


class FinetuneLoRAConfig(LoRATrainingConfig):
output: TrainingOutputConfig
optimizer: OptimizerConfig
dataset: ImageCaptionDatasetConfig


class FinetuneLoRASDXLConfig(FinetuneLoRAConfig):
# The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base
# model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
# with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
vae_model: typing.Optional[str] = None


class DreamBoothLoRAConfig(LoRATrainingConfig):
output: TrainingOutputConfig
optimizer: OptimizerConfig

# The instance dataset to train on.
instance_dataset: ImageDirDatasetConfig

# The caption to use for all examples in the instance_dataset. Typically has the following form:
# "a [instance identifier] [class noun]".
instance_prompt: str

# If true, a regularization dataset of prior presevation images will be generated.
use_prior_preservation: bool = False

# The prompt to use to generate the class regularization dataset. This same prompt will also be used for
# conditioning during training. Typically has the following form: "a [class noun]".
class_prompt: str

# The number of class regularization images to generate.
num_class_images: int = 0


class DreamBoothLoRASDXLConfig(DreamBoothLoRAConfig):
# The name of the Hugging Face Hub VAE model to train against. This will override the VAE bundled with the base
# model (specified by the `model` parameter). This config option is provided for SDXL models, because SDXL shipped
# with a VAE that produces NaNs in fp16 mode, so it is common to replace this VAE with a fixed version.
vae_model: typing.Optional[str] = None
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from invoke_training.training.config.finetune_lora_config import DreamBoothLoRAConfig


def run_training(config: DreamBoothLoRAConfig):
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from invoke_training.training.config.finetune_lora_config import (
DreamBoothLoRASDXLConfig,
)


def run_training(config: DreamBoothLoRASDXLConfig):
raise NotImplementedError
Loading

0 comments on commit 9ed39d7

Please sign in to comment.