Skip to content

Commit

Permalink
Merge pull request #1782 from DevArqSangoi/dev
Browse files Browse the repository at this point in the history
Update lpw_stable_diffusion.py
  • Loading branch information
bmaltais committed Dec 19, 2023
2 parents b3bf86d + 6dc1928 commit 235b6da
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
2 changes: 1 addition & 1 deletion library/class_basic_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
)
self.lr_scheduler_args = gr.Textbox(
label="LR scheduler extra arguments",
placeholder='(Optional) eg: "lr_end=5e-5"',
placeholder='(Optional) eg: "milestones=[1,10,30,50]" "gamma=0.1"',
)
self.optimizer_args = gr.Textbox(
label="Optimizer extra arguments",
Expand Down
21 changes: 18 additions & 3 deletions library/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import PIL.Image
import torch
from packaging import version
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

import diffusers
from diffusers import SchedulerMixin, StableDiffusionPipeline
Expand Down Expand Up @@ -516,12 +516,13 @@ def __init__(
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: SchedulerMixin,
# clip_skip: int,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
image_encoder: CLIPVisionModelWithProjection = None, # Incluindo o image_encoder
requires_safety_checker: bool = True,
clip_skip: int = 1,
):
self._clip_skip_internal = clip_skip
super().__init__(
vae=vae,
text_encoder=text_encoder,
Expand All @@ -530,11 +531,25 @@ def __init__(
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
self.clip_skip = clip_skip
self.__init__additional__()

@property
def clip_skip(self):
return self._clip_skip_internal

@clip_skip.setter
def clip_skip(self, value):
self._clip_skip_internal = value

def __setattr__(self, name: str, value):
if name == "clip_skip":
object.__setattr__(self, "_clip_skip_internal", value)
else:
super().__setattr__(name, value)

# else:
# def __init__(
# self,
Expand Down

0 comments on commit 235b6da

Please sign in to comment.