Skip to content

Commit

Permalink
refactor: remove unused function
Browse files Browse the repository at this point in the history
  • Loading branch information
v0xie committed Nov 4, 2023
1 parent 329c8ba commit bbf00a9
Showing 1 changed file with 0 additions and 47 deletions.
47 changes: 0 additions & 47 deletions extensions-builtin/Lora/network_oft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import network
from lyco_helpers import factorization
from einops import rearrange
from modules import devices


class ModuleTypeOFT(network.ModuleType):
Expand Down Expand Up @@ -54,58 +53,12 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
raise ValueError("sd_module must be Linear or Conv")

if self.is_kohya:
#self.num_blocks = self.dim
#self.block_size = self.out_dim // self.num_blocks
#self.block_size = self.dim
#self.num_blocks = self.out_dim // self.block_size
self.constraint = self.alpha * self.out_dim
self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
else:
self.constraint = None
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)

if is_other_linear:
self.lin_module = self.create_module(weights.w, "oft_diag", none_ok=True)


def create_module(self, weights, key, none_ok=False):
weight = weights.get(key)

if weight is None and none_ok:
return None

is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]

if is_linear:
weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
if len(weight.shape) == 2:
weight = weight.reshape(weight.shape[0], -1, 1, 1)

if weight.shape[2] != 1 or weight.shape[3] != 1:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
else:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif is_conv and key == "lora_mid.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')

with torch.no_grad():
if weight.shape != module.weight.shape:
weight = weight.reshape(module.weight.shape)
module.weight.copy_(weight)

module.to(device=devices.cpu, dtype=devices.dtype)
module.weight.requires_grad_(False)

return module


def merge_weight(self, R_weight, org_weight):
R_weight = R_weight.to(org_weight.device, dtype=org_weight.dtype)
if org_weight.dim() == 4:
Expand Down

0 comments on commit bbf00a9

Please sign in to comment.