-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathprompters.py
75 lines (53 loc) · 2.69 KB
/
prompters.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from functools import reduce, partial
from models.vision_transformer import VisionTransformer
class PatchPrompter(nn.Module):
def __init__(self, args):
super(PatchPrompter, self).__init__()
self.patch_size = args.patch_size
self.prompt_size = args.prompt_size
self.fg_size = self.patch_size - args.prompt_size * 2
self.patch = nn.Parameter(torch.randn([1, 3, args.image_size, args.image_size]))
def forward(self, x):
_, _, h, w = x.size()
fg_in_patch = torch.zeros([1, 3, self.fg_size, self.fg_size]).cuda()
fg_in_patch = F.pad(fg_in_patch, (self.prompt_size, self.prompt_size, self.prompt_size, self.prompt_size), "constant", 1)
mask = fg_in_patch.repeat(1, 1, h//self.patch_size, w//self.patch_size)
self.prompt = self.patch * mask
return x + self.prompt
class SharedPrompter(nn.Module):
def __init__(self, args):
super(SharedPrompter, self).__init__()
self.patch_size = args.patch_size
self.prompt_size = args.prompt_size
self.fg_size = self.patch_size - args.prompt_size * 2
self.patch = nn.Parameter(torch.randn([1, 3, self.patch_size, self.patch_size]))
def forward(self, x):
_, _, h, w = x.size()
fg_in_patch = torch.zeros([1, 3, self.fg_size, self.fg_size]).cuda()
fg_in_patch = F.pad(fg_in_patch, (self.prompt_size, self.prompt_size, self.prompt_size, self.prompt_size), "constant", 1)
mask = fg_in_patch.repeat(1, 1, h//self.patch_size, w//self.patch_size)
patch = self.patch.repeat(1, 1, h//self.patch_size, w//self.patch_size)
self.prompt = patch * mask
return x + self.prompt
class PadPrompter(nn.Module):
def __init__(self, args):
super(PadPrompter, self).__init__()
pad_size = args.prompt_size
image_size = args.image_size
self.base_size = image_size - pad_size*2
self.pad_up = nn.Parameter(torch.randn([1, 3, pad_size, image_size]))
self.pad_down = nn.Parameter(torch.randn([1, 3, pad_size, image_size]))
self.pad_left = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size]))
self.pad_right = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size]))
def forward(self, x):
base = torch.zeros(1, 3, self.base_size, self.base_size).cuda()
prompt = torch.cat([self.pad_left, base, self.pad_right], dim=3)
prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=2)
prompt = torch.cat(x.size(0) * [prompt])
return x + prompt