Skip to content

Commit

Permalink
fix(api): run torch gc alongside python (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Feb 17, 2023
1 parent 1ca0c01 commit 0ed4af1
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 19 deletions.
2 changes: 1 addition & 1 deletion api/onnx_web/chain/correct_codeformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ def correct_codeformer(
from codeformer import CodeFormer

device = job.get_device()
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_device())
pipe = CodeFormer(upscale=upscale.face_outscale).to(device.torch_str())
return pipe(stage_source or source)
4 changes: 2 additions & 2 deletions api/onnx_web/chain/correct_gfpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def load_gfpgan(
server: ServerContext,
_stage: StageParams,
upscale: UpscaleParams,
_device: DeviceParams,
device: DeviceParams,
):
# must be within the load function for patch to take effect
from gfpgan import GFPGANer
Expand All @@ -40,7 +40,7 @@ def load_gfpgan(
)

server.cache.set("gfpgan", cache_key, gfpgan)
run_gc()
run_gc([device])

return gfpgan

Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_resrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def load_resrgan(
)

server.cache.set("resrgan", cache_key, upsampler)
run_gc()
run_gc([device])

return upsampler

Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/chain/upscale_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def load_stable_diffusion(
)

server.cache.set("diffusion", cache_key, pipe)
run_gc()
run_gc([device])

return pipe

Expand Down
8 changes: 4 additions & 4 deletions api/onnx_web/diffusion/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,16 +115,16 @@ def load_pipeline(
)

if device is not None and hasattr(scheduler, "to"):
scheduler = scheduler.to(device.torch_device())
scheduler = scheduler.to(device.torch_str())

pipe.scheduler = scheduler
server.cache.set("scheduler", scheduler_key, scheduler)
run_gc()
run_gc([device])

else:
logger.debug("unloading previous diffusion pipeline")
server.cache.drop("diffusion", pipe_key)
run_gc()
run_gc([device])

if lpw:
custom_pipeline = "./onnx_web/diffusion/lpw_stable_diffusion_onnx.py"
Expand All @@ -149,7 +149,7 @@ def load_pipeline(
)

if device is not None and hasattr(pipe, "to"):
pipe = pipe.to(device.torch_device())
pipe = pipe.to(device.torch_str())

server.cache.set("diffusion", pipe_key, pipe)
server.cache.set("scheduler", scheduler_key, scheduler)
Expand Down
17 changes: 12 additions & 5 deletions api/onnx_web/diffusion/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,11 @@ def run_txt2img_pipeline(
dest = save_image(server, output, image)
save_params(server, output, params, size, upscale=upscale)

del pipe
del image
del result
run_gc()

run_gc([job.get_device()])

logger.info("finished txt2img job: %s", dest)

Expand Down Expand Up @@ -147,9 +149,11 @@ def run_img2img_pipeline(
size = Size(*source_image.size)
save_params(server, output, params, size, upscale=upscale)

del pipe
del image
del result
run_gc()

run_gc([job.get_device()])

logger.info("finished img2img job: %s", dest)

Expand Down Expand Up @@ -200,7 +204,8 @@ def run_inpaint_pipeline(
save_params(server, output, params, size, upscale=upscale, border=border)

del image
run_gc()

run_gc([job.get_device()])

logger.info("finished inpaint job: %s", dest)

Expand All @@ -226,7 +231,8 @@ def run_upscale_pipeline(
save_params(server, output, params, size, upscale=upscale)

del image
run_gc()

run_gc([job.get_device()])

logger.info("finished upscale job: %s", dest)

Expand Down Expand Up @@ -263,6 +269,7 @@ def run_blend_pipeline(
save_params(server, output, params, size, upscale=upscale)

del image
run_gc()

run_gc([job.get_device()])

logger.info("finished blend job: %s", dest)
2 changes: 1 addition & 1 deletion api/onnx_web/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def ort_provider(self) -> Tuple[str, Any]:
def sess_options(self) -> SessionOptions:
return SessionOptions()

def torch_device(self) -> str:
def torch_str(self) -> str:
if self.device.startswith("cuda"):
return self.device
else:
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/server/device_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def job_done(f: Future):
key,
format_exception(type(err), err, err.__traceback__),
)
run_gc()
run_gc([self.devices[device]])

future.add_done_callback(job_done)

Expand Down
12 changes: 9 additions & 3 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch

from .params import SizeChart
from .params import DeviceParams, SizeChart
from .server.model_cache import ModelCache

logger = getLogger(__name__)
Expand Down Expand Up @@ -134,7 +134,13 @@ def get_size(val: Union[int, str, None]) -> SizeChart:
raise Exception("invalid size")


def run_gc():
def run_gc(devices: List[DeviceParams] = []):
logger.debug("running garbage collection")
gc.collect()
torch.cuda.empty_cache()

if torch.cuda.is_available():
for device in devices:
logger.debug("running Torch garbage collection for device: %s", device)
with torch.cuda.device(device.torch_str()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

0 comments on commit 0ed4af1

Please sign in to comment.