We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Try to run the code: " import argparse import torch import torchvision
from ddpm import script_utils
def main(): args = create_argparser().parse_args() device = args.device
try: diffusion = script_utils.get_diffusion_from_args(args).to(device) diffusion.load_state_dict(torch.load(args.model_path)) if args.use_labels: for label in range(10): y = torch.ones(args.num_images // 10, dtype=torch.long, device=device) * label samples = diffusion.sample(args.num_images // 10, device, y=y) for image_id in range(len(samples)): image = ((samples[image_id] + 1) / 2).clip(0, 1) image = torchvision.transforms.Resize((128, 128))(image) torchvision.utils.save_image(image, f"{args.save_dir}/{label}-{image_id}.png") else: samples = diffusion.sample(args.num_images, device) for image_id in range(len(samples)): image = ((samples[image_id] + 1) / 2).clip(0, 1) image = torchvision.transforms.Resize((128, 128))(image) torchvision.utils.save_image(image, f"{args.save_dir}/{image_id}.png") except KeyboardInterrupt: print("Keyboard interrupt, generation finished early")
def create_argparser(): device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu") print(device) defaults = dict(num_images=100, device=device, schedule_low=1e-4, schedule_high=0.02, ) defaults.update(script_utils.diffusion_defaults())
parser = argparse.ArgumentParser() parser.add_argument("--model_path", type=str, default='./ddpm_logs/cifar10-ddpm-2024-09-09-02-56-iteration-279000-model.pth') parser.add_argument("--save_dir", type=str, default='./outputs') script_utils.add_dict_to_argparser(parser, defaults) return parser
if name == "main": main() "
Traceback (most recent call last): File "XXX/MyMethod/ddpmsample.py", line 54, in main() File "XXX/MyMethod/ddpmsample.py", line 14, in main diffusion.load_state_dict(torch.load(args.model_path)) File "XXX/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for GaussianDiffusion: Missing key(s) in state_dict: "model.downs.0.class_bias.weight", "model.downs.1.class_bias.weight", "model.downs.3.class_bias.weight", "model.downs.4.class_bias.weight", "model.downs.6.class_bias.weight", "model.downs.7.class_bias.weight", "model.downs.9.class_bias.weight", "model.downs.10.class_bias.weight", "model.ups.0.class_bias.weight", "model.ups.1.class_bias.weight", "model.ups.2.class_bias.weight", "model.ups.4.class_bias.weight", "model.ups.5.class_bias.weight", "model.ups.6.class_bias.weight", "model.ups.8.class_bias.weight", "model.ups.9.class_bias.weight", "model.ups.10.class_bias.weight", "model.ups.12.class_bias.weight", "model.ups.13.class_bias.weight", "model.ups.14.class_bias.weight", "model.mid.0.class_bias.weight", "model.mid.1.class_bias.weight", "ema_model.downs.0.class_bias.weight", "ema_model.downs.1.class_bias.weight", "ema_model.downs.3.class_bias.weight", "ema_model.downs.4.class_bias.weight", "ema_model.downs.6.class_bias.weight", "ema_model.downs.7.class_bias.weight", "ema_model.downs.9.class_bias.weight", "ema_model.downs.10.class_bias.weight", "ema_model.ups.0.class_bias.weight", "ema_model.ups.1.class_bias.weight", "ema_model.ups.2.class_bias.weight", "ema_model.ups.4.class_bias.weight", "ema_model.ups.5.class_bias.weight", "ema_model.ups.6.class_bias.weight", "ema_model.ups.8.class_bias.weight", "ema_model.ups.9.class_bias.weight", "ema_model.ups.10.class_bias.weight", "ema_model.ups.12.class_bias.weight", "ema_model.ups.13.class_bias.weight", "ema_model.ups.14.class_bias.weight", "ema_model.mid.0.class_bias.weight", "ema_model.mid.1.class_bias.weight".
what happend
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Try to run the code: "
import argparse
import torch
import torchvision
from ddpm import script_utils
python ddpmsample.py --model_path='./ddpm_logs/cifar10-ddpm-2024-09-09-02-56-iteration-13000-optim.pth' --save_dir='outputs' --use_labels=True
def main():
args = create_argparser().parse_args()
device = args.device
def create_argparser():
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")
print(device)
defaults = dict(num_images=100,
device=device,
schedule_low=1e-4,
schedule_high=0.02,
)
defaults.update(script_utils.diffusion_defaults())
if name == "main":
main()
"
Traceback (most recent call last):
File "XXX/MyMethod/ddpmsample.py", line 54, in
main()
File "XXX/MyMethod/ddpmsample.py", line 14, in main
diffusion.load_state_dict(torch.load(args.model_path))
File "XXX/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for GaussianDiffusion:
Missing key(s) in state_dict: "model.downs.0.class_bias.weight", "model.downs.1.class_bias.weight", "model.downs.3.class_bias.weight", "model.downs.4.class_bias.weight", "model.downs.6.class_bias.weight", "model.downs.7.class_bias.weight", "model.downs.9.class_bias.weight", "model.downs.10.class_bias.weight", "model.ups.0.class_bias.weight", "model.ups.1.class_bias.weight", "model.ups.2.class_bias.weight", "model.ups.4.class_bias.weight", "model.ups.5.class_bias.weight", "model.ups.6.class_bias.weight", "model.ups.8.class_bias.weight", "model.ups.9.class_bias.weight", "model.ups.10.class_bias.weight", "model.ups.12.class_bias.weight", "model.ups.13.class_bias.weight", "model.ups.14.class_bias.weight", "model.mid.0.class_bias.weight", "model.mid.1.class_bias.weight", "ema_model.downs.0.class_bias.weight", "ema_model.downs.1.class_bias.weight", "ema_model.downs.3.class_bias.weight", "ema_model.downs.4.class_bias.weight", "ema_model.downs.6.class_bias.weight", "ema_model.downs.7.class_bias.weight", "ema_model.downs.9.class_bias.weight", "ema_model.downs.10.class_bias.weight", "ema_model.ups.0.class_bias.weight", "ema_model.ups.1.class_bias.weight", "ema_model.ups.2.class_bias.weight", "ema_model.ups.4.class_bias.weight", "ema_model.ups.5.class_bias.weight", "ema_model.ups.6.class_bias.weight", "ema_model.ups.8.class_bias.weight", "ema_model.ups.9.class_bias.weight", "ema_model.ups.10.class_bias.weight", "ema_model.ups.12.class_bias.weight", "ema_model.ups.13.class_bias.weight", "ema_model.ups.14.class_bias.weight", "ema_model.mid.0.class_bias.weight", "ema_model.mid.1.class_bias.weight".
what happend
The text was updated successfully, but these errors were encountered: