Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Support for Flux #325

Merged
merged 8 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ hordelib/model_database/stable_diffusion.json
hordelib/model_database/lora.json
ComfyUI
model.ckpt
models/
coverage.lcov
profiles/
longprompts.zip
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ repos:
types-tabulate,
types-tqdm,
types-urllib3,
horde_sdk==0.14.0,
horde_model_reference==0.8.1,
horde_sdk==0.14.3,
horde_model_reference==0.9.0,
]
30 changes: 24 additions & 6 deletions hordelib/horde.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ class HordeLib:
"upscale_sampler.sampler_name": "sampler_name",
"controlnet_apply.strength": "control_strength",
"controlnet_model_loader.control_net_name": "control_type",
# Flux
"cfg_guider.cfg": "cfg_scale",
"random_noise.noise_seed": "seed",
"k_sampler_select.sampler_name": "sampler_name",
"basic_scheduler.denoise": "denoising_strength",
"basic_scheduler.steps": "ddim_steps",
# Stable Cascade
"stable_cascade_empty_latent_image.width": "width",
"stable_cascade_empty_latent_image.height": "height",
Expand Down Expand Up @@ -856,10 +862,15 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
)

# The last LORA always connects to the sampler and clip text encoders (via the clip_skip)
if lora_index == len(payload.get("loras")) - 1:
self.generator.reconnect_input(pipeline_data, "sampler.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "upscale_sampler.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "clip_skip.clip", f"lora_{lora_index}")
if lora_index == len(payload.get("loras")) - 1 and SharedModelManager.manager.compvis:
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model_name"])
if model_details is not None and model_details["baseline"] == "flux_1":
self.generator.reconnect_input(pipeline_data, "cfg_guider.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "basic_scheduler.model", f"lora_{lora_index}")
else:
self.generator.reconnect_input(pipeline_data, "sampler.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "upscale_sampler.model", f"lora_{lora_index}")
self.generator.reconnect_input(pipeline_data, "clip_skip.clip", f"lora_{lora_index}")

# Translate the payload parameters into pipeline parameters
pipeline_params = {}
Expand All @@ -885,7 +896,7 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis

# We inject these parameters to ensure the HordeCheckpointLoader knows what file to load, if necessary
# We don't want to hardcode this into the pipeline.json as we export this directly from ComfyUI
# and don't want to have to rememebr to re-add those keys
# and don't want to have to rememeber to re-add those keys
if "model_loader_stage_c.ckpt_name" in pipeline_params:
pipeline_params["model_loader_stage_c.file_type"] = "stable_cascade_stage_c"
if "model_loader_stage_b.ckpt_name" in pipeline_params:
Expand Down Expand Up @@ -990,7 +1001,12 @@ def _final_pipeline_adjustments(self, payload, pipeline_data) -> tuple[dict, lis
# We do this by reconnecting the nodes in the pipeline to make the input to the vae encoder
# the source image instead of the latent noise generator
if pipeline_params.get("image_loader.image"):
self.generator.reconnect_input(pipeline_data, "sampler.latent_image", "vae_encode")
if SharedModelManager.manager.compvis:
model_details = SharedModelManager.manager.compvis.get_model_reference_info(payload["model_name"])
if isinstance(model_details, dict) and model_details.get("baseline") == "flux_1":
self.generator.reconnect_input(pipeline_data, "sampler_custom_advanced.latent_image", "vae_encode")
else:
self.generator.reconnect_input(pipeline_data, "sampler.latent_image", "vae_encode")
if pipeline_params.get("sc_image_loader.image"):
self.generator.reconnect_input(
pipeline_data,
Expand Down Expand Up @@ -1181,6 +1197,8 @@ def _get_appropriate_pipeline(self, params):
if params.get("hires_fix", False):
return "stable_cascade_2pass"
return "stable_cascade"
if model_details.get("baseline") == "flux_1":
return "flux"
if params.get("control_type"):
if params.get("return_control_map", False):
return "controlnet_annotator"
Expand Down
6 changes: 3 additions & 3 deletions hordelib/model_manager/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,13 @@ def _parse_civitai_lora_data(self, item, adhoc=False):
logger.debug(f"Rejecting LoRa {lora.get('name')} because it doesn't have a url")
return None
# We don't want to start downloading GBs of a single LoRa.
# We just ignore anything over 150Mb. Them's the breaks...
# We just ignore anything over 400Mb. Them's the breaks...
if (
lora["versions"][lora_version]["adhoc"]
and lora["versions"][lora_version]["size_mb"] > 220
and lora["versions"][lora_version]["size_mb"] > 400
and lora["id"] not in self._default_lora_ids
):
logger.debug(f"Rejecting LoRa {lora.get('name')} version {lora_version} because its size is over 220Mb.")
logger.debug(f"Rejecting LoRa {lora.get('name')} version {lora_version} because its size is over 400Mb.")
return None
if lora["versions"][lora_version]["adhoc"] and lora["nsfw"] and not self.nsfw:
logger.debug(f"Rejecting LoRa {lora.get('name')} because worker is SFW.")
Expand Down
Loading
Loading