Skip to content

Commit

Permalink
VRAM optimization: keep track of where the model is, so that our embe…
Browse files Browse the repository at this point in the history
…ds end up in the same place.

Signed-off-by: bghira <[email protected]>
  • Loading branch information
bghira committed Jul 8, 2023
1 parent 6e7d339 commit 03aae82
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
8 changes: 8 additions & 0 deletions discord_tron_client/classes/image_manipulation/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,13 +332,21 @@ def patch_scheduler_betas(self, scheduler):
return scheduler

def to_accelerator(self, pipeline):
is_on_gpu = next(pipeline.unet.parameters()).is_cuda
if is_on_gpu:
logging.warning(f'Requested to move pipeline to CPU, when it is already there.')
return
try:
pipeline.to(self.device)
except Exception as e:
logging.error(f"Could not move pipeline to accelerator: {e}")
raise e

def to_cpu(self, pipeline):
is_on_gpu = next(pipeline.unet.parameters()).is_cuda
if not is_on_gpu:
logging.warning(f'Requested to move pipeline to CPU, when it is already there.')
return
try:
pipeline.to("cpu")
except Exception as e:
Expand Down
12 changes: 8 additions & 4 deletions discord_tron_client/classes/image_manipulation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,13 @@ def _get_generator(self, user_config: dict):
logging.info(f"Seed: {self.seed}")
return generator

def _get_prompt_manager(self, pipe):
logging.debug(f"Initialized the Compel")
return PromptManipulation(pipeline=pipe, device=self.pipeline_manager.device)
def _get_prompt_manager(self, pipe, device = "cpu"):
is_gpu = next(pipe.unet.parameters()).is_cuda
if is_gpu:
if device == "cpu":
logging.warning(f'Prompt manager was requested to be placed on the CPU, but the unet is already on the GPU. We have to adjust the prompt manager, to the GPU.')
device = "cuda"
return PromptManipulation(pipeline=pipe, device=device)

def _get_rescaled_resolution(self, user_config, side_x, side_y):
resolution = {"width": side_x, "height": side_y}
Expand Down Expand Up @@ -454,7 +458,7 @@ def _controlnet_pipeline(self, image: Image, user_config: dict, pipe, generator,
generator=generator,
num_inference_steps=user_config.get("tile_steps", 32),
).images[0]
self.pipeline_manager.to_cpu(pipe)
self.pipeline_manager.to_cpu(pipe, user_config['model_id'])
return new_image

def _refiner_pipeline(self, images: Image, user_config: dict, prompt: str = None, negative_prompt: str = None, random_seed = False):
Expand Down

0 comments on commit 03aae82

Please sign in to comment.