diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py index 637917dd4..575ab0348 100644 --- a/extensions-builtin/Lora/lora_patches.py +++ b/extensions-builtin/Lora/lora_patches.py @@ -16,7 +16,10 @@ def __init__(self): self.LayerNorm_load_state_dict = None self.MultiheadAttention_forward = None self.MultiheadAttention_load_state_dict = None - self.Linear4bit_forward = None + # optional quant forwards + self.Linear4bit_forward = None # bitsandbytes + self.QLinear_forward = None # optimum.quanto + self.QConv2d_forward = None # optimum.quanto def apply(self): if self.active or shared.opts.lora_force_diffusers: @@ -28,8 +31,8 @@ def apply(self): pass try: from optimum import quanto # pylint: disable=no-name-in-module - self.QLinear_forward = patches.patch(__name__, quanto.nn.QLinear, 'forward', networks.network_QLinear_forward) # pylint: disable=attribute-defined-outside-init - self.QConv2d_forward = patches.patch(__name__, quanto.nn.QConv2d, 'forward', networks.network_QConv2d_forward) # pylint: disable=attribute-defined-outside-init + self.QLinear_forward = patches.patch(__name__, quanto.nn.QLinear, 'forward', networks.network_QLinear_forward) + self.QConv2d_forward = patches.patch(__name__, quanto.nn.QConv2d, 'forward', networks.network_QConv2d_forward) except Exception: pass self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward) @@ -57,8 +60,8 @@ def undo(self): pass try: from optimum import quanto # pylint: disable=no-name-in-module - self.QLinear_forward = patches.undo(__name__, quanto.nn.QLinear, 'forward') # pylint: disable=E1128, attribute-defined-outside-init - self.QConv2d_forward = patches.undo(__name__, quanto.nn.QConv2d, 'forward') # pylint: disable=E1128, attribute-defined-outside-init + self.QLinear_forward = patches.undo(__name__, quanto.nn.QLinear, 'forward') # pylint: disable=E1128 + self.QConv2d_forward = patches.undo(__name__, quanto.nn.QConv2d, 'forward') # pylint: disable=E1128 except Exception: pass self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') # pylint: disable=E1128