Skip to content

Commit

Permalink
lpaps ckpt, fixes in loading state dicts (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
v-iashin committed Feb 3, 2022
1 parent a0281ca commit 3894458
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
Binary file added specvqgan/modules/autoencoder/lpaps/lpaps.pt
Binary file not shown.
14 changes: 7 additions & 7 deletions specvqgan/modules/losses/lpaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@ def __init__(self, use_dropout=True):
for param in self.parameters():
param.requires_grad = False

def load_from_pretrained(self, name="vggishish_lpaps"):
def load_from_pretrained(self, name="lpaps"):
ckpt = get_ckpt_path(name, "specvqgan/modules/autoencoder/lpaps")
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")))
print("loaded pretrained LPAPS loss from {}".format(ckpt))

@classmethod
def from_pretrained(cls, name="vggishish_lpaps"):
if name != "vggishish_lpaps":
def from_pretrained(cls, name="lpaps"):
if name != "lpaps":
raise NotImplementedError
model = cls()
ckpt = get_ckpt_path(name)
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
ckpt = get_ckpt_path(name, "specvqgan/modules/autoencoder/lpaps")
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")))
return model

def forward(self, input, target):
Expand Down Expand Up @@ -130,7 +130,7 @@ def vggishish16(self, pretrained: bool = True) -> VGGishish:
conv_layers = [64, 64, 'MP', 128, 128, 'MP', 256, 256, 256, 'MP', 512, 512, 512, 'MP', 512, 512, 512]
model = VGGishish(conv_layers, use_bn=False, num_classes=num_classes_vggsound)
if pretrained:
ckpt_path = get_ckpt_path('vggishish_lpaps', "specvqgan/modules/autoencoder/lpaps")
ckpt_path = get_ckpt_path('vggishish', "specvqgan/modules/autoencoder/lpaps")
ckpt = torch.load(ckpt_path, map_location=torch.device("cpu"))
model.load_state_dict(ckpt['model'])
return model
Expand Down
9 changes: 6 additions & 3 deletions specvqgan/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@
from tqdm import tqdm

URL_MAP = {
'vggishish_lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt',
'lpaps': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/lpaps.pt',
'vggishish': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/vggishish16.pt',
'vggishish_mean_std_melspec_10s_22050hz': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/train_means_stds_melspec_10s_22050hz.txt',
'melception': 'https://a3s.fi/swift/v1/AUTH_a235c0f452d648828f745589cde1219a/specvqgan_public/melception-21-05-10T09-28-40.pt',
}

CKPT_MAP = {
'vggishish_lpaps': 'vggishish16.pt',
'lpaps': 'lpaps.pt',
'vggishish': 'vggishish16.pt',
'vggishish_mean_std_melspec_10s_22050hz': 'train_means_stds_melspec_10s_22050hz.txt',
'melception': 'melception-21-05-10T09-28-40.pt',
}

MD5_MAP = {
'vggishish_lpaps': '197040c524a07ccacf7715d7080a80bd',
'lpaps': 'f8d4e7dba2b870222fe2bee26f85e7c9',
'vggishish': '197040c524a07ccacf7715d7080a80bd',
'vggishish_mean_std_melspec_10s_22050hz': 'f449c6fd0e248936c16f6d22492bb625',
'melception': 'a71a41041e945b457c7d3d814bbcf72d',
}
Expand Down

0 comments on commit 3894458

Please sign in to comment.