Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about missing key of model weights #20

Open
NKUShaw opened this issue Sep 9, 2024 · 0 comments
Open

Question about missing key of model weights #20

NKUShaw opened this issue Sep 9, 2024 · 0 comments

Comments

@NKUShaw
Copy link

NKUShaw commented Sep 9, 2024

Try to run the code: "
import argparse
import torch
import torchvision

from ddpm import script_utils

python ddpmsample.py --model_path='./ddpm_logs/cifar10-ddpm-2024-09-09-02-56-iteration-13000-optim.pth' --save_dir='outputs' --use_labels=True

def main():
args = create_argparser().parse_args()
device = args.device

try:
    diffusion = script_utils.get_diffusion_from_args(args).to(device)
    diffusion.load_state_dict(torch.load(args.model_path))
    if args.use_labels:
        for label in range(10):
            y = torch.ones(args.num_images // 10, dtype=torch.long, device=device) * label
            samples = diffusion.sample(args.num_images // 10, device, y=y)
            
            for image_id in range(len(samples)):
                image = ((samples[image_id] + 1) / 2).clip(0, 1)
                image = torchvision.transforms.Resize((128, 128))(image)  
                torchvision.utils.save_image(image, f"{args.save_dir}/{label}-{image_id}.png")
    else:
        samples = diffusion.sample(args.num_images, device)
        
        for image_id in range(len(samples)):
            image = ((samples[image_id] + 1) / 2).clip(0, 1)
            image = torchvision.transforms.Resize((128, 128))(image) 
            torchvision.utils.save_image(image, f"{args.save_dir}/{image_id}.png")
except KeyboardInterrupt:
    print("Keyboard interrupt, generation finished early")

def create_argparser():
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")
print(device)
defaults = dict(num_images=100,
device=device,
schedule_low=1e-4,
schedule_high=0.02,
)
defaults.update(script_utils.diffusion_defaults())

parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default='./ddpm_logs/cifar10-ddpm-2024-09-09-02-56-iteration-279000-model.pth')
parser.add_argument("--save_dir", type=str, default='./outputs')
script_utils.add_dict_to_argparser(parser, defaults)
return parser

if name == "main":
main()
"

Traceback (most recent call last):
File "XXX/MyMethod/ddpmsample.py", line 54, in
main()
File "XXX/MyMethod/ddpmsample.py", line 14, in main
diffusion.load_state_dict(torch.load(args.model_path))
File "XXX/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GaussianDiffusion:
Missing key(s) in state_dict: "model.downs.0.class_bias.weight", "model.downs.1.class_bias.weight", "model.downs.3.class_bias.weight", "model.downs.4.class_bias.weight", "model.downs.6.class_bias.weight", "model.downs.7.class_bias.weight", "model.downs.9.class_bias.weight", "model.downs.10.class_bias.weight", "model.ups.0.class_bias.weight", "model.ups.1.class_bias.weight", "model.ups.2.class_bias.weight", "model.ups.4.class_bias.weight", "model.ups.5.class_bias.weight", "model.ups.6.class_bias.weight", "model.ups.8.class_bias.weight", "model.ups.9.class_bias.weight", "model.ups.10.class_bias.weight", "model.ups.12.class_bias.weight", "model.ups.13.class_bias.weight", "model.ups.14.class_bias.weight", "model.mid.0.class_bias.weight", "model.mid.1.class_bias.weight", "ema_model.downs.0.class_bias.weight", "ema_model.downs.1.class_bias.weight", "ema_model.downs.3.class_bias.weight", "ema_model.downs.4.class_bias.weight", "ema_model.downs.6.class_bias.weight", "ema_model.downs.7.class_bias.weight", "ema_model.downs.9.class_bias.weight", "ema_model.downs.10.class_bias.weight", "ema_model.ups.0.class_bias.weight", "ema_model.ups.1.class_bias.weight", "ema_model.ups.2.class_bias.weight", "ema_model.ups.4.class_bias.weight", "ema_model.ups.5.class_bias.weight", "ema_model.ups.6.class_bias.weight", "ema_model.ups.8.class_bias.weight", "ema_model.ups.9.class_bias.weight", "ema_model.ups.10.class_bias.weight", "ema_model.ups.12.class_bias.weight", "ema_model.ups.13.class_bias.weight", "ema_model.ups.14.class_bias.weight", "ema_model.mid.0.class_bias.weight", "ema_model.mid.1.class_bias.weight".

what happend

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant