Skip to content

Commit 7a36632

Browse files
committedDec 13, 2023
Remove clip library dependency
1 parent 524fa54 commit 7a36632

File tree

5 files changed

+26
-8
lines changed

5 files changed

+26
-8
lines changed
 

‎Eval.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import os
3-
import clip
43
import cv2
54

65
import numpy as np
@@ -67,7 +66,7 @@ def eval_vggss_agg(
6766
labels, name = data['labels'], data['ids']
6867

6968
# Inference
70-
placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device)
69+
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
7170
placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
7271
audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
7372
prompt_length)
@@ -170,7 +169,7 @@ def eval_avsbench_agg(
170169
images, audios, gts, labels, name = data['images'], data['audios'], data['gts'], data['labels'], data['ids']
171170

172171
# Inference
173-
placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device)
172+
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
174173
placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
175174
audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
176175
prompt_length)
@@ -268,7 +267,7 @@ def eval_flickr_agg(
268267
labels, name = data['labels'], data['ids']
269268

270269
# Inference
271-
placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device)
270+
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
272271
placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
273272
audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
274273
prompt_length)
@@ -364,7 +363,7 @@ def eval_exvggss_agg(
364363
labels, name = data['labels'], data['ids']
365364

366365
# Inference
367-
placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device)
366+
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
368367
placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
369368
audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
370369
prompt_length)
@@ -446,7 +445,7 @@ def eval_exflickr_agg(
446445
labels, name = data['labels'], data['ids']
447446

448447
# Inference
449-
placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device)
448+
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
450449
placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
451450
audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
452451
prompt_length)

‎README.md

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ $ pip install tensorboard
3636
$ pip transformers==4.25.1
3737
$ pip install opencv-python
3838
$ pip install tqdm
39+
$ pip install scikit-learn
3940

4041
```
4142

‎Test_PTModels.sh

+1
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ python Test_PTModels.py \
88
--vggss_path {put dataset directory} \
99
--flickr_path {put dataset directory} \
1010
--avs_path {put dataset directory} \
11+
--save_path {put dataset directory} \
1112
--epochs None

‎Train_ACL.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import time
66
import datetime
77
import yaml
8-
import clip
98
import shutil
109
import argparse
1110

@@ -205,7 +204,7 @@ def main(model_name, exp_name, train_config_name, data_path_dict, save_path):
205204

206205
with autocast_fn():
207206
# Train step
208-
placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(module.device)
207+
placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', ''))
209208
placeholder_tokens = placeholder_tokens.repeat((train_dataloader.batch_size, 1))
210209
audio_driven_embedding = module.encode_audio(audios.to(module.device), placeholder_tokens,
211210
text_pos_at_prompt, prompt_length).half()

‎modules/models.py

+18
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from modules.AudioToken.embedder import FGAEmbedder
1010
from modules.CLIPSeg.clipseg_for_audio import CLIPSeg
1111
from modules.mask_utils import ImageMasker, FeatureMasker
12+
from transformers import AutoTokenizer
1213

1314

1415
class ACL(nn.Module):
@@ -37,6 +38,9 @@ def __init__(self, conf_file: str, device: str):
3738
cfg = BEATsConfig(checkpoint['cfg'])
3839
self.audio_backbone = BEATs(cfg)
3940

41+
# Text Tokenizer for placeholder prompt
42+
self.tokenizer = AutoTokenizer.from_pretrained("CIDAS/clipseg-rd64-refined")
43+
4044
# Init audio projection layer
4145
self.audio_proj = FGAEmbedder(input_size=self.args.audio_proj.input_size * 3,
4246
output_size=self.args.audio_proj.output_size)
@@ -63,6 +67,20 @@ def __init__(self, conf_file: str, device: str):
6367
self.masker_i.to(self.device)
6468
self.masker_f.to(self.device)
6569

70+
def get_placeholder_token(self, prompt_text: str):
71+
"""
72+
Get placeholder token from prompt text
73+
74+
Args:
75+
prompt_text (str): prompt text without '{}'
76+
77+
Returns:
78+
CLIPTokenizerFast result with prompt text
79+
"""
80+
placeholder_token = self.tokenizer(prompt_text, return_tensors="pt").data['input_ids']
81+
placeholder_token = F.pad(placeholder_token, (0, 77 - placeholder_token.shape[-1])).to(self.device)
82+
return placeholder_token
83+
6684
def train(self, bool: bool = True):
6785
"""
6886
Set the module in training mode.

0 commit comments

Comments
 (0)
Please sign in to comment.