diff --git a/library/class_basic_training.py b/library/class_basic_training.py index e0bb2f1ef..5aaa52d81 100644 --- a/library/class_basic_training.py +++ b/library/class_basic_training.py @@ -123,7 +123,7 @@ def __init__( ) self.lr_scheduler_args = gr.Textbox( label="LR scheduler extra arguments", - placeholder='(Optional) eg: "lr_end=5e-5"', + placeholder='(Optional) eg: "milestones=[1,10,30,50]" "gamma=0.1"', ) self.optimizer_args = gr.Textbox( label="Optimizer extra arguments", diff --git a/library/lpw_stable_diffusion.py b/library/lpw_stable_diffusion.py index 9dce91a76..4f9408352 100644 --- a/library/lpw_stable_diffusion.py +++ b/library/lpw_stable_diffusion.py @@ -9,7 +9,7 @@ import PIL.Image import torch from packaging import version -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection import diffusers from diffusers import SchedulerMixin, StableDiffusionPipeline @@ -516,12 +516,13 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: SchedulerMixin, - # clip_skip: int, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, + image_encoder: CLIPVisionModelWithProjection = None, # Incluindo o image_encoder requires_safety_checker: bool = True, clip_skip: int = 1, ): + self._clip_skip_internal = clip_skip super().__init__( vae=vae, text_encoder=text_encoder, @@ -530,11 +531,25 @@ def __init__( scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, + image_encoder=image_encoder, requires_safety_checker=requires_safety_checker, ) - self.clip_skip = clip_skip self.__init__additional__() + @property + def clip_skip(self): + return self._clip_skip_internal + + @clip_skip.setter + def clip_skip(self, value): + self._clip_skip_internal = value + + def __setattr__(self, name: str, value): + if name == "clip_skip": + object.__setattr__(self, "_clip_skip_internal", value) + else: + super().__setattr__(name, value) + # else: # def __init__( # self,