Skip to content

Commit

Permalink
deal with conv attention shape, rotate centroids
Browse files Browse the repository at this point in the history
  • Loading branch information
ljleb committed Nov 13, 2023
1 parent 0d5160b commit 1920496
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions sd_meh/merge_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,29 +212,22 @@ def filter_top_k(a: Tensor, k: float):


def rotate(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs):
if len(a.shape) == 0 or (a == b).all():
if len(a.shape) == 0 or (len(a.shape) == 4 and a.shape[-1] != 1) or (a == b).all():
return a

if len(a.shape) == 4: # conv
# ideally we should stack each n x n kernel into 1D for more freedom
# however, this brings the number of dimensions of the covariance matrix
# to a very high number for some layers (> 10k x 10k)
# SVD is not practical in these cases
# so instead, we break down the conv kernel into individual input features
# to lock their angles and distances along with the
a_2d = a.permute(0, 2, 3, 1).reshape(-1, a.shape[1]).float()
b_2d = b.permute(0, 2, 3, 1).reshape(-1, a.shape[1]).float()

def reshape_fn(m):
m = m.reshape(a.shape[0], a.shape[2], a.shape[3], a.shape[1])
m = m.permute(0, 3, 1, 2)
return m.contiguous() # apparently needed for saving
if len(a.shape) == 4:
shape_2d = (-1, a.shape[1])
else:
a_2d = a.reshape(-1, a.shape[-1]).float()
b_2d = b.reshape(-1, b.shape[-1]).float()
shape_2d = (-1, a.shape[-1])

def reshape_fn(m):
return m.reshape_as(a)
a_2d = a.reshape(*shape_2d).double()
b_2d = b.reshape(*shape_2d).double()

a_centroid = a_2d.mean(0)
b_centroid = b_2d.mean(0)
new_centroid = sample_ellipsis(a_centroid, b_centroid, 2 * torch.pi * alpha)
a_2d -= a_centroid
b_2d -= b_centroid

svd_driver = "gesvd" if a.is_cuda else None
u, _, v_t = torch.linalg.svd(a_2d.T @ b_2d, driver=svd_driver)
Expand All @@ -253,11 +246,19 @@ def reshape_fn(m):
a_2d = weighted_sum(a_2d, b_2d @ rotation.T, beta)

a_2d @= transform
return reshape_fn(a_2d).to(dtype=a.dtype)
a_2d += new_centroid
return a_2d.reshape_as(a).to(dtype=a.dtype)


def fractional_matrix_power(matrix: Tensor, power: float):
eigenvalues, eigenvectors = torch.linalg.eig(matrix.double())
eigenvalues.pow_(power)
result = eigenvectors @ torch.diag(eigenvalues) @ torch.linalg.inv(eigenvectors)
return result.real.to(dtype=matrix.dtype)


def sample_ellipsis(a, b, t):
return torch.column_stack((a, b)) @ torch.tensor([
math.sin(t),
math.cos(t),
], dtype=a.dtype, device=a.device)

0 comments on commit 1920496

Please sign in to comment.