|
1 | 1 | import torch
|
2 | 2 | import os
|
3 |
| -import clip |
4 | 3 | import cv2
|
5 | 4 |
|
6 | 5 | import numpy as np
|
@@ -67,7 +66,7 @@ def eval_vggss_agg(
|
67 | 66 | labels, name = data['labels'], data['ids']
|
68 | 67 |
|
69 | 68 | # Inference
|
70 |
| - placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device) |
| 69 | + placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) |
71 | 70 | placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
|
72 | 71 | audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
|
73 | 72 | prompt_length)
|
@@ -170,7 +169,7 @@ def eval_avsbench_agg(
|
170 | 169 | images, audios, gts, labels, name = data['images'], data['audios'], data['gts'], data['labels'], data['ids']
|
171 | 170 |
|
172 | 171 | # Inference
|
173 |
| - placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device) |
| 172 | + placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) |
174 | 173 | placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
|
175 | 174 | audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
|
176 | 175 | prompt_length)
|
@@ -268,7 +267,7 @@ def eval_flickr_agg(
|
268 | 267 | labels, name = data['labels'], data['ids']
|
269 | 268 |
|
270 | 269 | # Inference
|
271 |
| - placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device) |
| 270 | + placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) |
272 | 271 | placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
|
273 | 272 | audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
|
274 | 273 | prompt_length)
|
@@ -364,7 +363,7 @@ def eval_exvggss_agg(
|
364 | 363 | labels, name = data['labels'], data['ids']
|
365 | 364 |
|
366 | 365 | # Inference
|
367 |
| - placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device) |
| 366 | + placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) |
368 | 367 | placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
|
369 | 368 | audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
|
370 | 369 | prompt_length)
|
@@ -446,7 +445,7 @@ def eval_exflickr_agg(
|
446 | 445 | labels, name = data['labels'], data['ids']
|
447 | 446 |
|
448 | 447 | # Inference
|
449 |
| - placeholder_tokens = clip.tokenize(prompt_template.replace('{}', '')).to(model.device) |
| 448 | + placeholder_tokens = model.get_placeholder_token(prompt_template.replace('{}', '')) |
450 | 449 | placeholder_tokens = placeholder_tokens.repeat((test_dataloader.batch_size, 1))
|
451 | 450 | audio_driven_embedding = model.encode_audio(audios.to(model.device), placeholder_tokens, text_pos_at_prompt,
|
452 | 451 | prompt_length)
|
|
0 commit comments