From 0d4c0a5942a150cbb407a621f4a6c85eef1f41b6 Mon Sep 17 00:00:00 2001 From: Sean Sube Date: Sun, 8 Jan 2023 12:46:52 -0600 Subject: [PATCH] fix(api): remove prompt from output name --- api/onnx_web/serve.py | 18 +++++++++--------- api/requirements.txt | 3 +-- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/api/onnx_web/serve.py b/api/onnx_web/serve.py index 656d2461f..d85aacdf0 100644 --- a/api/onnx_web/serve.py +++ b/api/onnx_web/serve.py @@ -20,9 +20,9 @@ from hashlib import sha256 from io import BytesIO from PIL import Image -from stringcase import spinalcase from struct import pack from os import environ, makedirs, path, scandir +from typing import Tuple, Union import numpy as np # defaults @@ -74,15 +74,15 @@ } -def get_and_clamp_float(args, key, default_value, max_value, min_value=0.0): +def get_and_clamp_float(args, key: str, default_value: float, max_value: float, min_value=0.0): return min(max(float(args.get(key, default_value)), min_value), max_value) -def get_and_clamp_int(args, key, default_value, max_value, min_value=1): +def get_and_clamp_int(args, key: str, default_value: int, max_value: int, min_value=1): return min(max(int(args.get(key, default_value)), min_value), max_value) -def get_from_map(args, key, values, default): +def get_from_map(args, key: str, values, default): selected = args.get(key, default) if selected in values: return values[selected] @@ -90,7 +90,7 @@ def get_from_map(args, key, values, default): return values[default] -def get_model_path(model): +def get_model_path(model: str): return safer_join(model_path, model) @@ -104,7 +104,7 @@ def get_latents_from_seed(seed: int, width: int, height: int) -> np.ndarray: return image_latents -def load_pipeline(pipeline, model, provider, scheduler): +def load_pipeline(pipeline, model: str, provider: str, scheduler): global last_pipeline_instance global last_pipeline_scheduler global last_pipeline_options @@ -141,9 +141,9 @@ def json_with_cors(data, origin='*'): return res -def make_output_path(type, params): +def make_output_path(type: str, params: Tuple[Union[str, int, float]]): sha = sha256() - sha.update(type) + sha.update(type.encode('utf-8')) for param in params: if isinstance(param, str): sha.update(param.encode('utf-8')) @@ -154,7 +154,7 @@ def make_output_path(type, params): else: print('cannot hash param: %s, %s' % (param, type(param))) - output_file = 'txt2img_%s_%s.png' % (params[0], sha.hexdigest()) + output_file = '%s_%s.png' % (type, sha.hexdigest()) output_full = safer_join(output_path, output_file) return (output_file, output_full) diff --git a/api/requirements.txt b/api/requirements.txt index f55f073f1..ea303f264 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -8,5 +8,4 @@ protobuf<4,>=3.20.2 transformers ### Server packages ### -flask -stringcase \ No newline at end of file +flask \ No newline at end of file