Skip to content

Commit

Permalink
Rework installer and requirement checks
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed May 24, 2023
1 parent b88e115 commit f7c31b0
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 184 deletions.
178 changes: 3 additions & 175 deletions install.py
Original file line number Diff line number Diff line change
@@ -1,182 +1,10 @@
from __future__ import annotations

import importlib
import logging
import sys
from pathlib import Path

logger = logging.getLogger(__name__)


def is_empty_line(line):
return line is None or line.strip() == "" or line.strip().startswith("#")


def get_requirements() -> list[str]:
try:
requirements_file_path = Path(__file__).parent / "requirements.txt"

if not requirements_file_path.exists():
raise FileNotFoundError("requirements.txt not found.")

with requirements_file_path.open() as file:
requirements = [
line.strip() # Stripping whitespace at start/end
for line in file.readlines()
if not is_empty_line(line)
]

except Exception as e:
raise Exception("Failed to parse requirements file.") from e

return requirements


def split_package(requirement: str) -> tuple[str, str | None, str | None, str | None]:
"""
Split a requirement string into package name, extras, comparison operator and version.
:param requirement: Requirement string. E.g., "package[extra]>=1.0.0"
:return: tuple of (package name, extras, comparison operator, version)
"""
delimiters = ["==", ">=", "<=", ">", "<", "~=", "!="]

package = requirement
extras = None
delimiter = None
version = None

for delimiter in delimiters:
if delimiter in requirement:
splits = requirement.split(delimiter)
package_name_and_extras, version = splits[0], splits[1]

# Check for extras
if "[" in package_name_and_extras:
package, extras = map(str.strip, package_name_and_extras.split("["))
extras = extras.rstrip("]")
else:
package = package_name_and_extras.strip()

break

return package, extras, delimiter, version


def get_dynamic_prompts_version() -> str | None:
"""
Get the version of dynamicprompts from the requirements.
:return: Version of dynamicprompts if found, else None.
"""
requirements = get_requirements()

dynamicprompts_requirement = next(
(r for r in requirements if r.startswith("dynamicprompts")),
None,
)

if dynamicprompts_requirement is None:
return None

_, _, _, dynamicprompts_requirement_version = split_package(
dynamicprompts_requirement,
)

return dynamicprompts_requirement_version


def check_versions():
"""Deprecated. Use check_and_install_dependencies() instead."""
return check_and_install_dependencies()


def check_and_install_dependencies():
import launch # from AUTOMATIC1111

requirements = get_requirements()

for requirement in requirements:
try:
package, _, delimiter, package_version = split_package(requirement)

if not launch.is_installed(package):
logger.info(f"Installing {package}=={package_version}...")
launch.run_pip(f"install {requirement}", f"{requirement}")
else:
module = importlib.import_module(".", package)
version = getattr(module, "__version__", None)
# handle the case where the dependency version is pinned or when no version is specified
if delimiter == "==" or version is None:
if version is not None and version != package_version:
logger.info(
f"Found {package}=={version} but expected {package_version}. Trying to update...",
)
launch.run_pip(
f"install --upgrade {requirement}",
f"{requirement}",
)
else:
# more general handling of version comparison operators will be handled in the future
pass

except Exception as e:
logger.error(f"Failed to check/update package {package}: {str(e)}")


def get_update_command() -> str | None:
"""
Get the update command for dynamicprompts.
:return: Update command for dynamicprompts if found, else None.
"""
requirements = get_requirements()

# Find the requirement line for dynamicprompts
dynamicprompts_requirement = next(
(r for r in requirements if r.startswith("dynamicprompts")),
None,
)

# If dynamicprompts requirement was not found
if dynamicprompts_requirement is None:
return None

# If found, return the pip install command
return f"{sys.executable} -m pip install '{dynamicprompts_requirement}'"


def check_correct_dynamicprompts_installed() -> bool:
"""
Check if the installed version of dynamicprompts matches the required version.
:return: True if versions match, else False.
"""
try:
import dynamicprompts
except ImportError:
logger.error("dynamicprompts module is not installed.")
return False
except Exception as e:
logger.exception("Unexpected error while importing dynamicprompts.", e)
return False

dynamicprompts_requirement_version = get_dynamic_prompts_version()

if dynamicprompts_requirement_version is None:
logger.warning("Unable to find dynamicprompts version requirement.")
return False

if dynamicprompts.__version__ != dynamicprompts_requirement_version:
update_command = get_update_command()
logger.warning(
f"Installed dynamicprompts version ({dynamicprompts.__version__}) does not match the required version ({dynamicprompts_requirement_version}). "
f"Please update manually by running: {update_command}",
)
return False

return True


if __name__ == "__main__":
check_and_install_dependencies()
from sd_dynamic_prompts.version_tools import install_requirements

install_requirements()
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[project]
dependencies = [
"send2trash~=1.8",
"dynamicprompts[attentiongrabber,magicprompt]~=0.25.2",
]

