Skip to content

gg #52

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open

gg #52

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 310 additions & 0 deletions notebooks/models.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyNAyZfilBWm4WjENWI0eGc2",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/amansyayf/2016-solar_project/blob/master/notebooks/models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SToV-nheVjNe"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"from torchvision.models import vgg19\n",
"import torchvision.transforms as transforms\n",
"\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"source": [
"**Create Generation Network**"
],
"metadata": {
"id": "M4DvnVVQ8hUW"
}
},
{
"cell_type": "code",
"source": [
"class ConvLayer(nn.Module):\n",
" def __init__(self, in_c, out_c, kernel_size):\n",
" super().__init__()\n",
" pad = int(np.floor(kernel_size/2))\n",
" self.conv = nn.Conv2d(in_c, out_c, kernel_size = kernel_size, stride = 1, padding = pad)\n",
" def forward(self, x):\n",
" return self.conv(x)"
],
"metadata": {
"id": "eKXEwlppWi_G"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class Bottleneck(nn.Module):\n",
" def __init__(self, in_c, out_c, kernel_size = 3, stride=1):\n",
" super().__init__()\n",
" self.in_c = in_c\n",
" self.out_c = out_c\n",
" self.kernel_size = kernel_size\n",
" self.identity_block = nn.Sequential(\n",
" ConvLayer(in_c, out_c//4, kernel_size=1),\n",
" nn.InstanceNorm2d(out_c//4),\n",
" nn.ReLU(),\n",
" ConvLayer(out_c//4, out_c//4, kernel_size),\n",
" nn.InstanceNorm2d(out_c//4),\n",
" nn.ReLU(),\n",
" ConvLayer(out_c//4, out_c, kernel_size=1),\n",
" nn.InstanceNorm2d(out_c),\n",
" nn.ReLU()\n",
" )\n",
"\n",
" def residual(self, x):\n",
" if self.in_c == self.out_c:\n",
" return x\n",
" else:\n",
" layer = nn.Sequential(\n",
" ConvLayer(self.in_c, self.out_c, kernel_size=1),\n",
" nn.InstanceNorm2d(self.out_c)\n",
" )\n",
" return layer(x)\n",
"\n",
" def forward(self, x):\n",
" out = self.identity_block(x)\n",
" residual_x = self.residual(x)\n",
" out =+ residual_x\n",
" out = F.relu(out)\n",
" return out\n",
"\n"
],
"metadata": {
"id": "QzQCJNHKZf76"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class UpSample(nn.Module):\n",
" def __init__(self, in_channels, out_channels, scale_factor, mode='bilinear'):\n",
" super().__init__()\n",
" self.scale_factor = scale_factor\n",
" self.mode = mode\n",
" self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)\n",
" self.norm = nn.InstanceNorm2d(out_channels)\n",
"\n",
" def forward(self, x):\n",
" out = self.conv(x)\n",
" out = F.interpolate(out, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)\n",
" out = self.norm(out)\n",
" out = F.relu(out)\n",
" return out"
],
"metadata": {
"id": "qYx1-ZXRq6FV"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def upsample(scale_factor):\n",
" return nn.Upsample(scale_factor=scale_factor, mode='bilinear')"
],
"metadata": {
"id": "iUGnW6Rgq4DA"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"class HRNet(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.layer1_1 = Bottleneck(3, 16)\n",
"\n",
" self.layer2_1 = Bottleneck(16, 32)\n",
" self.downsample2_1 = nn.Conv2d(16, 32, kernel_size=3, stride = 2, padding=1)\n",
"\n",
" self.layer3_1 = Bottleneck(32, 32)\n",
" self.layer3_2 = Bottleneck(32, 32)\n",
" self.downsample3_1 = nn.Conv2d(32, 32, kernel_size=3, stride = 2, padding=1)\n",
" self.downsample3_2 = nn.Conv2d(32, 32, kernel_size=3, stride = 4, padding=1)\n",
" self.downsample3_3 = nn.Conv2d(32, 32, kernel_size=3, stride = 2, padding=1)\n",
"\n",
" self.layer4_1 = Bottleneck(64, 64)\n",
" self.layer5_1 = Bottleneck(192, 64)\n",
" self.layer6_1 = Bottleneck(64, 32)\n",
" self.layer7_1 = Bottleneck(32, 16)\n",
" self.layer8_1 = nn.Conv2d(16, 3, kernel_size=3, stride = 1, padding=1)\n",
"\n",
" def forward(self, x):\n",
" map1_1 = self.layer1_1(x)\n",
"\n",
" map2_1 = self.layer2_1(map1_1)\n",
" map2_2 = self.downsample2_1(map1_1)\n",
"\n",
" map3_1 = torch.cat((self.layer3_1(map2_1), upsample(map2_2, 2)), 1)\n",
" map3_2 = torch.cat((self.downsample3_1(map2_1), self.layer3_2(map2_2)), 1)\n",
" map3_3 = torch.cat((self.downsample3_2(map2_1), self.downsample3_3(map2_2)), 1)\n",
"\n",
" map4_1 = torch.cat((self.layer4_1(map3_1), upsample(map3_2, 2), upsample(map3_3, 4)), 1)\n",
"\n",
" out = self.layer5_1(map4_1)\n",
" out = self.layer6_1(out)\n",
" out = self.layer7_1(out)\n",
" out = self.layer8_1(out)\n",
"\n",
" return out"
],
"metadata": {
"id": "6RsVKZ_9dzQy"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"**Create utility functiions**"
],
"metadata": {
"id": "6mqn8lrz87U1"
}
},
{
"cell_type": "code",
"source": [
"def image_loading(path, size=None):\n",
" img = Image.open(path)\n",
"\n",
" if size is not None:\n",
" img = img.resize((size, size))\n",
"\n",
" transform = transforms.Compose([\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))\n",
" ])\n",
"\n",
" img = transform(img)\n",
" img = img.unsqueeze(0)\n",
" return img"
],
"metadata": {
"id": "N7-R6ZtkuqAP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def im_convert(img):\n",
"\n",
" img = img.to('cpu').clone().detach()\n",
" img = img.numpy().squeeze(0)\n",
" img = img.transpose(1, 2, 0)\n",
" img = img * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))\n",
" img = img.clip(0, 1)\n",
" return img"
],
"metadata": {
"id": "7v3MgY0sxG3f"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_features(img, model, layers=None):\n",
"\n",
" if layers is None:\n",
" layers = {\n",
" '0': 'conv1_1', # style layer\n",
" '5': 'conv2_1', # style layer\n",
" '10': 'conv3_1', # style layer\n",
" '19': 'conv4_1', # style layer\n",
" '28': 'conv5_1', # style layer\n",
"\n",
" '21': 'conv4_2' # content layer\n",
" }\n",
"\n",
" features = {}\n",
" x = img\n",
" for name, layer in model._modules.items():\n",
" x = layer(x)\n",
" if name in layers:\n",
" features[layers[name]] = x\n",
"\n",
" return features\n"
],
"metadata": {
"id": "TSyBWQ9t8dcF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_gram_matrix(img):\n",
"\n",
" b, c, h, w = img.size()\n",
" img = img.view(b*c, h*w)\n",
" gram = torch.mm(img, img.t())\n",
" return gram"
],
"metadata": {
"id": "B5d6r6CE9S2V"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "Va0fXQOA9wLo"
},
"execution_count": null,
"outputs": []
}
]
}
Loading