diff --git a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py index 5b7ea009..3b04bf58 100644 --- a/libs/spandrel/spandrel/architectures/PLKSR/__init__.py +++ b/libs/spandrel/spandrel/architectures/PLKSR/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations import math -from typing import Union +from typing import Literal, Sequence, Union from typing_extensions import override @@ -20,7 +20,11 @@ def __init__(self) -> None: id="PLKSR", detect=KeyCondition.has_all( "feats.0.weight", - "feats.1.lk.conv.weight", + KeyCondition.has_any( + "feats.1.lk.conv.weight", + "feats.1.lk.convs.0.weight", + "feats.1.lk.mn_conv.weight", + ), "feats.1.refine.weight", KeyCondition.has_any( "feats.1.channe_mixer.0.weight", @@ -38,10 +42,6 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: split_ratio = 0.25 use_ea = True - # RealPLKSR only - norm_groups = 4 # un-detectable - dropout = 0 # un-detectable - dim = state_dict["feats.0.weight"].shape[0] total_feat_layers = get_seq_len(state_dict, "feats") @@ -49,15 +49,13 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: state_dict[f"feats.{total_feat_layers - 1}.weight"].shape[0] // 3 ) - kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2] - split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim - use_ea = "feats.1.attn.f.0.weight" in state_dict - # Yes, the normal version has this typo. if "feats.1.channe_mixer.0.weight" in state_dict: + # Yes, the normal version has this typo. n_blocks = total_feat_layers - 2 + # ccm_type mixer_0_shape = state_dict["feats.1.channe_mixer.0.weight"].shape[2] mixer_2_shape = state_dict["feats.1.channe_mixer.2.weight"].shape[2] if mixer_0_shape == 3 and mixer_2_shape == 1: @@ -70,20 +68,57 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: raise ValueError("Unknown CCM type") more_tags = [ccm_type] + # lk_type + lk_type: Literal["PLK", "SparsePLK", "RectSparsePLK"] = "PLK" + use_max_kernel: bool = False + sparse_kernels: Sequence[int] = [5, 5, 5, 5] + sparse_dilations: Sequence[int] = [1, 2, 3, 4] + with_idt: bool = False # undetectable + + if "feats.1.lk.conv.weight" in state_dict: + # PLKConv2d + lk_type = "PLK" + kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2] + split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim + elif "feats.1.lk.convs.0.weight" in state_dict: + # SparsePLKConv2d + lk_type = "SparsePLK" + split_ratio = state_dict["feats.1.lk.convs.0.weight"].shape[0] / dim + # Detecting other parameters for SparsePLKConv2d is praticaly impossible. + # We cannot detect the values of sparse_dilations at all, we only know it has the same length as sparse_kernels. + # Detecting the values of sparse_kernels is possible, but we don't know its length exactly, because it's `len(sparse_kernels) = len(convs) - (1 if use_max_kernel else 0)`. + # However, we cannot detect use_max_kernel, because the convolutions it adds when enabled look the same as the other convolutions. + # So I give up. + elif "feats.1.lk.mn_conv.weight" in state_dict: + # RectSparsePLKConv2d + lk_type = "RectSparsePLK" + kernel_size = state_dict["feats.1.lk.mn_conv.weight"].shape[2] + split_ratio = state_dict["feats.1.lk.mn_conv.weight"].shape[0] / dim + else: + raise ValueError("Unknown LK type") + model = PLKSR( dim=dim, - upscaling_factor=scale, n_blocks=n_blocks, + upscaling_factor=scale, + ccm_type=ccm_type, kernel_size=kernel_size, split_ratio=split_ratio, + lk_type=lk_type, + use_max_kernel=use_max_kernel, + sparse_kernels=sparse_kernels, + sparse_dilations=sparse_dilations, + with_idt=with_idt, use_ea=use_ea, - ccm_type=ccm_type, ) - # and RealPLKSR doesn't. This makes it really convenient to detect. elif "feats.1.channel_mixer.0.weight" in state_dict: + # and RealPLKSR doesn't. This makes it really convenient to detect. more_tags = ["Real"] n_blocks = total_feat_layers - 3 + kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2] + split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim + model = RealPLKSR( dim=dim, upscaling_factor=scale, @@ -91,8 +126,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]: kernel_size=kernel_size, split_ratio=split_ratio, use_ea=use_ea, - norm_groups=norm_groups, - dropout=dropout, + norm_groups=4, # un-detectable ) else: raise ValueError("Unknown model type") diff --git a/tests/test_PLKSR.py b/tests/test_PLKSR.py index d4029e70..63c271c7 100644 --- a/tests/test_PLKSR.py +++ b/tests/test_PLKSR.py @@ -16,6 +16,7 @@ def test_load(): assert_loads_correctly( PLKSRArch(), + # PLKSR lambda: PLKSR(), lambda: PLKSR(dim=32), lambda: PLKSR(dim=96), @@ -26,11 +27,18 @@ def test_load(): lambda: PLKSR(ccm_type="DCCM"), lambda: PLKSR(ccm_type="CCM"), lambda: PLKSR(ccm_type="ICCM"), - lambda: PLKSR(kernel_size=9), - lambda: PLKSR(kernel_size=27), - lambda: PLKSR(split_ratio=0.5), - lambda: PLKSR(split_ratio=0.75), + lambda: PLKSR(lk_type="PLK", kernel_size=9), + lambda: PLKSR(lk_type="PLK", kernel_size=27), + lambda: PLKSR(lk_type="PLK", split_ratio=0.5), + lambda: PLKSR(lk_type="PLK", split_ratio=0.75), + lambda: PLKSR(lk_type="RectSparsePLK", kernel_size=9), + lambda: PLKSR(lk_type="RectSparsePLK", kernel_size=27), + lambda: PLKSR(lk_type="RectSparsePLK", split_ratio=0.5), + lambda: PLKSR(lk_type="RectSparsePLK", split_ratio=0.75), + lambda: PLKSR(lk_type="SparsePLK", split_ratio=0.5), + lambda: PLKSR(lk_type="SparsePLK", split_ratio=0.75), lambda: PLKSR(use_ea=False), + # RealPLKSR lambda: RealPLKSR(), lambda: RealPLKSR(dim=32), lambda: RealPLKSR(dim=96),