Skip to content

Commit

Permalink
Add a setting for clearing CUDA cache after each inference (#2930)
Browse files Browse the repository at this point in the history
* Add a setting for clearing CUDA cache after each inference

* change implementation
  • Loading branch information
joeyballentine committed Jun 2, 2024
1 parent 8caf78b commit 22222de
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 15 deletions.
6 changes: 4 additions & 2 deletions backend/src/api/node_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ def storage_dir(self) -> Path:
"""

@abstractmethod
def add_cleanup(self, fn: Callable[[], None]) -> None:
def add_cleanup(
self, fn: Callable[[], None], after: Literal["node", "chain"] = "chain"
) -> None:
"""
Registers a function that will be called when the chain execution is finished.
Registers a function that will be called when the chain execution is finished (if set to chain mode) or after node execution is finished (node mode).
Registering the same function (object) twice will only result in the function being called once.
"""
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@ def align_image_to_reference_node(
alignment_passes: int,
blur_strength: float,
) -> np.ndarray:
context.add_cleanup(safe_cuda_cache_empty)
exec_options = get_settings(context)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)
multiplier = precision.value / 1000
return align_images(
context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,15 @@ def guided_upscale_node(
iterations: float,
split_mode: SplitMode,
) -> np.ndarray:
context.add_cleanup(safe_cuda_cache_empty)
exec_options = get_settings(context)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)
return pix_transform_auto_split(
source=source,
guide=guide,
device=get_settings(context).device,
device=exec_options.device,
params=Params(iteration=int(iterations * 1000)),
split_mode=split_mode,
)
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def inpaint_node(
), "Input image and mask must have the same resolution"

exec_options = get_settings(context)

context.add_cleanup(safe_cuda_cache_empty)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)

return inpaint(img, mask, model, exec_options)
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ def upscale_image_node(
) -> np.ndarray:
exec_options = get_settings(context)

context.add_cleanup(safe_cuda_cache_empty)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)

in_nc = model.input_channels
out_nc = model.output_channels
Expand Down Expand Up @@ -299,5 +302,4 @@ def inner_upscale(img: np.ndarray) -> np.ndarray:
if not use_custom_scale or scale == 1 or in_nc != out_nc:
# no custom scale
custom_scale = scale

return custom_scale_upscale(img, inner_upscale, scale, custom_scale, separate_alpha)
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def wavelet_color_fix_node(
)

exec_options = get_settings(context)
context.add_cleanup(safe_cuda_cache_empty)
context.add_cleanup(
safe_cuda_cache_empty,
after="node" if exec_options.force_cache_wipe else "chain",
)
device = exec_options.device

# convert to tensors
Expand Down
12 changes: 12 additions & 0 deletions backend/src/packages/chaiNNer_pytorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,24 @@
)
)

if nvidia.is_available:
package.add_setting(
ToggleSetting(
label="Force CUDA Cache Wipe (not recommended)",
key="force_cache_wipe",
description="Clears PyTorch's CUDA cache after each inference. This is NOT recommended, by us or PyTorch's developers, as it basically interferes with how PyTorch is intended to work and can significantly slow down inference time. Only enable this if you're experiencing issues with VRAM allocation.",
default=False,
)
)


@dataclass(frozen=True)
class PyTorchSettings:
use_cpu: bool
use_fp16: bool
gpu_index: int
budget_limit: int
force_cache_wipe: bool = False

# PyTorch 2.0 does not support FP16 when using CPU
def __post_init__(self):
Expand Down Expand Up @@ -122,4 +133,5 @@ def get_settings(context: NodeContext) -> PyTorchSettings:
use_fp16=settings.get_bool("use_fp16", False),
gpu_index=settings.get_int("gpu_index", 0, parse_str=True),
budget_limit=settings.get_int("budget_limit", 0, parse_str=True),
force_cache_wipe=settings.get_bool("force_cache_wipe", False),
)
26 changes: 21 additions & 5 deletions backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, List, NewType, Sequence, Union
from typing import Callable, Iterable, List, Literal, NewType, Sequence, Union

from sanic.log import logger

Expand Down Expand Up @@ -342,7 +342,8 @@ def __init__(
self.__settings = settings
self._storage_dir = storage_dir

self.cleanup_fns: set[Callable[[], None]] = set()
self.chain_cleanup_fns: set[Callable[[], None]] = set()
self.node_cleanup_fns: set[Callable[[], None]] = set()

@property
def aborted(self) -> bool:
Expand Down Expand Up @@ -373,8 +374,15 @@ def settings(self) -> SettingsParser:
def storage_dir(self) -> Path:
return self._storage_dir

def add_cleanup(self, fn: Callable[[], None]) -> None:
self.cleanup_fns.add(fn)
def add_cleanup(
self, fn: Callable[[], None], after: Literal["node", "chain"] = "chain"
) -> None:
if after == "chain":
self.chain_cleanup_fns.add(fn)
elif after == "node":
self.node_cleanup_fns.add(fn)
else:
raise ValueError(f"Unknown cleanup type: {after}")


class Executor:
Expand Down Expand Up @@ -591,6 +599,14 @@ def get_lazy_evaluation_time():
)
await self.progress.suspend()

for fn in context.node_cleanup_fns:
try:
fn()
except Exception as e:
logger.error(f"Error running cleanup function: {e}")
finally:
context.node_cleanup_fns.remove(fn)

lazy_time_after = get_lazy_evaluation_time()
execution_time -= lazy_time_after - lazy_time_before

Expand Down Expand Up @@ -824,7 +840,7 @@ async def __process_nodes(self):

# Run cleanup functions
for context in self.__context_cache.values():
for fn in context.cleanup_fns:
for fn in context.chain_cleanup_fns:
try:
fn()
except Exception as e:
Expand Down

0 comments on commit 22222de

Please sign in to comment.