Skip to content

Commit

Permalink
sync the max_sequence_length parameter of sd3 pipeline with official …
Browse files Browse the repository at this point in the history
…implemention
  • Loading branch information
root committed Nov 13, 2024
1 parent 1dbd26f commit 472f06d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
max_sequence_length: int = 77,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
Expand Down Expand Up @@ -334,7 +334,7 @@ def encode_prompt(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = None,
max_sequence_length: int = 256,
max_sequence_length: int = 77,
lora_scale: Optional[float] = None,
):
r"""
Expand Down Expand Up @@ -693,7 +693,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
max_sequence_length: int = 77,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -777,7 +777,7 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
max_sequence_length (`int` defaults to 77): Maximum sequence length to use with the `prompt`.
Examples:
Expand Down Expand Up @@ -849,6 +849,7 @@ def __call__(
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
print(prompt_embeds.shape)

if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
max_sequence_length: int = 77,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
Expand Down Expand Up @@ -349,7 +349,7 @@ def encode_prompt(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = None,
max_sequence_length: int = 256,
max_sequence_length: int = 77,
lora_scale: Optional[float] = None,
):
r"""
Expand Down Expand Up @@ -731,7 +731,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
max_sequence_length: int = 77,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -815,7 +815,7 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
max_sequence_length (`int` defaults to 77): Maximum sequence length to use with the `prompt`.
Examples:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 256,
max_sequence_length: int = 77,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
Expand Down Expand Up @@ -355,7 +355,7 @@ def encode_prompt(
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = None,
max_sequence_length: int = 256,
max_sequence_length: int = 77,
lora_scale: Optional[float] = None,
):
r"""
Expand Down Expand Up @@ -823,7 +823,7 @@ def __call__(
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
max_sequence_length: int = 77,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -932,7 +932,7 @@ def __call__(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
max_sequence_length (`int` defaults to 77): Maximum sequence length to use with the `prompt`.
Examples:
Expand Down

0 comments on commit 472f06d

Please sign in to comment.