Skip to content

Commit

Permalink
fix(api): track last progress within worker
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Mar 18, 2023
1 parent 5106dd4 commit 588c8c7
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 38 deletions.
4 changes: 3 additions & 1 deletion api/onnx_web/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,9 @@ def ready(context: ServerContext, pool: DevicePoolExecutor):
if path.exists(output):
return ready_reply(True)
else:
return ready_reply(True, error=True) # is a missing image really an error? yes will display the retry button
return ready_reply(
True, error=True
) # is a missing image really an error? yes will display the retry button

return ready_reply(
progress.finished,
Expand Down
76 changes: 42 additions & 34 deletions api/onnx_web/worker/context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from logging import getLogger
from os import getpid
from typing import Any, Callable
from typing import Any, Callable, Optional

from torch.multiprocessing import Queue, Value

Expand All @@ -17,8 +17,9 @@ class WorkerContext:
cancel: "Value[bool]"
job: str
pending: "Queue[JobCommand]"
current: "Value[int]"
active_pid: "Value[int]"
progress: "Queue[ProgressCommand]"
last_progress: Optional[ProgressCommand]

def __init__(
self,
Expand All @@ -28,25 +29,25 @@ def __init__(
logs: "Queue[str]",
pending: "Queue[JobCommand]",
progress: "Queue[ProgressCommand]",
current: "Value[int]",
active_pid: "Value[int]",
):
self.job = job
self.device = device
self.cancel = cancel
self.progress = progress
self.logs = logs
self.pending = pending
self.current = current
self.active_pid = active_pid

def is_cancelled(self) -> bool:
return self.cancel.value

def is_current(self) -> bool:
return self.get_current() == getpid()
def is_active(self) -> bool:
return self.get_active() == getpid()

def get_current(self) -> int:
with self.current.get_lock():
return self.current.value
def get_active(self) -> int:
with self.active_pid.get_lock():
return self.active_pid.value

def get_device(self) -> DeviceParams:
"""
Expand All @@ -55,7 +56,10 @@ def get_device(self) -> DeviceParams:
return self.device

def get_progress(self) -> int:
return self.progress.value
if self.last_progress is not None:
return self.last_progress.progress

return 0

def get_progress_callback(self) -> ProgressCallback:
def on_progress(step: int, timestep: int, latents: Any):
Expand All @@ -73,44 +77,48 @@ def set_progress(self, progress: int) -> None:
raise RuntimeError("job has been cancelled")
else:
logger.debug("setting progress for job %s to %s", self.job, progress)
self.last_progress = ProgressCommand(
self.job,
self.device.device,
False,
progress,
self.is_cancelled(),
False,
)

self.progress.put(
ProgressCommand(
self.job,
self.device.device,
False,
progress,
self.is_cancelled(),
False,
),
self.last_progress,
block=False,
)

def set_finished(self) -> None:
logger.debug("setting finished for job %s", self.job)
self.last_progress = ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
False,
)
self.progress.put(
ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
False,
),
self.last_progress,
block=False,
)

def set_failed(self) -> None:
logger.warning("setting failure for job %s", self.job)
try:
self.last_progress = ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
True,
)
self.progress.put(
ProgressCommand(
self.job,
self.device.device,
True,
self.get_progress(),
self.is_cancelled(),
True,
),
self.last_progress,
block=False,
)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion api/onnx_web/worker/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def create_device_worker(self, device: DeviceParams) -> None:
progress=self.progress,
logs=self.logs,
pending=pending,
current=current,
active_pid=current,
)
self.context[name] = context
worker = Process(
Expand Down
4 changes: 2 additions & 2 deletions api/onnx_web/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def worker_main(context: WorkerContext, server: ServerContext):

while True:
try:
if not context.is_current():
if not context.is_active():
logger.warning(
"worker %s has been replaced by %s, exiting",
getpid(),
context.get_current(),
context.get_active(),
)
exit(EXIT_REPLACED)

Expand Down

0 comments on commit 588c8c7

Please sign in to comment.