Skip to content

Commit

Permalink
lint(api): remove global paths entirely
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed Jan 16, 2023
1 parent 4615614 commit c98c0ff
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 41 deletions.
69 changes: 33 additions & 36 deletions api/onnx_web/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flask_executor import Executor
from io import BytesIO
from PIL import Image
from os import environ, makedirs, path, scandir
from os import makedirs, path, scandir
from typing import Tuple

from .image import (
Expand Down Expand Up @@ -57,17 +57,6 @@
import json
import numpy as np

# paths
bundle_path = environ.get('ONNX_WEB_BUNDLE_PATH',
path.join('..', 'gui', 'out'))
model_path = environ.get('ONNX_WEB_MODEL_PATH', path.join('..', 'models'))
output_path = environ.get('ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs'))
params_path = environ.get('ONNX_WEB_PARAMS_PATH', 'params.json')

# options
cors_origin = environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(',')
num_workers = int(environ.get('ONNX_WEB_NUM_WORKERS', 1))

# pipeline caching
available_models = []
config_params = {}
Expand Down Expand Up @@ -109,14 +98,12 @@
# TODO: load from model_path
upscale_models = [
'RealESRGAN_x4plus',
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', # TODO: convert GFPGAN
'GFPGANv1.3',
# TODO: convert GFPGAN
# 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
]


def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', bundle_path), filename)


def url_from_rule(rule) -> str:
options = {}
for arg in rule.arguments:
Expand All @@ -125,10 +112,6 @@ def url_from_rule(rule) -> str:
return url_for(rule.endpoint, **options)


def get_model_path(model: str):
return safer_join(model_path, model)


def pipeline_from_request() -> Tuple[BaseParams, Size]:
user = request.remote_addr

Expand Down Expand Up @@ -210,37 +193,51 @@ def upscale_from_request() -> UpscaleParams:
denoise=denoise,
)

def check_paths():
if not path.exists(model_path):

def check_paths(context: ServerContext):
if not path.exists(context.model_path):
raise RuntimeError('model path must exist')

if not path.exists(output_path):
makedirs(output_path)
if not path.exists(context.output_path):
makedirs(context.output_path)


def load_models():
def load_models(context: ServerContext):
global available_models
available_models = [f.name for f in scandir(model_path) if f.is_dir()]
available_models = [f.name for f in scandir(
context.model_path) if f.is_dir()]


def load_params():
def load_params(context: ServerContext):
global config_params
with open(params_path) as f:
params_file = path.join(context.params_path, 'params.json')
with open(params_file) as f:
config_params = json.load(f)


check_paths()
load_models()
load_params()
context = ServerContext()

check_paths(context)
load_models(context)
load_params(context)

app = Flask(__name__)
app.config['EXECUTOR_MAX_WORKERS'] = num_workers
app.config['EXECUTOR_MAX_WORKERS'] = context.num_workers
app.config['EXECUTOR_PROPAGATE_EXCEPTIONS'] = True

CORS(app, origins=cors_origin)
CORS(app, origins=context.cors_origin)
executor = Executor(app)

context = ServerContext(bundle_path, model_path, output_path, params_path)

# TODO: these two use context

def get_model_path(model: str):
return safer_join(context.model_path, model)


def serve_bundle_file(filename='index.html'):
return send_from_directory(path.join('..', context.bundle_path), filename)


# routes

Expand Down Expand Up @@ -418,4 +415,4 @@ def ready():

@app.route('/api/output/<path:filename>')
def output(filename: str):
return send_from_directory(path.join('..', output_path), filename, as_attachment=False)
return send_from_directory(path.join('..', context.output_path), filename, as_attachment=False)
29 changes: 24 additions & 5 deletions api/onnx_web/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from os import path
from os import environ, path
from time import time
from struct import pack
from typing import Any, Dict, Tuple, Union
Expand Down Expand Up @@ -54,15 +54,34 @@ def __init__(self, left: int, right: int, top: int, bottom: int) -> None:
class ServerContext:
def __init__(
self,
bundle_path: str,
model_path: str,
output_path: str,
params_path: str
bundle_path: str = '.',
model_path: str = '.',
output_path: str = '.',
params_path: str = '.',
cors_origin: str = '*',
num_workers: int = 1,
) -> None:
self.bundle_path = bundle_path
self.model_path = model_path
self.output_path = output_path
self.params_path = params_path
self.cors_origin = cors_origin
self.num_workers = num_workers

@classmethod
def from_environ():
return ServerContext(
bundle_path=environ.get('ONNX_WEB_BUNDLE_PATH',
path.join('..', 'gui', 'out')),
model_path=environ.get('ONNX_WEB_MODEL_PATH',
path.join('..', 'models')),
output_path=environ.get(
'ONNX_WEB_OUTPUT_PATH', path.join('..', 'outputs')),
params_path=environ.get('ONNX_WEB_PARAMS_PATH', '.'),
# others
cors_origin=environ.get('ONNX_WEB_CORS_ORIGIN', '*').split(','),
num_workers=int(environ.get('ONNX_WEB_NUM_WORKERS', 1)),
)


class Size:
Expand Down

0 comments on commit c98c0ff

Please sign in to comment.