Skip to content

Commit

Permalink
support Erasing Concepts from Diffusion Models
Browse files Browse the repository at this point in the history
  • Loading branch information
okotaku committed Oct 8, 2023
1 parent 7a63ce7 commit 45e7344
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 9 deletions.
1 change: 1 addition & 0 deletions diffengine/models/editors/esd/esd_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self,
super().__init__(
*args,
finetune_text_encoder=finetune_text_encoder,
pre_compute_text_embeddings=pre_compute_text_embeddings,
data_preprocessor=data_preprocessor,
**kwargs)

Expand Down
3 changes: 2 additions & 1 deletion diffengine/models/editors/ip_adapter/ip_adapter_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ def infer(self,
tokenizer=self.tokenizer_one,
tokenizer_2=self.tokenizer_two,
unet=self.unet,
torch_dtype=torch.float16,
torch_dtype=(torch.float16 if self.device != torch.device('cpu')
else torch.float32),
)
pipeline.to(self.device)
pipeline.set_progress_bar_config(disable=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def infer(self,
tokenizer=self.tokenizer,
unet=self.unet,
safety_checker=None,
torch_dtype=torch.float16)
torch_dtype=(torch.float16 if self.device != torch.device('cpu')
else torch.float32),
)
pipeline.set_progress_bar_config(disable=True)
images = []
for p in prompt:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def infer(self,
unet=self.unet,
controlnet=self.controlnet,
safety_checker=None,
torch_dtype=torch.float16)
torch_dtype=(torch.float16 if self.device != torch.device('cpu')
else torch.float32),
)
pipeline.set_progress_bar_config(disable=True)
images = []
for p, img in zip(prompt, condition_image):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def infer(self,
vae=self.vae,
unet=self.unet,
safety_checker=None,
torch_dtype=torch.float16,
torch_dtype=(torch.float16
if self.device != torch.device('cpu') else
torch.float32),
)
else:
pipeline = DiffusionPipeline.from_pretrained(
Expand All @@ -201,7 +203,9 @@ def infer(self,
tokenizer=self.tokenizer_one,
tokenizer_2=self.tokenizer_two,
unet=self.unet,
torch_dtype=torch.float16,
torch_dtype=(torch.float16
if self.device != torch.device('cpu') else
torch.float32),
)
pipeline.to(self.device)
pipeline.set_progress_bar_config(disable=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def infer(self,
tokenizer_2=self.tokenizer_two,
unet=self.unet,
controlnet=self.controlnet,
torch_dtype=torch.float16,
torch_dtype=(torch.float16 if self.device != torch.device('cpu')
else torch.float32),
)
pipeline.to(self.device)
pipeline.set_progress_bar_config(disable=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def infer(self,
tokenizer_2=self.tokenizer_two,
unet=self.unet,
adapter=self.adapter,
torch_dtype=torch.float16,
torch_dtype=(torch.float16 if self.device != torch.device('cpu')
else torch.float32),
)
pipeline.to(self.device)
pipeline.set_progress_bar_config(disable=True)
Expand Down
7 changes: 5 additions & 2 deletions tests/test_models/test_editors/test_esd/test_esd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,19 @@ def test_train_step_with_pre_compute_embs(self):
# test load with loss module
StableDiffuser = ESDXL(
'hf-internal-testing/tiny-stable-diffusion-xl-pipe',
train_method='xattn',
height=64,
width=64,
loss=L2Loss(),
data_preprocessor=ESDXLDataPreprocessor())

# test train step
data = dict(
inputs=dict(
text=['dog'],
prompt_embeds=[torch.zeros((77, 64))],
prompt_embeds=[torch.zeros((2, 64))],
pooled_prompt_embeds=[torch.zeros((32))],
null_prompt_embeds=[torch.zeros((77, 64))],
null_prompt_embeds=[torch.zeros((2, 64))],
null_pooled_prompt_embeds=[torch.zeros((32))]))
optimizer = SGD(StableDiffuser.parameters(), lr=0.1)
optim_wrapper = OptimWrapper(optimizer)
Expand Down

0 comments on commit 45e7344

Please sign in to comment.