Skip to content

Commit

Permalink
feat(status): initial status work
Browse files Browse the repository at this point in the history
  • Loading branch information
gadicc committed Apr 18, 2023
1 parent 7854649 commit d1cd39e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
21 changes: 21 additions & 0 deletions api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
from loadModel import loadModel
from send import send, getTimings, clearSession
from status import status
import os
import numpy as np
import skimage
Expand All @@ -25,6 +26,7 @@
from diffusers.models.cross_attention import CrossAttnProcessor
from utils import Storage
from hashlib import sha256
from threading import Timer


RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
Expand Down Expand Up @@ -155,6 +157,18 @@ async def inference(all_inputs: dict, response) -> dict:
if response:
send_opts.update({"response": response})

async def sendStatusAsync():
await response.send(json.dumps(status.get()) + "\n")

def sendStatus():
try:
asyncio.run(sendStatusAsync())
Timer(1.0, sendStatus).start()
except:
pass

Timer(1.0, sendStatus).start()

if model_inputs == None or call_inputs == None:
return {
"$error": {
Expand Down Expand Up @@ -448,6 +462,13 @@ def callback(step: int, timestep: int, latents: torch.FloatTensor):
)
)

else:

def callback(step: int, timestep: int, latents: torch.FloatTensor):
status.update(
"inference", step / model_inputs.get("num_inference_steps", 50)
)

with torch.inference_mode():
custom_pipeline_method = call_inputs.get("custom_pipeline_method", None)

Expand Down
2 changes: 2 additions & 0 deletions api/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
from convert_to_diffusers import main as convert_to_diffusers
from download_checkpoint import main as download_checkpoint
from status import status

USE_DREAMBOOTH = os.environ.get("USE_DREAMBOOTH")
HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
Expand All @@ -20,6 +21,7 @@
MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)


# i.e. don't run during build
def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
if RUNTIME_DOWNLOADS:
Expand Down
6 changes: 6 additions & 0 deletions api/send.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import requests
import hashlib
from requests_futures.sessions import FuturesSession
from status import status as statusInstance

print()
environ = os.environ.copy()
Expand Down Expand Up @@ -92,6 +93,11 @@ async def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
"payload": payload,
}

if status == "start":
statusInstance.update(type, 0.0)
elif status == "done":
statusInstance.update(type, 1.0)

if send_url and sign_key:
input = json.dumps(data, separators=(",", ":")) + sign_key
sig = hashlib.md5(input.encode("utf-8")).hexdigest()
Expand Down
14 changes: 14 additions & 0 deletions api/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class Status:
def __init__(self):
self.type = "init"
self.progress = 0.0

def update(self, type, progress):
self.type = type
self.progress = progress

def get(self):
return {"type": self.type, "progress": self.progress}


status = Status()

0 comments on commit d1cd39e

Please sign in to comment.