[tool.pytest.ini_options]
minversion = "7.0"
pythonpath = [
Expand All @@ -19,7 +25,6 @@ select = [
ignore = [
"C901", # Complexity
"E501", # Line length
"B905",
]
unfixable = [
"B007", # Loop control variable not used within the loop body
Expand Down
2 changes: 0 additions & 2 deletions requirements.txt

This file was deleted.

28 changes: 22 additions & 6 deletions sd_dynamic_prompts/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
import math
from functools import lru_cache
from pathlib import Path
from string import Template

Expand All @@ -16,7 +17,6 @@
from modules.processing import fix_seed
from modules.shared import opts

from install import check_correct_dynamicprompts_installed, get_update_command
from sd_dynamic_prompts import __version__, callbacks
from sd_dynamic_prompts.element_ids import make_element_id
from sd_dynamic_prompts.generator_builder import GeneratorBuilder
Expand All @@ -40,7 +40,6 @@
if is_debug:
logger.setLevel(logging.DEBUG)

check_correct_dynamicprompts_installed()
base_dir = Path(scripts.basedir())
magicprompt_models_path = get_magicmodels_path(base_dir)

Expand Down Expand Up @@ -76,6 +75,19 @@ def get_prompts(p):
loaded_count = 0


@lru_cache(maxsize=1)
def _get_install_error_message() -> str | None:
try:
from sd_dynamic_prompts.version_tools import get_dynamicprompts_install_result

get_dynamicprompts_install_result().raise_if_incorrect()
except RuntimeError as rte:
return str(rte)
except Exception:
logger.exception("Failed to get dynamicprompts install result")
return None


class Script(scripts.Script):
def __init__(self):
global loaded_count
Expand Down Expand Up @@ -108,8 +120,8 @@ def show(self, is_img2img):
return scripts.AlwaysVisible

def ui(self, is_img2img):
correct_lib_version = check_correct_dynamicprompts_installed()
update_command = get_update_command()
install_message = _get_install_error_message()
correct_lib_version = bool(not install_message)

html_path = base_dir / "helptext.html"
html = html_path.open().read()
Expand All @@ -123,7 +135,10 @@ def ui(self, is_img2img):
jinja_help = jinja_html_path.open().read()

with gr.Group(elem_id=make_element_id("dynamic-prompting")):
with gr.Accordion("Dynamic Prompts", open=False):
title = "Dynamic Prompts"
if not correct_lib_version:
title += " [incorrect installation]"
with gr.Accordion(title, open=False):
is_enabled = gr.Checkbox(
label="Dynamic Prompts enabled",
value=correct_lib_version,
Expand All @@ -133,7 +148,8 @@ def ui(self, is_img2img):

if not correct_lib_version:
gr.HTML(
f"""<span class="warning sddp-warning">Dynamic Prompts is not installed correctly</span>. Please reinstall the dynamic prompts library by running the following command: <span class="sddp-info">{update_command}</span>""",
f"""<span class="warning sddp-warning">Dynamic Prompts is not installed correctly</span>.
{install_message}""",
)

with gr.Group(visible=correct_lib_version):
Expand Down
96 changes: 96 additions & 0 deletions sd_dynamic_prompts/version_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# NB: this file may not import anything from `sd_dynamic_prompts` because it is used by `install.py`.

from __future__ import annotations

import dataclasses
import logging
import shlex
import subprocess
import sys
from functools import lru_cache
from pathlib import Path

import tomli
from packaging.requirements import Requirement

logger = logging.getLogger(__name__)


@dataclasses.dataclass
class InstallResult:
requirement: str
specifier: str
installed: str
correct: bool

@property
def message(self) -> str | None:
if self.correct:
return None
return (
f"You have dynamicprompts {self.installed} installed, "
f"but this extension requires {self.specifier}. "
f"Please run `install.py` from the sd-dynamic-prompts extension directory, "
f"or `{self.pip_install_command}`."
)

@property
def pip_install_command(self) -> str:
return f"pip install {self.requirement}"

def raise_if_incorrect(self) -> None:
message = self.message
if message:
raise RuntimeError(message)


@lru_cache(maxsize=1)
def get_requirements() -> tuple[str]:
toml_text = (Path(__file__).parent.parent / "pyproject.toml").read_text()
return tuple(tomli.loads(toml_text)["project"]["dependencies"])


def get_dynamic_prompts_requirement() -> Requirement | None:
for req in get_requirements():
if req.startswith("dynamicprompts"):
return Requirement(req)
return None


def get_dynamicprompts_install_result() -> InstallResult:
import dynamicprompts

dp_req = get_dynamic_prompts_requirement()
if not dp_req:
raise RuntimeError("dynamicprompts requirement not found")
return InstallResult(
requirement=str(dp_req),
specifier=str(dp_req.specifier),
installed=dynamicprompts.__version__,
correct=(dynamicprompts.__version__ in dp_req.specifier),
)


def install_requirements() -> None:
"""
Invoke pip to install the requirements for the extension.
"""
command = [
sys.executable,
"-m",
"pip",
"install",
*get_requirements(),
]
print(f"sd-dynamic-prompts installer: running {shlex.join(command)}")
subprocess.check_call(command)


def selftest() -> None:
res = get_dynamicprompts_install_result()
print(res)
res.raise_if_incorrect()


if __name__ == "__main__":
selftest()

0 comments on commit f7c31b0

Please sign in to comment.