Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add detection for PLKSR's lk_type parameter #264

Merged
merged 2 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import math
from typing import Union
from typing import Literal, Sequence, Union

from typing_extensions import override

Expand All @@ -20,7 +20,11 @@ def __init__(self) -> None:
id="PLKSR",
detect=KeyCondition.has_all(
"feats.0.weight",
"feats.1.lk.conv.weight",
KeyCondition.has_any(
"feats.1.lk.conv.weight",
"feats.1.lk.convs.0.weight",
"feats.1.lk.mn_conv.weight",
),
"feats.1.refine.weight",
KeyCondition.has_any(
"feats.1.channe_mixer.0.weight",
Expand All @@ -38,26 +42,20 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
split_ratio = 0.25
use_ea = True

# RealPLKSR only
norm_groups = 4 # un-detectable
dropout = 0 # un-detectable

dim = state_dict["feats.0.weight"].shape[0]

total_feat_layers = get_seq_len(state_dict, "feats")
scale = math.isqrt(
state_dict[f"feats.{total_feat_layers - 1}.weight"].shape[0] // 3
)

kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2]
split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim

use_ea = "feats.1.attn.f.0.weight" in state_dict

# Yes, the normal version has this typo.
if "feats.1.channe_mixer.0.weight" in state_dict:
# Yes, the normal version has this typo.
n_blocks = total_feat_layers - 2

# ccm_type
mixer_0_shape = state_dict["feats.1.channe_mixer.0.weight"].shape[2]
mixer_2_shape = state_dict["feats.1.channe_mixer.2.weight"].shape[2]
if mixer_0_shape == 3 and mixer_2_shape == 1:
Expand All @@ -70,29 +68,65 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
raise ValueError("Unknown CCM type")
more_tags = [ccm_type]

# lk_type
lk_type: Literal["PLK", "SparsePLK", "RectSparsePLK"] = "PLK"
use_max_kernel: bool = False
sparse_kernels: Sequence[int] = [5, 5, 5, 5]
sparse_dilations: Sequence[int] = [1, 2, 3, 4]
with_idt: bool = False # undetectable

if "feats.1.lk.conv.weight" in state_dict:
# PLKConv2d
lk_type = "PLK"
kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2]
split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim
elif "feats.1.lk.convs.0.weight" in state_dict:
# SparsePLKConv2d
lk_type = "SparsePLK"
split_ratio = state_dict["feats.1.lk.convs.0.weight"].shape[0] / dim
# Detecting other parameters for SparsePLKConv2d is praticaly impossible.
# We cannot detect the values of sparse_dilations at all, we only know it has the same length as sparse_kernels.
# Detecting the values of sparse_kernels is possible, but we don't know its length exactly, because it's `len(sparse_kernels) = len(convs) - (1 if use_max_kernel else 0)`.
# However, we cannot detect use_max_kernel, because the convolutions it adds when enabled look the same as the other convolutions.
# So I give up.
elif "feats.1.lk.mn_conv.weight" in state_dict:
# RectSparsePLKConv2d
lk_type = "RectSparsePLK"
kernel_size = state_dict["feats.1.lk.mn_conv.weight"].shape[2]
split_ratio = state_dict["feats.1.lk.mn_conv.weight"].shape[0] / dim
else:
raise ValueError("Unknown LK type")

model = PLKSR(
dim=dim,
upscaling_factor=scale,
n_blocks=n_blocks,
upscaling_factor=scale,
ccm_type=ccm_type,
kernel_size=kernel_size,
split_ratio=split_ratio,
lk_type=lk_type,
use_max_kernel=use_max_kernel,
sparse_kernels=sparse_kernels,
sparse_dilations=sparse_dilations,
with_idt=with_idt,
use_ea=use_ea,
ccm_type=ccm_type,
)
# and RealPLKSR doesn't. This makes it really convenient to detect.
elif "feats.1.channel_mixer.0.weight" in state_dict:
# and RealPLKSR doesn't. This makes it really convenient to detect.
more_tags = ["Real"]

n_blocks = total_feat_layers - 3
kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2]
split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim

model = RealPLKSR(
dim=dim,
upscaling_factor=scale,
n_blocks=n_blocks,
kernel_size=kernel_size,
split_ratio=split_ratio,
use_ea=use_ea,
norm_groups=norm_groups,
dropout=dropout,
norm_groups=4, # un-detectable
)
else:
raise ValueError("Unknown model type")
Expand Down
16 changes: 12 additions & 4 deletions tests/test_PLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
def test_load():
assert_loads_correctly(
PLKSRArch(),
# PLKSR
lambda: PLKSR(),
lambda: PLKSR(dim=32),
lambda: PLKSR(dim=96),
Expand All @@ -26,11 +27,18 @@ def test_load():
lambda: PLKSR(ccm_type="DCCM"),
lambda: PLKSR(ccm_type="CCM"),
lambda: PLKSR(ccm_type="ICCM"),
lambda: PLKSR(kernel_size=9),
lambda: PLKSR(kernel_size=27),
lambda: PLKSR(split_ratio=0.5),
lambda: PLKSR(split_ratio=0.75),
lambda: PLKSR(lk_type="PLK", kernel_size=9),
lambda: PLKSR(lk_type="PLK", kernel_size=27),
lambda: PLKSR(lk_type="PLK", split_ratio=0.5),
lambda: PLKSR(lk_type="PLK", split_ratio=0.75),
lambda: PLKSR(lk_type="RectSparsePLK", kernel_size=9),
lambda: PLKSR(lk_type="RectSparsePLK", kernel_size=27),
lambda: PLKSR(lk_type="RectSparsePLK", split_ratio=0.5),
lambda: PLKSR(lk_type="RectSparsePLK", split_ratio=0.75),
lambda: PLKSR(lk_type="SparsePLK", split_ratio=0.5),
lambda: PLKSR(lk_type="SparsePLK", split_ratio=0.75),
lambda: PLKSR(use_ea=False),
# RealPLKSR
lambda: RealPLKSR(),
lambda: RealPLKSR(dim=32),
lambda: RealPLKSR(dim=96),
Expand Down
Loading