diff --git a/diffengine/models/editors/esd/esd_xl.py b/diffengine/models/editors/esd/esd_xl.py index f4a6eab..ea92ea5 100644 --- a/diffengine/models/editors/esd/esd_xl.py +++ b/diffengine/models/editors/esd/esd_xl.py @@ -46,6 +46,7 @@ def __init__(self, super().__init__( *args, finetune_text_encoder=finetune_text_encoder, + pre_compute_text_embeddings=pre_compute_text_embeddings, data_preprocessor=data_preprocessor, **kwargs) diff --git a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py index 5654297..63f1953 100644 --- a/diffengine/models/editors/ip_adapter/ip_adapter_xl.py +++ b/diffengine/models/editors/ip_adapter/ip_adapter_xl.py @@ -153,7 +153,8 @@ def infer(self, tokenizer=self.tokenizer_one, tokenizer_2=self.tokenizer_two, unet=self.unet, - torch_dtype=torch.float16, + torch_dtype=(torch.float16 if self.device != torch.device('cpu') + else torch.float32), ) pipeline.to(self.device) pipeline.set_progress_bar_config(disable=True) diff --git a/diffengine/models/editors/stable_diffusion/stable_diffusion.py b/diffengine/models/editors/stable_diffusion/stable_diffusion.py index 3137d81..bf27cc2 100644 --- a/diffengine/models/editors/stable_diffusion/stable_diffusion.py +++ b/diffengine/models/editors/stable_diffusion/stable_diffusion.py @@ -144,7 +144,9 @@ def infer(self, tokenizer=self.tokenizer, unet=self.unet, safety_checker=None, - torch_dtype=torch.float16) + torch_dtype=(torch.float16 if self.device != torch.device('cpu') + else torch.float32), + ) pipeline.set_progress_bar_config(disable=True) images = [] for p in prompt: diff --git a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py index 2c9fcae..d4d1609 100644 --- a/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py +++ b/diffengine/models/editors/stable_diffusion_controlnet/stable_diffusion_controlnet.py @@ -140,7 +140,9 @@ def infer(self, unet=self.unet, controlnet=self.controlnet, safety_checker=None, - torch_dtype=torch.float16) + torch_dtype=(torch.float16 if self.device != torch.device('cpu') + else torch.float32), + ) pipeline.set_progress_bar_config(disable=True) images = [] for p, img in zip(prompt, condition_image): diff --git a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index 5500d5f..d3c76a2 100644 --- a/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/diffengine/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -190,7 +190,9 @@ def infer(self, vae=self.vae, unet=self.unet, safety_checker=None, - torch_dtype=torch.float16, + torch_dtype=(torch.float16 + if self.device != torch.device('cpu') else + torch.float32), ) else: pipeline = DiffusionPipeline.from_pretrained( @@ -201,7 +203,9 @@ def infer(self, tokenizer=self.tokenizer_one, tokenizer_2=self.tokenizer_two, unet=self.unet, - torch_dtype=torch.float16, + torch_dtype=(torch.float16 + if self.device != torch.device('cpu') else + torch.float32), ) pipeline.to(self.device) pipeline.set_progress_bar_config(disable=True) diff --git a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py index 4d0f054..2135960 100644 --- a/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py +++ b/diffengine/models/editors/stable_diffusion_xl_controlnet/stable_diffusion_xl_controlnet.py @@ -142,7 +142,8 @@ def infer(self, tokenizer_2=self.tokenizer_two, unet=self.unet, controlnet=self.controlnet, - torch_dtype=torch.float16, + torch_dtype=(torch.float16 if self.device != torch.device('cpu') + else torch.float32), ) pipeline.to(self.device) pipeline.set_progress_bar_config(disable=True) diff --git a/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py b/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py index 77fa70a..3af057e 100644 --- a/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py +++ b/diffengine/models/editors/t2i_adapter/stable_diffusion_xl_t2i_adapter.py @@ -131,7 +131,8 @@ def infer(self, tokenizer_2=self.tokenizer_two, unet=self.unet, adapter=self.adapter, - torch_dtype=torch.float16, + torch_dtype=(torch.float16 if self.device != torch.device('cpu') + else torch.float32), ) pipeline.to(self.device) pipeline.set_progress_bar_config(disable=True) diff --git a/tests/test_models/test_editors/test_esd/test_esd.py b/tests/test_models/test_editors/test_esd/test_esd.py index 5b6c281..8553686 100644 --- a/tests/test_models/test_editors/test_esd/test_esd.py +++ b/tests/test_models/test_editors/test_esd/test_esd.py @@ -51,6 +51,9 @@ def test_train_step_with_pre_compute_embs(self): # test load with loss module StableDiffuser = ESDXL( 'hf-internal-testing/tiny-stable-diffusion-xl-pipe', + train_method='xattn', + height=64, + width=64, loss=L2Loss(), data_preprocessor=ESDXLDataPreprocessor()) @@ -58,9 +61,9 @@ def test_train_step_with_pre_compute_embs(self): data = dict( inputs=dict( text=['dog'], - prompt_embeds=[torch.zeros((77, 64))], + prompt_embeds=[torch.zeros((2, 64))], pooled_prompt_embeds=[torch.zeros((32))], - null_prompt_embeds=[torch.zeros((77, 64))], + null_prompt_embeds=[torch.zeros((2, 64))], null_pooled_prompt_embeds=[torch.zeros((32))])) optimizer = SGD(StableDiffuser.parameters(), lr=0.1) optim_wrapper = OptimWrapper(optimizer)