-
Notifications
You must be signed in to change notification settings - Fork 19
Visualizing individual channels
This section of torch-dreams was highly inspired by Feature visualization by Olah, et al. We basically optimize the input image to maximize activations of a certain channel of a layer in the neural network.
import cv2
import numpy as np
import matplotlib.pyplot as plt
from torch_dreams.dreamer import dreamer
import torchvision.models as models
model = models.inception_v3(pretrained=True)
dreamy_boi = dreamer(model)
"""
feel free to use more layers for experiments
"""
layers_to_use = [model.Mixed_6c.branch7x7_1.conv]
The next step now would be to define a custom_func
that would enable use to selectively optimize a single channel. We could've made a simpler function, but this would be useful when we'll be running the optimizations on a lot of channels
def make_custom_func(layer_number = 0, channel_number= 0):
def custom_func(layer_outputs):
loss = layer_outputs[0][channel_number].mean()
return loss
return custom_func
Now in order to optimize the 7th channel of the first layer mentioned in layers_to_uses we define the custom loss as:
my_custom_func = make_custom_func(layer_number= 0, channel_number = 7)
Let us now see how it worked out:
config = {
"image_path": "noise.jpg",
"layers": layers_to_use,
"octave_scale": 1.1,
"num_octaves": 20,
"iterations": 100,
"lr": 0.04,
"max_rotation": 0.7,
"custom_func": my_custom_func,
}
out = dreamy_boi.deep_dream(config)
plt.imshow(out)
plt.show()
Now let us scale this up to visualize a lot more channels:
channels = [7,31,115,120,74, 123, 56, 69, 148, 16, 47, 4]
custom_funcs = []
for c in channels:
custom_funcs.append(make_custom_func(channel_number = c)
all_outs = []
for c in custom_funcs: ## might take some time to run
config["custom_func"] = c
out = dreamy_boi.deep_dream(config)
all_outs.append(out)
Let's make a figure out of all the images we generated now. If things went well, you'll end up with the banner seen on top of this page.
fig, ax = plt.subplots(nrows=2, ncols=6, figsize = (16*2,6*2))
fig.suptitle('model: inceptionv3 \n layer: Mixed_6c.branch7x7_1.conv \n', fontsize=25)
j = 0
for i in range(len(channels)):
# print(i,j)
ax.flat[j].imshow(all_outs[i])
ax.flat[j].axis("off")
ax.flat[j].set_facecolor("white")
ax.flat[j].set_title("channel:" + str(channels[i]), fontsize = 20)
j+=1
fig.savefig("inceptionv3_new_channels.jpg")