Skip to content

Visualizing individual channels

mayukh edited this page Dec 10, 2020 · 2 revisions

Visualizing individual channels

This section of torch-dreams was highly inspired by Feature visualization by Olah, et al. We basically optimize the input image to maximize activations of a certain channel of a layer in the neural network.

import cv2
import numpy as np
import matplotlib.pyplot as plt
from torch_dreams.dreamer import dreamer
import torchvision.models as models

model = models.inception_v3(pretrained=True)
dreamy_boi = dreamer(model)

"""
feel free to use more layers for experiments 
"""
layers_to_use = [model.Mixed_6c.branch7x7_1.conv]

The next step now would be to define a custom_func that would enable use to selectively optimize a single channel. We could've made a simpler function, but this would be useful when we'll be running the optimizations on a lot of channels

def make_custom_func(layer_number = 0, channel_number= 0): 
    def custom_func(layer_outputs):
        loss = layer_outputs[0][channel_number].mean()
        return loss
    return custom_func

Now in order to optimize the 7th channel of the first layer mentioned in layers_to_uses we define the custom loss as:

my_custom_func = make_custom_func(layer_number= 0, channel_number = 7)

Let us now see how it worked out:

config = {
    "image_path": "noise.jpg",
    "layers": layers_to_use,
    "octave_scale": 1.1,  
    "num_octaves": 20,  
    "iterations": 100,  
    "lr": 0.04,
    "max_rotation": 0.7,
    "custom_func":  my_custom_func,
}

out = dreamy_boi.deep_dream(config)
plt.imshow(out)
plt.show()

Now let us scale this up to visualize a lot more channels:

channels = [7,31,115,120,74, 123, 56, 69, 148, 16, 47, 4]

custom_funcs = []
for c in channels:
    custom_funcs.append(make_custom_func(channel_number = c)
    
all_outs = []
for c in custom_funcs:  ## might take some time to run 
    config["custom_func"] = c
    out = dreamy_boi.deep_dream(config)
    all_outs.append(out)

Let's make a figure out of all the images we generated now. If things went well, you'll end up with the banner seen on top of this page.

fig, ax = plt.subplots(nrows=2, ncols=6, figsize = (16*2,6*2))
fig.suptitle('model: inceptionv3 \n layer: Mixed_6c.branch7x7_1.conv \n', fontsize=25)

j = 0
for i in range(len(channels)):
#     print(i,j)
    ax.flat[j].imshow(all_outs[i])
    ax.flat[j].axis("off")
    ax.flat[j].set_facecolor("white")
    ax.flat[j].set_title("channel:" + str(channels[i]), fontsize = 20)
    j+=1
    
fig.savefig("inceptionv3_new_channels.jpg")
Clone this wiki locally