Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LoRA/LyCORIS loading for SD and SDXL in the ONNX pipeline #1321

Open
ssube opened this issue Aug 29, 2023 · 0 comments
Open

LoRA/LyCORIS loading for SD and SDXL in the ONNX pipeline #1321

ssube opened this issue Aug 29, 2023 · 0 comments

Comments

@ssube
Copy link
Contributor

ssube commented Aug 29, 2023

Feature request

I have some code for loading LoRA (and LyCORIS, LoHA, etc) weights into ONNX models at runtime, without converting them to/from PyTorch again or writing anything to disk. It doesn't look like that is currently supported, so I would like to contribute that feature if possible, but I'm not sure the best way to go about that.

This code is working and has been tested with many models and LoRA files from Civitai. There are a few more exotic LyCORIS variants that are not fully supported yet, but LoHA seems to work pretty well. I have support for SDXL on a branch, almost ready to merge, which is why I am opening this issue now.

The LoRA safetensors do not need to be converted to ONNX ahead of time. Support is limited to LoRA weights trained by the kohya-ss scripts or some descendant of those, the cloneofsimo scripts remove the node names from the LoRA, making it very difficult to match up the weights with the correct ONNX nodes.

Motivation

This is supported for the other non-ONNX pipelines, and would be cool to have as part of the ONNX pipeline as well. diffusers has a LoRA loader mixin, which seems like the right way to structure this code? That's the part I'm not really sure about.

Blending the base model and LoRA weights in-memory without writing anything back to disk was super important for a few users with HDDs, and by externalizing the weight initializers in memory, it's possible to do regular numpy maths on them before loading the model into ORT: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/diffusers/load.py#L258-L276

The state dicts are loaded from the LoRA normally, but then I use the node names from both the LoRA and ONNX models to find the correct MatMul initializer, which is always an input to that node: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py#L358-L364

The protobuf-based structure of ONNX model graphs makes this a little bit difficult, but since the number of nodes does not change, they can be replaced in-place, which works just fine: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py#L429-L430

See ssube/onnx-web#213 and microsoft/onnxruntime#15024 for some more background and context. If you have any questions, I'm happy to chat about this on Discord as well.

Your contribution

I have the code and it works, it just needs some cleanup and I'm not sure where you would want it.

The code is in here: https://github.com/ssube/onnx-web/blob/main/api/onnx_web/convert/diffusion/lora.py (SDXL is still on a branch)

It currently looks something like:

# LoRA blending
if loras is not None and len(loras) > 0:
    lora_names, lora_weights = zip(*loras)
    lora_models = [
        path.join(server.model_path, "lora", name) for name in lora_names
    ]
    logger.info(
        "blending base model %s with LoRA models: %s", model, lora_models
    )

    # blend and load text encoder
    text_encoder = text_encoder or path.join(model, "text_encoder", ONNX_MODEL)
    text_encoder = blend_loras(
        server,
        text_encoder,
        list(zip(lora_models, lora_weights)),
        "text_encoder",
        1 if params.is_xl() else None,
    )
    (text_encoder, text_encoder_data) = buffer_external_data_tensors(
        text_encoder
    )
    text_encoder_names, text_encoder_values = zip(*text_encoder_data)
    text_encoder_opts = device.sess_options(cache=False)
    text_encoder_opts.add_external_initializers(
        list(text_encoder_names), list(text_encoder_values)
    )

    if params.is_xl():
        text_encoder_session = InferenceSession(
            text_encoder.SerializeToString(),
            providers=[device.ort_provider("text-encoder")],
            sess_options=text_encoder_opts,
        )
        text_encoder_session._model_path = path.join(model, "text_encoder")
        components["text_encoder"] = ORTModelTextEncoder(
            text_encoder_session, text_encoder
        )
    else:
        components["text_encoder"] = OnnxRuntimeModel(
            OnnxRuntimeModel.load_model(
                text_encoder.SerializeToString(),
                provider=device.ort_provider("text-encoder"),
                sess_options=text_encoder_opts,
            )
        )

    if params.is_xl():
        text_encoder2 = path.join(model, "text_encoder_2", ONNX_MODEL)
        text_encoder2 = blend_loras(
            server,
            text_encoder2,
            list(zip(lora_models, lora_weights)),
            "text_encoder",
            2,
        )
        (text_encoder2, text_encoder2_data) = buffer_external_data_tensors(
            text_encoder2
        )
        text_encoder2_names, text_encoder2_values = zip(*text_encoder2_data)
        text_encoder2_opts = device.sess_options(cache=False)
        text_encoder2_opts.add_external_initializers(
            list(text_encoder2_names), list(text_encoder2_values)
        )

        text_encoder2_session = InferenceSession(
            text_encoder2.SerializeToString(),
            providers=[device.ort_provider("text-encoder")],
            sess_options=text_encoder2_opts,
        )
        text_encoder2_session._model_path = path.join(model, "text_encoder_2")
        components["text_encoder_2"] = ORTModelTextEncoder(
            text_encoder2_session, text_encoder2
        )

    # blend and load unet
    unet = path.join(model, unet_type, ONNX_MODEL)
    blended_unet = blend_loras(
        server,
        unet,
        list(zip(lora_models, lora_weights)),
        "unet",
    )
    (unet_model, unet_data) = buffer_external_data_tensors(blended_unet)
    unet_names, unet_values = zip(*unet_data)
    unet_opts = device.sess_options(cache=False)
    unet_opts.add_external_initializers(list(unet_names), list(unet_values))

    if params.is_xl():
        unet_session = InferenceSession(
            unet_model.SerializeToString(),
            providers=[device.ort_provider("unet")],
            sess_options=unet_opts,
        )
        unet_session._model_path = path.join(model, "unet")
        components["unet"] = ORTModelUnet(unet_session, unet_model)
    else:
        components["unet"] = OnnxRuntimeModel(
            OnnxRuntimeModel.load_model(
                unet_model.SerializeToString(),
                provider=device.ort_provider("unet"),
                sess_options=unet_opts,
            )
        )

pipeline_class = available_pipelines.get(pipeline, OnnxStableDiffusionPipeline)
logger.debug("loading pretrained SD pipeline for %s", pipeline_class.__name__)
pipe = pipeline_class.from_pretrained(
    model,
    provider=device.ort_provider(),
    sess_options=device.sess_options(),
    safety_checker=None,
    torch_dtype=torch_dtype,
    **components,
)

There are a few different steps that I can break down into individual functions. Looking at the LoraLoaderMixin from https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py#L855, blend_loras is roughly equivalent to load_lora_into_unet and load_lora_into_text_encoder, although ORT doesn't differentiate between the two models. Once you locate the correct MatMul node and its initializer, the math is all normal and follows the algo described in https://github.com/KohakuBlueleaf/LyCORIS/blob/main/Algo.md.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant