Skip to content

Commit

Permalink
pylint fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AI-Casanova committed Nov 18, 2023
1 parent f2b66b8 commit c5371a5
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 21 deletions.
2 changes: 1 addition & 1 deletion modules/merging/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def simple_merge_key(progress, key, thetas, *args, **kwargs):
progress.update()


def merge_key(
def merge_key( # pylint: disable=inconsistent-return-statements
key: str,
thetas: Dict,
weight_matcher: WeightClass,
Expand Down
24 changes: 12 additions & 12 deletions modules/merging/merge_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
EPSILON = 1e-10 # Define a small constant EPSILON to prevent division by zero


def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor:
def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Basic Merge:
alpha 0 returns Primary Model
Expand All @@ -32,7 +32,7 @@ def weighted_sum(a: Tensor, b: Tensor, alpha: float, **kwargs) -> Tensor:
return (1 - alpha) * a + alpha * b


def weighted_subtraction(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor:
def weighted_subtraction(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
The inverse of a Weighted Sum Merge
Returns Primary Model when alpha*beta = 0
Expand All @@ -45,7 +45,7 @@ def weighted_subtraction(a: Tensor, b: Tensor, alpha: float, beta: float, **kwar
return (a - alpha * beta * b) / (1 - alpha * beta)


def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor:
def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Takes a slice of Secondary Model and pastes it into Primary Model
Alpha sets the width of the slice
Expand All @@ -65,14 +65,14 @@ def tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Ten
return tt


def add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor:
def add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Classic Add Difference Merge
"""
return a + alpha * (b - c)


def sum_twice(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor:
def sum_twice(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Stacked Basic Merge:
Equivalent to Merging Primary and Secondary @ alpha
Expand All @@ -81,7 +81,7 @@ def sum_twice(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwar
return (1 - beta) * ((1 - alpha) * a + alpha * b) + beta * c


def triple_sum(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor:
def triple_sum(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Weights Secondary and Tertiary at alpha and beta respectively
Fills in the rest with Primary
Expand All @@ -90,7 +90,7 @@ def triple_sum(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwa
return (1 - alpha - beta) * a + alpha * b + beta * c


def euclidean_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor:
def euclidean_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Subtract Primary and Secondary from Tertiary
Compare the remainders via Euclidean distance
Expand All @@ -111,7 +111,7 @@ def euclidean_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, **kw
return c + distance / torch.linalg.norm(distance) * target_norm


def multiply_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor:
def multiply_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Similar to Add Difference but with geometric mean instead of arithmatic mean
"""
Expand All @@ -121,7 +121,7 @@ def multiply_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: flo
return c + difference.to(c.dtype)


def top_k_tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor:
def top_k_tensor_sum(a: Tensor, b: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Redistributes the largest weights of Secondary Model into Primary Model
"""
Expand Down Expand Up @@ -173,7 +173,7 @@ def ratio_to_region(width: float, offset: float, n: int) -> Tuple[int, int, bool
return round(start), round(end), inverted


def similarity_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor:
def similarity_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
Weighted Sum where A and B are similar and Add Difference where A and B are dissimilar
"""
Expand All @@ -186,7 +186,7 @@ def similarity_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, bet
return (1 - similarity) * ab_diff + similarity * ab_sum


def distribution_crossover(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs):
def distribution_crossover(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs): # pylint: disable=unused-argument
"""
From the creator:
It's Primary high-passed + Secondary low-passed. Takes the fourrier transform of the weights of
Expand Down Expand Up @@ -218,7 +218,7 @@ def distribution_crossover(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta:
return x_values.reshape_as(a)


def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor:
def ties_add_difference(a: Tensor, b: Tensor, c: Tensor, alpha: float, beta: float, **kwargs) -> Tensor: # pylint: disable=unused-argument
"""
An implementation of arXiv:2306.01708
"""
Expand Down
14 changes: 7 additions & 7 deletions modules/merging/merge_rebasin.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:


def sdunet_permutation_spec() -> PermutationSpec:
conv = lambda name, p_in, p_out: {
conv = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment
f"{name}.weight": (
p_out,
p_in,
),
f"{name}.bias": (p_out,),
}
norm = lambda name, p: {f"{name}.weight": (p,), f"{name}.bias": (p,)}
norm = lambda name, p: {f"{name}.weight": (p,), f"{name}.bias": (p,)} # pylint: disable=unnecessary-lambda-assignment
dense = (
lambda name, p_in, p_out, bias=True: {
lambda name, p_in, p_out, bias=True: { # pylint: disable=unnecessary-lambda-assignment
f"{name}.weight": (p_out, p_in),
f"{name}.bias": (p_out,),
}
if bias
else {f"{name}.weight": (p_out, p_in)}
)
skip = lambda name, p_in, p_out: {
skip = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment
f"{name}": (
p_out,
p_in,
Expand All @@ -57,7 +57,7 @@ def sdunet_permutation_spec() -> PermutationSpec:
}

# Unet Res blocks
easyblock = lambda name, p_in, p_out: {
easyblock = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment
**norm(f"{name}.in_layers.0", p_in),
**conv(f"{name}.in_layers.2", p_in, f"P_{name}_inner"),
**dense(
Expand All @@ -68,15 +68,15 @@ def sdunet_permutation_spec() -> PermutationSpec:
}

# Text Encoder blocks
easyblock2 = lambda name, p: {
easyblock2 = lambda name, p: { # pylint: disable=unnecessary-lambda-assignment
**norm(f"{name}.norm1", p),
**conv(f"{name}.conv1", p, f"P_{name}_inner"),
**norm(f"{name}.norm2", f"P_{name}_inner"),
**conv(f"{name}.conv2", f"P_{name}_inner", p),
}

# This is for blocks that use a residual connection, but change the number of channels via a Conv.
shortcutblock = lambda name, p_in, p_out: {
shortcutblock = lambda name, p_in, p_out: { # pylint: disable=unnecessary-lambda-assignment
**norm(f"{name}.norm1", p_in),
**conv(f"{name}.conv1", p_in, f"P_{name}_inner"),
**norm(f"{name}.norm2", f"P_{name}_inner"),
Expand Down
1 change: 0 additions & 1 deletion modules/merging/merge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,3 @@ def step_weights_and_bases(self, ratio):

def set_it(self, it):
self.it = it
return

0 comments on commit c5371a5

Please sign in to comment.