Skip to content

Commit

Permalink
initial proposal
Browse files Browse the repository at this point in the history
  • Loading branch information
asomoza committed Aug 29, 2024
1 parent 2a3fbc2 commit 7f5ece9
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
59 changes: 56 additions & 3 deletions src/diffusers/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,17 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s

class SDXLCFGCutoffCallback(PipelineCallback):
"""
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
`cutoff_step_index`), this callback will disable the CFG.
Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio`
or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""

tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
tensor_inputs = [
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
Expand All @@ -129,6 +133,55 @@ def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[s
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids

return callback_kwargs


class SDXLControlnetCFGCutoffCallback(PipelineCallback):
"""
Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio`
or `cutoff_step_index`), this callback will disable the CFG.
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
"""

tensor_inputs = [
"prompt_embeds",
"add_text_embeds",
"add_time_ids",
"image",
]

def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
cutoff_step_ratio = self.config.cutoff_step_ratio
cutoff_step_index = self.config.cutoff_step_index

# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
cutoff_step = (
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
)

if step_index == cutoff_step:
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.

add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens

add_time_ids = callback_kwargs[self.tensor_inputs[2]]
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector

# For Controlnet
image = callback_kwargs[self.tensor_inputs[3]]
image = image[-1:]

pipeline._guidance_scale = 0.0

callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
callback_kwargs[self.tensor_inputs[3]] = image

return callback_kwargs


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class StableDiffusionXLControlNetPipeline(
"add_time_ids",
"negative_pooled_prompt_embeds",
"negative_add_time_ids",
"image",
]

def __init__(
Expand Down Expand Up @@ -1532,6 +1533,7 @@ def __call__(
)
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
image = callback_outputs.pop("image", image)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
Expand Down

0 comments on commit 7f5ece9

Please sign in to comment.