Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <[email protected]>
  • Loading branch information
vladmandic committed Sep 30, 2024
1 parent ac58b21 commit 214b9a2
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions extensions-builtin/Lora/lora_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def __init__(self):
self.LayerNorm_load_state_dict = None
self.MultiheadAttention_forward = None
self.MultiheadAttention_load_state_dict = None
self.Linear4bit_forward = None
# optional quant forwards
self.Linear4bit_forward = None # bitsandbytes
self.QLinear_forward = None # optimum.quanto
self.QConv2d_forward = None # optimum.quanto

def apply(self):
if self.active or shared.opts.lora_force_diffusers:
Expand All @@ -28,8 +31,8 @@ def apply(self):
pass
try:
from optimum import quanto # pylint: disable=no-name-in-module
self.QLinear_forward = patches.patch(__name__, quanto.nn.QLinear, 'forward', networks.network_QLinear_forward) # pylint: disable=attribute-defined-outside-init
self.QConv2d_forward = patches.patch(__name__, quanto.nn.QConv2d, 'forward', networks.network_QConv2d_forward) # pylint: disable=attribute-defined-outside-init
self.QLinear_forward = patches.patch(__name__, quanto.nn.QLinear, 'forward', networks.network_QLinear_forward)
self.QConv2d_forward = patches.patch(__name__, quanto.nn.QConv2d, 'forward', networks.network_QConv2d_forward)
except Exception:
pass
self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
Expand Down Expand Up @@ -57,8 +60,8 @@ def undo(self):
pass
try:
from optimum import quanto # pylint: disable=no-name-in-module
self.QLinear_forward = patches.undo(__name__, quanto.nn.QLinear, 'forward') # pylint: disable=E1128, attribute-defined-outside-init
self.QConv2d_forward = patches.undo(__name__, quanto.nn.QConv2d, 'forward') # pylint: disable=E1128, attribute-defined-outside-init
self.QLinear_forward = patches.undo(__name__, quanto.nn.QLinear, 'forward') # pylint: disable=E1128
self.QConv2d_forward = patches.undo(__name__, quanto.nn.QConv2d, 'forward') # pylint: disable=E1128
except Exception:
pass
self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') # pylint: disable=E1128
Expand Down

0 comments on commit 214b9a2

Please sign in to comment.