Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui, worker): FLUX LoRAs in linear UI #6886

Merged
merged 9 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 58 additions & 2 deletions invokeai/app/invocations/flux_lora_loader.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
from typing import Optional

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import BaseModelType


@invocation_output("flux_lora_loader_output")
class FluxLoRALoaderOutput(BaseInvocationOutput):
"""FLUX LoRA Loader Output"""

transformer: TransformerField = OutputField(
transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="FLUX Transformer"
)

Expand All @@ -19,6 +28,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput):
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to a FLUX transformer."""
Expand Down Expand Up @@ -51,3 +61,49 @@ def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
)

return FluxLoRALoaderOutput(transformer=transformer)


@invocation(
"flux_lora_collection_loader",
title="FLUX LoRA Collection Loader",
tags=["lora", "model", "flux"],
category="model",
version="1.0.0",
classification=Classification.Prototype,
)
class FLUXLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to a FLUX transformer."""

loras: LoRAField | list[LoRAField] = InputField(
description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)

transformer: Optional[TransformerField] = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)

def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput:
output = FluxLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []

for lora in loras:
if lora.lora.key in added_loras:
continue

if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")

assert lora.lora.base is BaseModelType.Flux

added_loras.append(lora.lora.key)

if self.transformer is not None:
if output.transformer is None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(lora)

return output
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import type { RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import { zModelIdentifierField } from 'features/nodes/types/common';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
import type { Invocation, S } from 'services/api/types';

export const addFLUXLoRAs = (
state: RootState,
g: Graph,
denoise: Invocation<'flux_denoise'>,
modelLoader: Invocation<'flux_model_loader'>
): void => {
const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux');
const loraCount = enabledLoRAs.length;

if (loraCount === 0) {
return;
}

const loraMetadata: S['LoRAMetadataField'][] = [];

// We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies
// each LoRA to the UNet and CLIP.
const loraCollector = g.addNode({
id: getPrefixedId('lora_collector'),
type: 'collect',
});
const loraCollectionLoader = g.addNode({
type: 'flux_lora_collection_loader',
id: getPrefixedId('flux_lora_collection_loader'),
});

g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras');
// Use model loader as transformer input
g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer');
// Reroute transformer connections through the LoRA collection loader
g.deleteEdgesTo(denoise, ['transformer']);

g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer');

for (const lora of enabledLoRAs) {
const { weight } = lora;
const parsedModel = zModelIdentifierField.parse(lora.model);

const loraSelector = g.addNode({
type: 'lora_selector',
id: getPrefixedId('lora_selector'),
lora: parsedModel,
weight,
});

loraMetadata.push({
model: parsedModel,
weight,
});

g.addEdge(loraSelector, 'lora', loraCollector, 'item');
}

g.upsertMetadata({ loras: loraMetadata });
};
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';

import { addFLUXLoRAs } from './addFLUXLoRAs';

const log = logger('system');

export const buildFLUXGraph = async (
Expand Down Expand Up @@ -84,6 +86,8 @@ export const buildFLUXGraph = async (
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
g.addEdge(modelLoader, 'vae', l2i, 'vae');

addFLUXLoRAs(state, g, noise, modelLoader);

g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
Expand Down
Loading
Loading