Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SRFormer parameter detection #68

Merged
merged 1 commit into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 83 additions & 119 deletions src/spandrel/architectures/SRFormer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import math
import re

from torch import nn

from ...__helpers.model_descriptor import (
ImageModelDescriptor,
SizeRequirements,
StateDict,
)
from ...architectures.__arch_helpers.state import get_seq_len
from .arch.SRFormer import SRFormer


Expand All @@ -22,138 +20,106 @@ def load(state_dict: StateDict) -> ImageModelDescriptor[SRFormer]:
window_size = 7
mlp_ratio = 4.0
qkv_bias = True
qk_scale = None
drop_rate = 0.0
attn_drop_rate = 0.0
drop_path_rate = 0.1
norm_layer = nn.LayerNorm
qk_scale = None # cannot be deduced from state_dict
drop_rate = 0.0 # cannot be deduced from state_dict
attn_drop_rate = 0.0 # cannot be deduced from state_dict
drop_path_rate = 0.1 # cannot be deduced from state_dict
ape = False
patch_norm = True
upscale = 2
img_range = 1.0
upsampler = ""
resi_connection = "1conv"

state = state_dict
in_chans = state_dict["conv_first.weight"].shape[1]
embed_dim = state_dict["conv_first.weight"].shape[0]

ape = "absolute_pos_embed" in state_dict
patch_norm = "patch_embed.norm.weight" in state_dict
qkv_bias = "layers.0.residual_group.blocks.0.attn.q.bias" in state_dict

state_keys = list(state_dict.keys())
mlp_ratio = float(
state_dict["layers.0.residual_group.blocks.0.mlp.fc1.weight"].shape[0]
/ embed_dim
)

if "conv_before_upsample.0.weight" in state_keys:
if "conv_up1.weight" in state_keys:
upsampler = "nearest+conv"
else:
upsampler = "pixelshuffle"
elif "upsample.0.weight" in state_keys:
# depths & num_heads
num_layers = get_seq_len(state_dict, "layers")
depths = [6] * num_layers
num_heads = [6] * num_layers
for i in range(num_layers):
depths[i] = get_seq_len(state_dict, f"layers.{i}.residual_group.blocks")
num_heads[i] = state_dict[
f"layers.{i}.residual_group.blocks.0.attn.relative_position_bias_table"
].shape[1]

if "conv_hr.weight" in state_dict:
upsampler = "nearest+conv"
upscale = 4 # only supported scale
elif "conv_before_upsample.0.weight" in state_dict:
upsampler = "pixelshuffle"

num_feat = 64 # hard-coded constant
upscale = 1
for i in range(0, get_seq_len(state_dict, "upsample"), 2):
shape = state_dict[f"upsample.{i}.weight"].shape[0]
upscale *= int(math.sqrt(shape // num_feat))
elif "upsample.0.weight" in state_dict:
upsampler = "pixelshuffledirect"
upscale = int(math.sqrt(state_dict["upsample.0.weight"].shape[0] // in_chans))
else:
upsampler = ""
upscale = 1 # it's technically undefined, but we'll use 1

num_feat_pre_layer = state_dict.get("conv_before_upsample.weight", None)
num_feat_layer = state_dict.get("conv_before_upsample.0.weight", None)
num_feat = (
num_feat_layer.shape[1]
if num_feat_layer is not None and num_feat_pre_layer is not None
else 64
)

num_in_ch = state_dict["conv_first.weight"].shape[1]
in_chans = num_in_ch
if "conv_last.weight" in state_keys:
num_out_ch = state_dict["conv_last.weight"].shape[0]
else:
num_out_ch = num_in_ch

upscale = 1
if upsampler == "nearest+conv":
upsample_keys = [x for x in state_keys if "conv_up" in x and "bias" not in x]

for upsample_key in upsample_keys:
upscale *= 2
elif upsampler == "pixelshuffle":
upsample_keys = [
x
for x in state_keys
if "upsample" in x and "conv" not in x and "bias" not in x
]
for upsample_key in upsample_keys:
shape = state_dict[upsample_key].shape[0]
upscale *= math.sqrt(shape // num_feat)
upscale = int(upscale)
elif upsampler == "pixelshuffledirect":
upscale = int(math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch))

max_layer_num = 0
max_block_num = 0
for key in state_keys:
result = re.match(r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key)
if result:
layer_num, block_num = result.groups()
max_layer_num = max(max_layer_num, int(layer_num))
max_block_num = max(max_block_num, int(block_num))

depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]

if (
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
in state_keys
):
num_heads_num = state_dict[
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
].shape[-1]
num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
if "conv_after_body.weight" in state_dict:
resi_connection = "1conv"
else:
num_heads = depths

embed_dim = state_dict["conv_first.weight"].shape[0]

mlp_ratio = float(
state_dict["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0] / embed_dim
)

# TODO: could actually count the layers, but this should do
# TOOD: confirm this is correct and the same as SwinIR
if "layers.0.conv.4.weight" in state_keys:
resi_connection = "3conv"
else:
resi_connection = "1conv"

window_size = int(
math.sqrt(
state_dict[
"layers.0.residual_group.blocks.0.attn.aligned_relative_position_index"
].shape[0]
)
)

if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
img_size = int(
(
math.sqrt(
state_dict["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
)
window_size = (
int(
math.sqrt(
state_dict[
"layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
].shape[0]
)
* 16
)
+ 1
)

in_nc = num_in_ch
out_nc = num_out_ch
num_feat = num_feat
embed_dim = embed_dim
num_heads = num_heads
depths = depths
window_size = window_size
mlp_ratio = mlp_ratio
scale = upscale
upsampler = upsampler
img_size = img_size
img_range = img_range
resi_connection = resi_connection
# Unfortunately, we cannot detect img_size and patch_size, but we can detect
# patches_resolution. What we know:
# patches_resolution = img_size // patch_size
# if window_size > patches_resolution:
# attn_mask[0] = patches_resolution**2 // window_size**2
# We will assume that we already know the patch_size (we don't, we'll assume the default value).
if "layers.0.residual_group.blocks.1.attn_mask" in state_dict:
attn_mask_0 = state_dict["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
patches_resolution = int(math.sqrt(attn_mask_0 * window_size * window_size))
else:
# we only know that window_size <= patches_resolution
# assume window_size == patches_resolution
patches_resolution = window_size

# if APE is enabled, we know that absolute_pos_embed[1] == patches_resolution**2
if ape:
patches_resolution = int(math.sqrt(state_dict["absolute_pos_embed"][1]))
img_size = patch_size * patches_resolution
# Further, img_size is actually rounded up to the nearest multiple of window_size
# before calculating patches_resolution. We have to do a bit of guess to get
# the actual img_size...
for nice_number in [512, 256, 128, 96, 64, 48, 32, 24, 16]:
rounded = nice_number
if rounded % window_size != 0:
rounded = rounded + (window_size - rounded % window_size)
if rounded // patch_size == patches_resolution:
img_size = nice_number
break

model = SRFormer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
num_feat=num_feat,
embed_dim=embed_dim,
depths=depths,
num_heads=num_heads,
Expand All @@ -164,10 +130,9 @@ def load(state_dict: StateDict) -> ImageModelDescriptor[SRFormer]:
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
ape=ape,
patch_norm=patch_norm,
upscale=scale,
upscale=upscale,
upsampler=upsampler,
img_range=img_range,
resi_connection=resi_connection,
Expand All @@ -183,21 +148,20 @@ def load(state_dict: StateDict) -> ImageModelDescriptor[SRFormer]:
tags = [
size_tag,
f"s{img_size}w{window_size}",
f"{num_feat}nf",
f"{embed_dim}dim",
f"{resi_connection}",
]

return ImageModelDescriptor(
model,
state,
state_dict,
architecture="SRFormer",
purpose="Restoration" if scale == 1 else "SR",
purpose="Restoration" if upscale == 1 else "SR",
tags=tags,
supports_half=False, # Too much weirdness to support this at the moment
supports_bfloat16=True,
scale=scale,
input_channels=in_nc,
output_channels=out_nc,
scale=upscale,
input_channels=in_chans,
output_channels=in_chans,
size_requirements=SizeRequirements(minimum=16),
)
1 change: 0 additions & 1 deletion src/spandrel/architectures/SRFormer/arch/SRFormer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,6 @@ def __init__(
img_range=1.0,
upsampler="",
resi_connection="1conv",
**kwargs,
):
super().__init__()
num_in_ch = in_chans
Expand Down
55 changes: 55 additions & 0 deletions tests/__snapshots__/test_SRFormer.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# serializer version: 1
# name: test_SRFormerLight_SRx3_DIV2K
ImageModelDescriptor(
architecture='SRFormer',
input_channels=3,
output_channels=3,
purpose='SR',
scale=3,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'small',
's64w16',
'60dim',
'1conv',
]),
)
# ---
# name: test_SRFormer_SRx2_DF2K
ImageModelDescriptor(
architecture='SRFormer',
input_channels=3,
output_channels=3,
purpose='SR',
scale=2,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'medium',
's64w22',
'180dim',
'1conv',
]),
)
# ---
# name: test_SRFormer_SRx4_DF2K
ImageModelDescriptor(
architecture='SRFormer',
input_channels=3,
output_channels=3,
purpose='SR',
scale=4,
size_requirements=SizeRequirements(minimum=16, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
tags=list([
'medium',
's64w22',
'180dim',
'1conv',
]),
)
# ---
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/16x16/SRFormer_SRx2_DF2K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/16x16/SRFormer_SRx4_DF2K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/32x32/SRFormer_SRx2_DF2K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/32x32/SRFormer_SRx4_DF2K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/64x64/SRFormer_SRx2_DF2K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/images/outputs/64x64/SRFormer_SRx4_DF2K.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading