-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathinput_segmentation.py
120 lines (88 loc) · 2.9 KB
/
input_segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import scipy.io as sio
import scipy
import os
import numpy as np
from scipy.stats import norm
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf
import os
from argparse import ArgumentParser
import pickle
import torchaudio.functional as F
from src.model import joint_Model
from src.data import B2T_DataModule
from src.utils import phonemize_text
def align(emission, targets):
alignments, scores = F.forced_align(emission, targets, blank=0)
alignments, scores = (
alignments[0],
scores[0],
) # remove batch dimension for simplicity
scores = scores.exp() # convert back to probability
return alignments, scores
def unflatten(list_, lengths):
assert len(list_) == sum(lengths)
i = 0
ret = []
for l in lengths:
ret.append(list_[i : i + l])
i += l
return ret
def main():
parser = ArgumentParser()
parser.add_argument("--ckpt", type=str)
args = parser.parse_args()
ckpt_path = args.ckpt
model = joint_Model.load_from_checkpoint(
ckpt_path, decoders_conf=["ctc_al"], strict=False
)
model = model.cuda()
input_dir = "./dataset/train"
cfg = OmegaConf.create(
{
"train_data_dir": None,
"val_data_dir": input_dir,
"test_data_dir": None,
"train_batch_size": 1,
"valid_batch_size": 1, # no padding
# "debugging": True,
"num_workers": 2,
"word_level":False
}
)
dm = B2T_DataModule(cfg)
dm.setup("")
dataloader = dm.val_dataloader()
word_sp_indices = []
for batch in tqdm(dataloader):
emission = (
model(
spikePow=batch["spikePow"].cuda(),
spikePow_mask=batch["spikePow_mask"].cuda(),
spikePow_lens=batch["spikePow_lens"].cuda(),
)["ctc_al"][0]
.detach()
.cpu()
)
targets_ids = batch["sent_ids"][:, :-1]
target_text = batch["sent"][0]
TRANSCRIPT = target_text.split("|")
aligned_tokens, alignment_scores = align(emission, targets_ids)
token_spans = F.merge_tokens(aligned_tokens, alignment_scores)
token_spans = [i for i in token_spans if i.token != 1]
word_spans = unflatten(token_spans, [len(word) for word in TRANSCRIPT])
# index of the start and end of each word in spikePow
word_indices = []
for w_span in word_spans:
# convert to index of spikePow
s_start = w_span[0].start * 4
s_end = w_span[-1].end * 4 + 3
word_indices.append([
s_start,
s_end
])
word_sp_indices.append(word_indices)
with open(os.path.join(input_dir, "word_sp_indices.npy"), 'wb') as handle:
pickle.dump(word_sp_indices, handle, protocol=pickle.HIGHEST_PROTOCOL)
if __name__ == "__main__":
main()