Skip to content

Commit

Permalink
Add RealPLKSR LayerNorm (#313)
Browse files Browse the repository at this point in the history
* add realplksr layernorm

* Update __init__.py

* update snapshots
  • Loading branch information
the-database authored Jan 3, 2025
1 parent a1db3f5 commit 3a32e72
Show file tree
Hide file tree
Showing 10 changed files with 122 additions and 12 deletions.
34 changes: 29 additions & 5 deletions libs/spandrel/spandrel/architectures/PLKSR/__arch/RealPLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@
from spandrel.util import store_hyperparameters


class LayerNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.bias = nn.Parameter(torch.zeros(dim))
self.eps = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
return self.weight[:, None, None] * x + self.bias[:, None, None]


class DCCM(nn.Sequential):
"Doubled Convolutional Channel Mixer"

Expand Down Expand Up @@ -56,11 +70,15 @@ def __init__(
dim: int,
kernel_size: int,
split_ratio: float,
norm_groups: int,
use_ea: bool = True,
norm_groups: int = 4,
use_layer_norm: bool = False,
):
super().__init__()

# Layer Norm
self.layer_norm = LayerNorm(dim) if use_layer_norm else nn.Identity()

# Local Texture
self.channel_mixer = DCCM(dim)

Expand All @@ -80,11 +98,16 @@ def __init__(
self.refine = nn.Conv2d(dim, dim, 1, 1, 0)
trunc_normal_(self.refine.weight, std=0.02)

# Group Normalization
self.norm = nn.GroupNorm(norm_groups, dim)
if not use_layer_norm:
self.norm = nn.GroupNorm(norm_groups, dim)
nn.init.constant_(self.norm.bias, 0)
nn.init.constant_(self.norm.weight, 1.0)
else:
self.norm = nn.Identity()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x_skip = x
x = self.layer_norm(x)
x = self.channel_mixer(x)
x = self.lk(x)
x = self.attn(x)
Expand Down Expand Up @@ -114,6 +137,7 @@ def __init__(
norm_groups: int = 4,
dropout: float = 0,
dysample: bool = False,
layer_norm: bool = False,
):
super().__init__()

Expand All @@ -128,11 +152,11 @@ def __init__(
self.feats = nn.Sequential(
*[nn.Conv2d(in_ch, dim, 3, 1, 1)]
+ [
PLKBlock(dim, kernel_size, split_ratio, norm_groups, use_ea)
PLKBlock(dim, kernel_size, split_ratio, use_ea, norm_groups, layer_norm)
for _ in range(n_blocks)
]
+ [nn.Dropout2d(dropout)]
+ [nn.Conv2d(dim, 3 * upscaling_factor**2, 3, 1, 1)]
+ [nn.Conv2d(dim, out_ch * upscaling_factor**2, 3, 1, 1)]
)
trunc_normal_(self.feats[0].weight, std=0.02)
trunc_normal_(self.feats[-1].weight, std=0.02)
Expand Down
13 changes: 10 additions & 3 deletions libs/spandrel/spandrel/architectures/PLKSR/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import math
from typing import Literal, Sequence, Union
from collections.abc import Sequence
from typing import Literal, Union

from typing_extensions import override

Expand Down Expand Up @@ -41,6 +42,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
kernel_size = 17
split_ratio = 0.25
use_ea = True
supports_half = True

dim = state_dict["feats.0.weight"].shape[0]

Expand Down Expand Up @@ -118,10 +120,14 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
n_blocks = total_feat_layers - 3
kernel_size = state_dict["feats.1.lk.conv.weight"].shape[2]
split_ratio = state_dict["feats.1.lk.conv.weight"].shape[0] / dim

use_layer_norm = "feats.1.layer_norm.bias" in state_dict
use_dysample = "to_img.init_pos" in state_dict
if use_dysample:
more_tags.append("DySample")
if use_layer_norm:
more_tags.append("LayerNorm")
else:
supports_half = False

model = RealPLKSR(
dim=dim,
Expand All @@ -132,6 +138,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
use_ea=use_ea,
norm_groups=4, # un-detectable
dysample=use_dysample,
layer_norm=use_layer_norm,
)
else:
raise ValueError("Unknown model type")
Expand All @@ -142,7 +149,7 @@ def load(self, state_dict: StateDict) -> ImageModelDescriptor[_PLKSR]:
architecture=self,
purpose="Restoration" if scale == 1 else "SR",
tags=[f"{dim}dim", f"{n_blocks}nb", f"{kernel_size}ks", *more_tags],
supports_half=False,
supports_half=supports_half,
supports_bfloat16=True,
scale=scale,
input_channels=3,
Expand Down
54 changes: 50 additions & 4 deletions tests/__snapshots__/test_PLKSR.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
scale=4,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
supports_half=True,
tags=list([
'64dim',
'12nb',
Expand All @@ -33,7 +33,7 @@
scale=2,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
supports_half=True,
tags=list([
'64dim',
'28nb',
Expand All @@ -55,7 +55,7 @@
scale=3,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
supports_half=True,
tags=list([
'64dim',
'28nb',
Expand All @@ -77,7 +77,7 @@
scale=4,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=False,
supports_half=True,
tags=list([
'64dim',
'28nb',
Expand Down Expand Up @@ -154,3 +154,49 @@
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
# name: test_RealPLKSR_LayerNorm_2x
ImageModelDescriptor(
architecture=PLKSRArch(
id='PLKSR',
name='PLKSR',
),
input_channels=3,
output_channels=3,
purpose='SR',
scale=2,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64dim',
'28nb',
'17ks',
'Real',
'LayerNorm',
]),
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
# name: test_RealPLKSR_LayerNorm_4x
ImageModelDescriptor(
architecture=PLKSRArch(
id='PLKSR',
name='PLKSR',
),
input_channels=3,
output_channels=3,
purpose='SR',
scale=4,
size_requirements=SizeRequirements(minimum=0, multiple_of=1, square=False),
supports_bfloat16=True,
supports_half=True,
tags=list([
'64dim',
'28nb',
'17ks',
'Real',
'LayerNorm',
]),
tiling=<ModelTiling.SUPPORTED: 1>,
)
# ---
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
33 changes: 33 additions & 0 deletions tests/test_PLKSR.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_load():
lambda: RealPLKSR(split_ratio=0.75),
lambda: RealPLKSR(use_ea=False),
lambda: RealPLKSR(dysample=True),
lambda: RealPLKSR(layer_norm=True),
)


Expand All @@ -68,6 +69,12 @@ def test_size_requirements():
)
assert_size_requirements(file.load_model())

file = ModelFile.from_url(
"https://github.com/the-database/traiNNer-redux/releases/download/pretrained-models/4x_DF2K_Redux_RealPLKSRLayerNorm_50k.safetensors",
name="4x_DF2K_Redux_RealPLKSRLayerNorm_50k.safetensors",
)
assert_size_requirements(file.load_model())


def test_PLKSR_official_x4(snapshot):
file = ModelFile.from_url(
Expand Down Expand Up @@ -172,3 +179,29 @@ def test_RealPLKSR_DySample(snapshot):
model,
[TestImage.SR_16, TestImage.SR_32, TestImage.SR_64],
)


def test_RealPLKSR_LayerNorm_4x(snapshot):
file = ModelFile.from_url(
"https://github.com/the-database/traiNNer-redux/releases/download/pretrained-models/4x_DF2K_Redux_RealPLKSRLayerNorm_50k.safetensors",
name="4x_DF2K_Redux_RealPLKSRLayerNorm_50k.safetensors",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RealPLKSR)
assert_image_inference(
file, model, [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64]
)


def test_RealPLKSR_LayerNorm_2x(snapshot):
file = ModelFile.from_url(
"https://github.com/the-database/traiNNer-redux/releases/download/pretrained-models/2x_DF2K_Redux_RealPLKSRLayerNorm_450k.safetensors",
name="2x_DF2K_Redux_RealPLKSRLayerNorm_450k.safetensors",
)
model = file.load_model()
assert model == snapshot(exclude=disallowed_props)
assert isinstance(model.model, RealPLKSR)
assert_image_inference(
file, model, [TestImage.SR_16, TestImage.SR_32, TestImage.SR_64]
)

0 comments on commit 3a32e72

Please sign in to comment.