-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmnist_kernel.py
91 lines (77 loc) · 2.81 KB
/
mnist_kernel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import random
import sys
from functools import partial
from pathlib import Path
from typing import Union, Tuple
import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import mnist_networks
from model_manifold.inspect import constant_direction_kernel, domain_projection
from model_manifold.plot import denormalize, to_gif, save_strip
def mnist_kernel_direction(
checkpoint_path: Union[str, Path], start_idx: int = -1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
normalize = transforms.Normalize((0.1307,), (0.3081,))
test_mnist = datasets.MNIST(
"data",
train=False,
download=True,
transform=transforms.Compose([transforms.ToTensor(), normalize]),
)
network = mnist_networks.medium_cnn(checkpoint_path)
if start_idx == -1:
start_idx = random.randrange(len(test_mnist))
print(f"Evolve the image {start_idx} in the kernel of the local data matrix.")
device = next(network.parameters()).device
start_image = test_mnist[start_idx][0].to(device)
v = torch.randn_like(start_image)
# noinspection PyTypeChecker
data_path, prob_path, pred_path = constant_direction_kernel(
network,
start_image,
v,
steps=1000,
post_processing=partial(domain_projection, normalization=normalize),
)
data_path = denormalize(data_path, normalize)
return data_path, prob_path, pred_path, start_idx
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export the path obtained evolving a valid image along a "
"random direction in the kernel of the local data matrix as a .gif",
usage="python3 mnist_kernel.py CHECKPOINT "
"[--start START --seed SEED --output-dir OUTPUT-DIR]",
)
parser.add_argument("checkpoint", type=str, help="Path to checkpoint model")
parser.add_argument(
"--start", type=int, default=-1, help="Index of the starting image"
)
parser.add_argument("--seed", type=int, default=0, help="Random seed")
parser.add_argument(
"--output-dir",
type=str,
default="outputs",
help="Output directory",
)
args = parser.parse_args(sys.argv[1:])
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
image_path, probability_path, prediction_path, start = mnist_kernel_direction(
args.checkpoint, args.start
)
output_dir = Path(args.output_dir).expanduser()
output_dir.mkdir(parents=True, exist_ok=True)
filename = f"{start:05d}_noise"
to_gif(
image_path,
output_dir / f"{filename}.gif",
step=100,
scale_factor=10.0,
)
save_strip(
image_path, output_dir / f"{filename}.png", probability_path, prediction_path
)