Skip to content

Commit

Permalink
feat(api): add env vars for controlnet conversion and opset
Browse files Browse the repository at this point in the history
  • Loading branch information
ssube committed May 15, 2023
1 parent a993886 commit 0d51d61
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 3 deletions.
3 changes: 2 additions & 1 deletion api/onnx_web/convert/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .upscaling.resrgan import convert_upscale_resrgan
from .upscaling.swinir import convert_upscaling_swinir
from .utils import (
DEFAULT_OPSET,
ConversionContext,
download_progress,
remove_prefix,
Expand Down Expand Up @@ -572,7 +573,7 @@ def main(args=None) -> int:
)
parser.add_argument(
"--opset",
default=14,
default=DEFAULT_OPSET,
type=int,
help="The version of the ONNX operator set to use.",
)
Expand Down
1 change: 1 addition & 0 deletions api/onnx_web/convert/diffusion/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def convert_diffusion_diffusers_cnet(
pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet")

pipe_cnet = pipe_cnet.to(device=device, dtype=dtype)
run_gc()

if is_torch_2_0:
pipe_cnet.set_attn_processor(AttnProcessor())
Expand Down
11 changes: 10 additions & 1 deletion api/onnx_web/convert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..constants import ONNX_WEIGHTS
from ..server import ServerContext
from ..utils import get_boolean

logger = getLogger(__name__)

Expand All @@ -28,6 +29,8 @@
ModelDict = Dict[str, Union[str, int]]
LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]]

DEFAULT_OPSET = 14


class ConversionContext(ServerContext):
def __init__(
Expand All @@ -36,7 +39,7 @@ def __init__(
cache_path: Optional[str] = None,
device: Optional[str] = None,
half: bool = False,
opset: Optional[int] = None,
opset: int = DEFAULT_OPSET,
token: Optional[str] = None,
prune: Optional[List[str]] = None,
control: bool = True,
Expand All @@ -57,6 +60,12 @@ def __init__(

self.map_location = torch.device(self.training_device)

@classmethod
def from_environ(cls):
context = super().from_environ()
context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True)
context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET))


def download_progress(urls: List[Tuple[str, str]]):
for url, dest in urls:
Expand Down
5 changes: 4 additions & 1 deletion api/onnx_web/server/hacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from os import path
from urllib.parse import urlparse

from .context import ServerContext
from ..utils import run_gc
from .context import ServerContext

logger = getLogger(__name__)

Expand Down Expand Up @@ -135,20 +135,23 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str:
def apply_patch_basicsr(server: ServerContext):
logger.debug("patching BasicSR module")
import basicsr.utils.download_util

basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl
basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server)


def apply_patch_codeformer(server: ServerContext):
logger.debug("patching CodeFormer module")
import codeformer.facelib.utils.misc

codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl
codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server)


def apply_patch_facexlib(server: ServerContext):
logger.debug("patching Facexlib module")
import facexlib.utils

facexlib.utils.load_file_from_url = partial(patch_cache_path, server)


Expand Down

0 comments on commit 0d51d61

Please sign in to comment.