-
Notifications
You must be signed in to change notification settings - Fork 0
/
basic.py
30 lines (24 loc) · 883 Bytes
/
basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from torchvision import models
from ..image import ImageBatch
from ..objectives import ChannelObjective
from ..renderer import RendererBuilder
def basic(device="cuda:0", numberOfFrames=500):
model = models.resnet18(pretrained=True)
# look into 3rd layer, 18th channel; this changes a lot between architectures
objective = ChannelObjective(lambda m: m.layer3[1].conv2, channel=18)
# return 4 visualizations of size 224 x 224
imageBatch = ImageBatch.generate(
width=224,
height=224,
batch_size=4,
channels=3
).to(device)
renderer = (RendererBuilder()
.imageBatch(imageBatch)
.model(model)
.objective(objective)
.withLivePreview()
.build()
)
renderer.render(numberOfFrames)
return renderer.drawableImageBatch().data