Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
dnwjddl authored Feb 4, 2024
1 parent 93af690 commit 8b4fce9
Show file tree
Hide file tree
Showing 36 changed files with 18,871 additions and 1 deletion.
70 changes: 69 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,69 @@
# Cobra
# CoBra 🐍 : Complementary Branch Fusing Class and Semantic Knowledge for Robust Weakly Supervised Semantic Segmentation

[[Project Page]](https://anno9555.github.io/)

![voc](./img/fig1.png)

We propose **Complementary Branch (CoBra)**, a novel dual branch framework consisting of two distinct architectures which provide valuable complementary knowledge of class (from CNN) and semantic (from vision transformer) to each branch. In particular, we learn **Class-Aware Projection (CAP)** for the CNN branch and **Semantic-Aware Projection (SAP)** for the vision transformer branch to explicitly fuse their complementary knowledge and facilitate a new type of extra patch-level supervision. Extensive experiments qualitatively and quantitatively investigate how CNN and vision transformer complement each other on the PASCAL VOC 2012 dataset showing a state-of-the-art WSSS result.

## :book: Contents
<!--ts-->
* [Prerequisite](#Prerequisite)
* [Usage](#Usage)
* [Pretrained Weight](#Pretrained-Weight)
* [About CoBra](#About-CoBra)
<!--te-->


## 🔧 Prerequisite
- Download [PASCAL VOC2012 devkit](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/)
- Ubuntu 20.04, with Python 3.10 and the following python dependencies.
```bash
pip install -r requirements.txt
```

## 💻 Usage
**Step1:**
- Run the run.sh script for training Cobra, it makes **Seed** and elements for making better Mask.
- We train [IRNet](https://github.com/jiwoon-ahn/irn) to generate mask to refine Seed.

- change ```title``` and ```pascal_dataset_path``` in run shell script.
```bash
bash run.sh
```

### 🏋️ Pretrained Weight
<table style="margin: auto">
<tr>
<td align="center">CAK Branch</td>
<td align="center"><i>ep19_cnn_checkpoint.pth</i></td>
<td><a href="https://drive.google.com/file/d/1X0kn_imyesfKlguBWqoysar5_4RWMFZ1/view?usp=sharing](https://drive.google.com/drive/folders/1ZUCTrz7J4eCUrMTLgbaqHEJqqW_ZYamy?usp=sharing">link</td>
</tr>
<tr>
<td align="center">SAK Branch</td>
<td align="center"><i>ep19_tran_checkpoint.pth</i></td>
<td><a href="https://drive.google.com/file/d/1GAEO-Qta_iUnR1ptZL7z5ZTiCuTA9QWx/view?usp=sharing](https://drive.google.com/drive/folders/1ZUCTrz7J4eCUrMTLgbaqHEJqqW_ZYamy?usp=sharing">link</td>
</tr>
</table>


**Step2:** For the Segmentation part, we used DeepLabV2 with resnet101 backbone and MiT-B2 backbone.
- DeepLabV2 (https://github.com/kazuto1011/deeplab-pytorch)


## 🐍 About CoBra
<center><img src="./img/main.png" width="700px" height="700px" title="Github_Logo"/>
</center>

### Results
- Qualitative Results on Pascal VOC 2012 dataset.
![voc](./img/VOC_results.png)


- Qualitative Results on MS-COCO 2014 dataset.
![coco](./img/COCO_results.png)


## :scroll: Acknowledgement
This repository has been developed based on the [IRNet](https://github.com/jiwoon-ahn/irn) repository. Thanks for the good work!

Binary file added img/COCO_results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/VOC_results.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/fig1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/fig2 .png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added img/main.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
136 changes: 136 additions & 0 deletions loss_modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
from torch.nn import functional as F
import sys

def cll_v1(args, cam_cnn, transformer_embed, label_bg):
'''
Semantic Aware Projection
'''
# *********Hyperparams********* #
top_num = args.top_bot_k[0]
bottom_num = args.top_bot_k[1]
tau = args.tau

B, C, N, _ = cam_cnn.shape
N2 = N * N # for 196

scores = F.softmax(cam_cnn * label_bg, dim=1) # [B, 21, 14, 14]; Softmax on class level
pseudo_score, pseudo_label = torch.max(scores, dim=1) # [B, 14, 14]; Select best class-score on CNN CAMs by pixel
cam_cnn = cam_cnn.reshape(B, C, -1) # [B, C, 196]
pseudo_label = pseudo_label.reshape(B, -1) # [B, 196]

cam = [cam_cnn[i, pseudo_label[i]] for i in range(B)] # [B, 196, 196]
cam = torch.stack(cam, dim=0)

top_values, top_indices = torch.topk(
cam, k=top_num, dim=-1, largest=True) # [B, 196, 20]
bottom_values, bottom_indices = torch.topk(
cam, k=bottom_num, dim=-1, largest=False) # [B, 196, 20]

transformer_embed = transformer_embed.transpose(1, 2) # [B, 196, 128]

pos_init = []
neg_init = []

for i in range(B):
pos_init.append(transformer_embed[i, top_indices[i]])
neg_init.append(transformer_embed[i, bottom_indices[i]])

pos = torch.stack(pos_init, dim=0) # [B, 196, 20, 128]
neg = torch.stack(neg_init, dim=0) # [B, 196, 20, 128]

# Computing Loss
loss = torch.zeros((1)).cuda()
'''
basically fomula of loss = X/(X+Y)
'''
for i in range(N2):

main_vector_tf = transformer_embed[:, i].unsqueeze(-1)

# X where of numerator
pos_inner = pos[:, i] @ main_vector_tf # [B, 20, 1]
X = torch.exp(pos_inner.squeeze(-1) / tau)

# Y where of denominator
neg_inner = neg[:, i] @ main_vector_tf # [B, 20, 1]
Y = torch.sum((torch.exp(neg_inner.squeeze(-1)) / tau),
dim=-1, keepdim=True)

# X/(X+Y)
loss += torch.sum(-torch.log(X / (X + Y)))

return loss / (N2 * (top_num * B))


def cll_v2(args, attn_weights, cnn_embed):
'''
Class Aware Projection
'''

# *********Hyperparams*********
top_num = args.top_bot_k[2]
bottom_num = args.top_bot_k[3]
tau = args.tau
# *****************************

attn_weights = attn_weights[:, 1:, 1:] # P2P Attention Score Excepted Background Token
B, N, N = attn_weights.shape

top_values, top_indices = torch.topk(
attn_weights, k=top_num, dim=-1, largest=True) # [B, 196, 20]
bottom_values, bottom_indices = torch.topk(
attn_weights, k=bottom_num, dim=-1, largest=False) # [B, 196, 20]

cnn_embed = cnn_embed.transpose(1, 2) # [B, 196, 128]

pos_init = []
neg_init = []

for i in range(B):
pos_init.append(cnn_embed[i, top_indices[i]])
neg_init.append(cnn_embed[i, bottom_indices[i]])

pos = torch.stack(pos_init, dim=0) # [B, 196, k, 128]
neg = torch.stack(neg_init, dim=0) # [B, 196, k, 128]

# Computing Loss
loss = torch.zeros(1).cuda()
'''
basically fomula of loss is X/(X+Y)
'''
for i in range(N):

main_vector_tf = cnn_embed[:, i].unsqueeze(-1) # main vector for all batch

# X where of numerator
pos_inner = pos[:, i] @ main_vector_tf # [B, 20, 1], matmul
X = torch.exp(pos_inner.squeeze(-1) / tau)

# Y where of denominator
neg_inner = neg[:, i] @ main_vector_tf # [B, 20, 1], matmul
Y = torch.sum(torch.exp(neg_inner.squeeze(-1) / tau),
dim=-1, keepdim=True)

# X/(X+Y)
loss += torch.sum(-torch.log(X / (X + Y)))

# loss /= N * B * top_num
loss /= N * B * top_num
return loss

def loss_CAM(cam_cnn_224, cam_tf_224, label):
cam_cnn_224_classes = cam_cnn_224[:, 1:]
cam_tf_224_classes = cam_tf_224[:, 1:]
loss_interCAM = 0
for i in range(len(cam_cnn_224_classes)):
valid_cat = torch.nonzero(label[i])[:, 0]
cam_cnn_224_class = cam_cnn_224_classes[i, valid_cat]
cam_tf_224_class = cam_tf_224_classes[i, valid_cat]
loss_bg = torch.mean(torch.abs(cam_cnn_224[i][0] - cam_tf_224[i][0]))
loss_inter = torch.mean(torch.abs(cam_cnn_224_class - cam_tf_224_class))
# loss = (loss_bg + loss_inter)/2
loss = loss_inter
loss_interCAM += loss
loss_interCAM1 = loss_interCAM / len(cam_cnn_224)
return loss_interCAM1
Loading

0 comments on commit 8b4fce9

Please sign in to comment.