Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keep_tokens_separator as alternative for keep_tokens #975

Merged
merged 3 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class BaseSubsetParams:
shuffle_caption: bool = False
caption_separator: str = ',',
keep_tokens: int = 0
keep_tokens_separator: str = None,
color_aug: bool = False
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
Expand Down Expand Up @@ -160,6 +161,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"random_crop": bool,
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
"token_warmup_min": int,
"token_warmup_step": Any(float,int),
"caption_prefix": str,
Expand Down Expand Up @@ -461,6 +463,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
num_repeats: {subset.num_repeats}
shuffle_caption: {subset.shuffle_caption}
keep_tokens: {subset.keep_tokens}
keep_tokens_separator: {subset.keep_tokens_separator}
caption_dropout_rate: {subset.caption_dropout_rate}
caption_dropout_every_n_epoches: {subset.caption_dropout_every_n_epochs}
caption_tag_dropout_rate: {subset.caption_tag_dropout_rate}
Expand Down
41 changes: 32 additions & 9 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def __init__(
shuffle_caption: bool,
caption_separator: str,
keep_tokens: int,
keep_tokens_separator: str,
color_aug: bool,
flip_aug: bool,
face_crop_aug_range: Optional[Tuple[float, float]],
Expand All @@ -368,6 +369,7 @@ def __init__(
self.shuffle_caption = shuffle_caption
self.caption_separator = caption_separator
self.keep_tokens = keep_tokens
self.keep_tokens_separator = keep_tokens_separator
self.color_aug = color_aug
self.flip_aug = flip_aug
self.face_crop_aug_range = face_crop_aug_range
Expand Down Expand Up @@ -395,6 +397,7 @@ def __init__(
shuffle_caption,
caption_separator: str,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -415,6 +418,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -449,6 +453,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -469,6 +474,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -500,6 +506,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand All @@ -520,6 +527,7 @@ def __init__(
shuffle_caption,
caption_separator,
keep_tokens,
keep_tokens_separator,
color_aug,
flip_aug,
face_crop_aug_range,
Expand Down Expand Up @@ -654,15 +662,29 @@ def process_caption(self, subset: BaseSubset, caption):
caption = ""
else:
if subset.shuffle_caption or subset.token_warmup_step > 0 or subset.caption_tag_dropout_rate > 0:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
fixed_tokens = []
flex_tokens = []
if hasattr(subset, 'keep_tokens_separator') and subset.keep_tokens_separator in caption:
fixed_part, flex_part = caption.split(subset.keep_tokens_separator, 1)
fixed_tokens = [t.strip() for t in fixed_part.split(subset.caption_separator) if t.strip()]
flex_tokens = [t.strip() for t in flex_part.split(subset.caption_separator) if t.strip()]
else:
tokens = [t.strip() for t in caption.strip().split(subset.caption_separator)]
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[:subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens:]


if subset.token_warmup_step < 1: # 初回に上書きする
subset.token_warmup_step = math.floor(subset.token_warmup_step * self.max_train_steps)
if subset.token_warmup_step and self.current_step < subset.token_warmup_step:
tokens_len = (
math.floor((self.current_step) * ((len(tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
math.floor((self.current_step) * ((len(flex_tokens) - subset.token_warmup_min) / (subset.token_warmup_step)))
+ subset.token_warmup_min
)
tokens = tokens[:tokens_len]
flex_tokens = flex_tokens[:tokens_len]


def dropout_tags(tokens):
if subset.caption_tag_dropout_rate <= 0:
Expand All @@ -673,12 +695,6 @@ def dropout_tags(tokens):
l.append(token)
return l

fixed_tokens = []
flex_tokens = tokens[:]
if subset.keep_tokens > 0:
fixed_tokens = flex_tokens[: subset.keep_tokens]
flex_tokens = tokens[subset.keep_tokens :]

if subset.shuffle_caption:
random.shuffle(flex_tokens)

Expand Down Expand Up @@ -1723,6 +1739,7 @@ def __init__(
subset.num_repeats,
subset.shuffle_caption,
subset.keep_tokens,
subset.keep_tokens_separator,
subset.color_aug,
subset.flip_aug,
subset.face_crop_aug_range,
Expand Down Expand Up @@ -3133,6 +3150,12 @@ def add_dataset_arguments(
default=0,
help="keep heading N tokens when shuffling caption tokens (token means comma separated strings) / captionのシャッフル時に、先頭からこの個数のトークンをシャッフルしないで残す(トークンはカンマ区切りの各部分を意味する)",
)
parser.add_argument(
"--keep_tokens_separator",
type=str,
default="",
help="A custom separator to divide the caption into fixed and flexible parts. Tokens before this separator will not be shuffled. If not specified, '--keep_tokens' will be used to determine the fixed number of tokens.",
)
parser.add_argument(
"--caption_prefix",
type=str,
Expand Down