-
Notifications
You must be signed in to change notification settings - Fork 0
/
segformer.py
57 lines (46 loc) · 2.29 KB
/
segformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from transformers import SegformerModel, SegformerConfig, SegformerFeatureExtractor, SegformerForSemanticSegmentation
import json
import torch
class segformersegmentation(torch.nn.Module):
def __init__(self, mode='train', size=640, config_json=None):
super().__init__()
self.mode=mode
self.size=size
if self.mode=="test":
with open(config_json,"r") as file:
config_dict = json.load(config_json)
config = SegformerConfig(**config_dict)
self.segpretrained = SegformerForSemanticSegmentation(config)
else:
self.segpretrained = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b2-finetuned-ade-512-512")
self.upsample=torch.nn.Upsample(size=self.size, mode='bilinear', align_corners=False)
self.segpretrained.decode_head.classifier= torch.nn.Conv2d(768, 5, kernel_size=(1, 1), stride=(1, 1))
def forward(self, x):
out_logits = self.segpretrained.forward(x).logits
if self.mode!='test':
out_logits = self.upsample(out_logits)
return out_logits
class segformersegmentationmitb3(torch.nn.Module):
def __init__(self, mode='train', size=640, config_json=None):
super().__init__()
self.mode=mode
self.size=size
if self.mode=="test":
with open(config_json,"r") as file:
config_dict = json.load(config_json)
config = SegformerConfig(**config_dict)
self.segpretrained = SegformerForSemanticSegmentation(config)
else:
self.segpretrained = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b3-finetuned-ade-512-512")
self.upsample=torch.nn.Upsample(size=self.size, mode='bilinear', align_corners=False)
self.segpretrained.decode_head.classifier= torch.nn.Conv2d(768, 5, kernel_size=(1, 1), stride=(1, 1))
def forward(self, x):
out_logits = self.segpretrained.forward(x).logits
if self.mode!='test':
out_logits = self.upsample(out_logits)
return out_logits
# model = SegformerForSemanticSegmentation.from_pretrained("")
# x = torch.rand((2,3,640,640))
# print(model)
# model = segformersegmentation(mode='train', size=640)
# print(model)