Skip to content

Commit

Permalink
split commandline args into its own file
Browse files Browse the repository at this point in the history
make launch.py use the same command line argument parser as the main program
  • Loading branch information
AUTOMATIC1111 committed Mar 25, 2023
1 parent 3ec7e19 commit 8c80136
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 833 deletions.
77 changes: 23 additions & 54 deletions launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,27 @@
import importlib.util
import shlex
import platform
import argparse
import json

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ui-settings-file", type=str, default='config.json')
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
args, _ = parser.parse_known_args(sys.argv)
from modules import cmd_args
from modules.paths_internal import script_path, extensions_dir

script_path = os.path.dirname(__file__)
data_path = args.data_dir
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)

args, _ = cmd_args.parser.parse_known_args()

dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
skip_install = False
dir_repos = "repositories"

if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'


def check_python_version():
is_windows = platform.system() == "Windows"
major = sys.version_info.major
Expand Down Expand Up @@ -72,23 +71,6 @@ def commit_hash():
return stored_commit_hash


def extract_arg(args, name):
return [x for x in args if x != name], name in args


def extract_opt(args, name):
opt = None
is_present = False
if name in args:
is_present = True
idx = args.index(name)
del args[idx]
if idx < len(args) and args[idx][0] != "-":
opt = args[idx]
del args[idx]
return args, is_present, opt


def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None:
print(desc)
Expand Down Expand Up @@ -225,23 +207,22 @@ def list_extensions(settings_file):

disabled_extensions = set(settings.get('disabled_extensions', []))

return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]


def run_extensions_installers(settings_file):
if not os.path.isdir(dir_extensions):
if not os.path.isdir(extensions_dir):
return

for dirname_extension in list_extensions(settings_file):
run_extension_installer(os.path.join(data_path, dir_extensions, dirname_extension))
run_extension_installer(os.path.join(extensions_dir, dirname_extension))


def prepare_environment():
global skip_install

torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")

xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
Expand All @@ -260,32 +241,18 @@ def prepare_environment():
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")

sys.argv += shlex.split(commandline_args)

sys.argv, _ = extract_arg(sys.argv, '-f')
sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')

This comment has been minimized.

Copy link
@micky2be

micky2be Mar 30, 2023

Contributor

Origin of the bug reported in #8935

xformers = '--xformers' in sys.argv
ngrok = '--ngrok' in sys.argv

if not skip_python_version_check:
if not args.skip_python_version_check:
check_python_version()

commit = commit_hash()

print(f"Python {sys.version}")
print(f"Commit hash: {commit}")

if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)

if not skip_torch_cuda_test:
if not args.skip_torch_cuda_test:
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")

if not is_installed("gfpgan"):
Expand All @@ -297,7 +264,7 @@ def prepare_environment():
if not is_installed("open_clip"):
run_pip(f"install {openclip_package}", "open_clip")

if (not is_installed("xformers") or reinstall_xformers) and xformers:
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
if platform.system() == "Windows":
if platform.python_version().startswith("3.10"):
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
Expand All @@ -309,7 +276,7 @@ def prepare_environment():
elif platform.system() == "Linux":
run_pip(f"install {xformers_package}", "xformers")

if not is_installed("pyngrok") and ngrok:
if not is_installed("pyngrok") and args.ngrok:
run_pip("install pyngrok", "ngrok")

os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
Expand All @@ -329,18 +296,18 @@ def prepare_environment():

run_extensions_installers(settings_file=args.ui_settings_file)

if update_check:
if args.update_check:
version_check(commit)

if update_all_extensions:
git_pull_recursive(os.path.join(data_path, dir_extensions))
if args.update_all_extensions:
git_pull_recursive(extensions_dir)

if "--exit" in sys.argv:
print("Exiting because of --exit argument")
exit(0)

if run_tests:
exitcode = tests(test_dir)
if args.tests and not args.no_tests:
exitcode = tests(args.tests)
exit(exitcode)


Expand All @@ -354,6 +321,8 @@ def tests(test_dir):
sys.argv.append("--skip-torch-cuda-test")
if "--disable-nan-check" not in sys.argv:
sys.argv.append("--disable-nan-check")
if "--no-tests" not in sys.argv:
sys.argv.append("--no-tests")

print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")

Expand Down
Loading

0 comments on commit 8c80136

Please sign in to comment.