-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathAVSegFormer_pvt2_s4.py
84 lines (84 loc) · 2.61 KB
/
AVSegFormer_pvt2_s4.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
model = dict(
type='AVSegFormer',
neck=None,
backbone=dict(
type='pvt_v2_b5',
init_weights_path='pretrained/pvt_v2_b5.pth'),
vggish=dict(
freeze_audio_extractor=True,
pretrained_vggish_model_path='pretrained/vggish-10086976.pth',
preprocess_audio_to_log_mel=False,
postprocess_log_mel_with_pca=False,
pretrained_pca_params_path=None),
head=dict(
type='AVSegHead',
in_channels=[64, 128, 320, 512],
num_classes=1,
query_num=300,
use_learnable_queries=True,
fusion_block=dict(type='CrossModalMixer'),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128),
query_generator=dict(
type='AttentionGenerator',
num_layers=6,
query_num=300),
transformer=dict(
type='AVSTransformer',
encoder=dict(
num_layers=6,
layer=dict(
dim=256,
ffn_dim=2048,
dropout=0.1)),
decoder=dict(
num_layers=6,
layer=dict(
dim=256,
ffn_dim=2048,
dropout=0.1)))),
audio_dim=128,
embed_dim=256,
freeze_audio_backbone=True,
T=5)
dataset = dict(
train=dict(
type='S4Dataset',
split='train',
anno_csv='data/Single-source/s4_meta_data.csv',
dir_img='data/Single-source/s4_data/visual_frames',
dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
dir_mask='data/Single-source/s4_data/gt_masks',
img_size=(224, 224),
batch_size=2),
val=dict(
type='S4Dataset',
split='val',
anno_csv='data/Single-source/s4_meta_data.csv',
dir_img='data/Single-source/s4_data/visual_frames',
dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
dir_mask='data/Single-source/s4_data/gt_masks',
img_size=(224, 224),
batch_size=2),
test=dict(
type='S4Dataset',
split='test',
anno_csv='data/Single-source/s4_meta_data.csv',
dir_img='data/Single-source/s4_data/visual_frames',
dir_audio_log_mel='data/Single-source/s4_data/audio_log_mel',
dir_mask='data/Single-source/s4_data/gt_masks',
img_size=(224, 224),
batch_size=2))
optimizer = dict(
type='AdamW',
lr=2e-5)
loss = dict(
weight_dict=dict(
iou_loss=1.0,
mix_loss=0.1),
loss_type='dice')
process = dict(
num_works=8,
train_epochs=30,
freeze_epochs=5)