diff --git a/api/onnx_web/convert/__main__.py b/api/onnx_web/convert/__main__.py index 437dc0394..7e3315c68 100644 --- a/api/onnx_web/convert/__main__.py +++ b/api/onnx_web/convert/__main__.py @@ -22,6 +22,7 @@ from .upscaling.resrgan import convert_upscale_resrgan from .upscaling.swinir import convert_upscaling_swinir from .utils import ( + DEFAULT_OPSET, ConversionContext, download_progress, remove_prefix, @@ -572,7 +573,7 @@ def main(args=None) -> int: ) parser.add_argument( "--opset", - default=14, + default=DEFAULT_OPSET, type=int, help="The version of the ONNX operator set to use.", ) diff --git a/api/onnx_web/convert/diffusion/diffusers.py b/api/onnx_web/convert/diffusion/diffusers.py index a59b696ad..a56f21970 100644 --- a/api/onnx_web/convert/diffusion/diffusers.py +++ b/api/onnx_web/convert/diffusion/diffusers.py @@ -120,6 +120,7 @@ def convert_diffusion_diffusers_cnet( pipe_cnet = UNet2DConditionModel_CNet.from_pretrained(source, subfolder="unet") pipe_cnet = pipe_cnet.to(device=device, dtype=dtype) + run_gc() if is_torch_2_0: pipe_cnet.set_attn_processor(AttnProcessor()) diff --git a/api/onnx_web/convert/utils.py b/api/onnx_web/convert/utils.py index 9d22fbad2..bd219d90f 100644 --- a/api/onnx_web/convert/utils.py +++ b/api/onnx_web/convert/utils.py @@ -17,6 +17,7 @@ from ..constants import ONNX_WEIGHTS from ..server import ServerContext +from ..utils import get_boolean logger = getLogger(__name__) @@ -28,6 +29,8 @@ ModelDict = Dict[str, Union[str, int]] LegacyModel = Tuple[str, str, Optional[bool], Optional[bool], Optional[int]] +DEFAULT_OPSET = 14 + class ConversionContext(ServerContext): def __init__( @@ -36,7 +39,7 @@ def __init__( cache_path: Optional[str] = None, device: Optional[str] = None, half: bool = False, - opset: Optional[int] = None, + opset: int = DEFAULT_OPSET, token: Optional[str] = None, prune: Optional[List[str]] = None, control: bool = True, @@ -57,6 +60,12 @@ def __init__( self.map_location = torch.device(self.training_device) + @classmethod + def from_environ(cls): + context = super().from_environ() + context.control = get_boolean(environ, "ONNX_WEB_CONVERT_CONTROL", True) + context.opset = int(environ.get("ONNX_WEB_CONVERT_OPSET", DEFAULT_OPSET)) + def download_progress(urls: List[Tuple[str, str]]): for url, dest in urls: diff --git a/api/onnx_web/server/hacks.py b/api/onnx_web/server/hacks.py index 1c33953cb..b59bb73a3 100644 --- a/api/onnx_web/server/hacks.py +++ b/api/onnx_web/server/hacks.py @@ -4,8 +4,8 @@ from os import path from urllib.parse import urlparse -from .context import ServerContext from ..utils import run_gc +from .context import ServerContext logger = getLogger(__name__) @@ -135,6 +135,7 @@ def patch_cache_path(server: ServerContext, url: str, **kwargs) -> str: def apply_patch_basicsr(server: ServerContext): logger.debug("patching BasicSR module") import basicsr.utils.download_util + basicsr.utils.download_util.download_file_from_google_drive = patch_not_impl basicsr.utils.download_util.load_file_from_url = partial(patch_cache_path, server) @@ -142,6 +143,7 @@ def apply_patch_basicsr(server: ServerContext): def apply_patch_codeformer(server: ServerContext): logger.debug("patching CodeFormer module") import codeformer.facelib.utils.misc + codeformer.facelib.utils.misc.download_pretrained_models = patch_not_impl codeformer.facelib.utils.misc.load_file_from_url = partial(patch_cache_path, server) @@ -149,6 +151,7 @@ def apply_patch_codeformer(server: ServerContext): def apply_patch_facexlib(server: ServerContext): logger.debug("patching Facexlib module") import facexlib.utils + facexlib.utils.load_file_from_url = partial(patch_cache_path, server)