-
Notifications
You must be signed in to change notification settings - Fork 6
/
Mapper.lua
62 lines (53 loc) · 1.66 KB
/
Mapper.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
require 'torch'
-- construct an object to deal with the mapping
local mapper = torch.class('Mapper')
function mapper:__init(dictPath)
assert(paths.filep(dictPath), dictPath ..' not found')
self.alphabet2token = {}
self.token2alphabet = {}
-- make maps
local cnt = 0
for line in io.lines(dictPath) do
self.alphabet2token[line] = cnt
self.token2alphabet[cnt] = line
cnt = cnt + 1
end
end
function mapper:encodeString(line)
line = string.lower(line)
local label = {}
for i = 1, #line do
local character = line:sub(i, i)
table.insert(label, self.alphabet2token[character])
end
return label
end
function mapper:decodeOutput(predictions)
--[[
Turns the predictions tensor into a list of the most likely tokens
NOTE:
to compute WER we strip the begining and ending spaces
--]]
local tokens = {}
local blankToken = self.alphabet2token['$']
local preToken = blankToken
-- The prediction is a sequence of likelihood vectors
local _, maxIndices = torch.max(predictions, 2)
maxIndices = maxIndices:float():squeeze()
for i=1, maxIndices:size(1) do
local token = maxIndices[i] - 1 -- CTC indexes start from 1, while token starts from 0
-- add token if it's not blank, and is not the same as pre_token
if token ~= blankToken and token ~= preToken then
table.insert(tokens, token)
end
preToken = token
end
return tokens
end
function mapper:tokensToText(tokens)
local text = ""
for i, t in ipairs(tokens) do
text = text..self.token2alphabet[tokens[i]]
end
return text
end