Skip to content

Commit

Permalink
feat(all): format all files (#174)
Browse files Browse the repository at this point in the history
* feat(all): format all files

* feat(all): format all files

* feat(all): format all files
  • Loading branch information
xingchensong authored Dec 7, 2023
1 parent bd44df0 commit 85581f4
Show file tree
Hide file tree
Showing 24 changed files with 286 additions and 279 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
exclude: '.*\.(txt|tsv)$'
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
51 changes: 30 additions & 21 deletions itn/chinese/inverse_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@

class InverseNormalizer(Processor):

def __init__(self, cache_dir=None, overwrite_cache=False,
def __init__(self,
cache_dir=None,
overwrite_cache=False,
enable_standalone_number=True,
enable_0_to_9=False,
enable_million=False):
Expand All @@ -44,32 +46,39 @@ def __init__(self, cache_dir=None, overwrite_cache=False,
self.build_fst('zh_itn', cache_dir, overwrite_cache)

def build_tagger(self):
tagger = (add_weight(Date().tagger, 1.02)
| add_weight(Whitelist().tagger, 1.01)
| add_weight(Fraction().tagger, 1.05)
| add_weight(Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05) # noqa
| add_weight(Money(enable_0_to_9=self.enable_0_to_9).tagger, 1.04) # noqa
| add_weight(Time().tagger, 1.05)
| add_weight(Cardinal(self.convert_number, self.enable_0_to_9, self.enable_million).tagger, 1.06) # noqa
| add_weight(Math().tagger, 1.10)
| add_weight(LicensePlate().tagger, 1.0)
| add_weight(Char().tagger, 100)).optimize()
tagger = (
add_weight(Date().tagger, 1.02)
| add_weight(Whitelist().tagger, 1.01)
| add_weight(Fraction().tagger, 1.05)
| add_weight(
Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05) # noqa
| add_weight(Money(enable_0_to_9=self.enable_0_to_9).tagger,
1.04) # noqa
| add_weight(Time().tagger, 1.05)
| add_weight(
Cardinal(self.convert_number, self.enable_0_to_9,
self.enable_million).tagger, 1.06) # noqa
| add_weight(Math().tagger, 1.10)
| add_weight(LicensePlate().tagger, 1.0)
| add_weight(Char().tagger, 100)).optimize()

tagger = tagger.star
# remove the last space
self.tagger = tagger @ self.build_rule(delete(' '), '', '[EOS]')

def build_verbalizer(self):
verbalizer = (Cardinal(self.convert_number, self.enable_0_to_9, self.enable_million).verbalizer # noqa
| Char().verbalizer
| Date().verbalizer
| Fraction().verbalizer
| Math().verbalizer
| Measure(enable_0_to_9=self.enable_0_to_9).verbalizer
| Money(enable_0_to_9=self.enable_0_to_9).verbalizer
| Time().verbalizer
| LicensePlate().verbalizer
| Whitelist().verbalizer).optimize()
verbalizer = (
Cardinal(self.convert_number, self.enable_0_to_9,
self.enable_million).verbalizer # noqa
| Char().verbalizer
| Date().verbalizer
| Fraction().verbalizer
| Math().verbalizer
| Measure(enable_0_to_9=self.enable_0_to_9).verbalizer
| Money(enable_0_to_9=self.enable_0_to_9).verbalizer
| Time().verbalizer
| LicensePlate().verbalizer
| Whitelist().verbalizer).optimize()
postprocessor = PostProcessor(remove_interjections=True).processor

self.verbalizer = (verbalizer @ postprocessor).star
80 changes: 41 additions & 39 deletions itn/chinese/rules/cardinal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@

class Cardinal(Processor):

def __init__(self, enable_standalone_number=True, enable_0_to_9=True,
def __init__(self,
enable_standalone_number=True,
enable_0_to_9=True,
enable_million=False):
super().__init__('cardinal')
self.number = None
Expand All @@ -32,10 +34,10 @@ def __init__(self, enable_standalone_number=True, enable_0_to_9=True,
self.build_verbalizer()

def build_tagger(self):
zero = string_file('itn/chinese/data/number/zero.tsv') # 0
zero = string_file('itn/chinese/data/number/zero.tsv') # 0
digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9
sign = string_file('itn/chinese/data/number/sign.tsv') # + -
dot = string_file('itn/chinese/data/number/dot.tsv') # .
sign = string_file('itn/chinese/data/number/sign.tsv') # + -
dot = string_file('itn/chinese/data/number/dot.tsv') # .

addzero = insert('0')
digits = zero | digit # 0 ~ 9
Expand All @@ -52,33 +54,33 @@ def build_tagger(self):
| add_weight(addzero**2, 1.0)))
# 一千一百一十一 => 1111, 一千零一十一 => 1011, 一千零一 => 1001
# 一千一 => 1100, 一千 => 1000
thousand = ((hundred | teen | tens | digits) + delete('千') + (
hundred
| add_weight(zero + (tens | teen), 0.1)
| add_weight(addzero + zero + digit, 0.5)
| add_weight(digit + addzero**2, 0.8)
| add_weight(addzero**3, 1.0)))
thousand = ((hundred | teen | tens | digits) + delete('千') +
(hundred
| add_weight(zero + (tens | teen), 0.1)
| add_weight(addzero + zero + digit, 0.5)
| add_weight(digit + addzero**2, 0.8)
| add_weight(addzero**3, 1.0)))
# 10001111, 1001111, 101111, 11111, 10111, 10011, 10001, 10000
if self.enable_million:
ten_thousand = ((thousand | hundred | teen | tens | digits)
+ delete('万')
+ (thousand
| add_weight(zero + hundred, 0.1)
| add_weight(addzero + zero + (tens | teen), 0.5)
| add_weight(addzero + addzero + zero + digit, 0.5)
| add_weight(digit + addzero**3, 0.8)
| add_weight(addzero**4, 1.0)))
ten_thousand = (
(thousand | hundred | teen | tens | digits) + delete('万') +
(thousand
| add_weight(zero + hundred, 0.1)
| add_weight(addzero + zero + (tens | teen), 0.5)
| add_weight(addzero + addzero + zero + digit, 0.5)
| add_weight(digit + addzero**3, 0.8)
| add_weight(addzero**4, 1.0)))
else:
ten_thousand = ((teen | tens | digits)
+ delete('万')
+ (thousand
| add_weight(zero + hundred, 0.1)
| add_weight(addzero + zero + (tens | teen), 0.5)
| add_weight(addzero + addzero + zero + digit, 0.5)
| add_weight(digit + addzero**3, 0.8)
| add_weight(addzero**4, 1.0)))
ten_thousand |= (thousand | hundred) + accep("万") + delete("零").ques + (
thousand | hundred | tens | teen | digits).ques
ten_thousand = (
(teen | tens | digits) + delete('万') +
(thousand
| add_weight(zero + hundred, 0.1)
| add_weight(addzero + zero + (tens | teen), 0.5)
| add_weight(addzero + addzero + zero + digit, 0.5)
| add_weight(digit + addzero**3, 0.8)
| add_weight(addzero**4, 1.0)))
ten_thousand |= (thousand | hundred) + accep("万") + delete(
"零").ques + (thousand | hundred | tens | teen | digits).ques
# 个/十/百/千/万
number = digits | teen | tens | hundred | thousand | ten_thousand
# 兆/亿
Expand Down Expand Up @@ -107,31 +109,31 @@ def build_tagger(self):
# 十/百/千/万
number_exclude_0_to_9 = teen | tens | hundred | thousand | ten_thousand
# 兆/亿
number_exclude_0_to_9 = (
((number_exclude_0_to_9 | digits) + accep('兆') + delete('零').ques).ques +
((number_exclude_0_to_9 | digits) + accep('亿') + delete('零').ques).ques +
number_exclude_0_to_9
)
number_exclude_0_to_9 = (((number_exclude_0_to_9 | digits) +
accep('兆') + delete('零').ques).ques +
((number_exclude_0_to_9 | digits) +
accep('亿') + delete('零').ques).ques +
number_exclude_0_to_9)
# 负的xxx 1.11, 1.01
number_exclude_0_to_9 |= (
(number_exclude_0_to_9 | digits) +
(dot + digits.plus).plus
)
number_exclude_0_to_9 |= ((number_exclude_0_to_9 | digits) +
(dot + digits.plus).plus)
# 五六万,三五千,六七百,三四十
# 十七八美元 => $17~18, 四十五六岁 => 45-6岁,
# 三百七八公里 => 370-80km, 三百七八十千克 => 370-80kg
number_exclude_0_to_9 |= special_2number
number_exclude_0_to_9 |= add_weight(special_3number, -100.0)

self.number_exclude_0_to_9 = (sign.ques + number_exclude_0_to_9).optimize() # noqa
self.number_exclude_0_to_9 = (sign.ques +
number_exclude_0_to_9).optimize() # noqa

# cardinal string like 127.0.0.1, used in ID, IP, etc.
cardinal = digits.plus + (dot + digits.plus).plus
# float number like 1.11
cardinal |= (number + dot + digits.plus)
# cardinal string like 110 or 12306 or 13125617878, used in phone,
# 340621199806051223, used in ID card
cardinal |= (digits**3 | digits**4 | digits**5 | digits**11 | digits**18)
cardinal |= (digits**3 | digits**4 | digits**5 | digits**11
| digits**18)
# cardinal string like 23
if self.enable_standalone_number:
if self.enable_0_to_9:
Expand Down
6 changes: 3 additions & 3 deletions itn/chinese/rules/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def __init__(self):

def build_tagger(self):
digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9
zero = string_file('itn/chinese/data/number/zero.tsv') # 0
zero = string_file('itn/chinese/data/number/zero.tsv') # 0

yyyy = digit + (digit | zero)**3 # 二零零八年
yyy = digit + (digit | zero)**2 # 公元一六八年
yy = (digit | zero)**2 # 零八年奥运会
yyy = digit + (digit | zero)**2 # 公元一六八年
yy = (digit | zero)**2 # 零八年奥运会
mm = string_file('itn/chinese/data/date/mm.tsv')
dd = string_file('itn/chinese/data/date/dd.tsv')

Expand Down
7 changes: 3 additions & 4 deletions itn/chinese/rules/fraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,15 @@ def __init__(self):

def build_tagger(self):
number = Cardinal().number
sign = string_file('itn/chinese/data/number/sign.tsv') # + -
sign = string_file('itn/chinese/data/number/sign.tsv') # + -

# NOTE(xcsong): default weight = 1.0, set to -1.0 means higher priority
# For example,
# 1.0, 负二分之三 -> { sign: "" denominator: "-2" numerator: "3" }
# -1.0,负二分之三 -> { sign: "-" denominator: "2" numerator: "3" }
tagger = (insert('sign: "') + add_weight(sign, -1.0).ques +
insert('" denominator: "') + number +
delete('分之') + insert('" numerator: "') +
number + insert('"'))
insert('" denominator: "') + number + delete('分之') +
insert('" numerator: "') + number + insert('"'))
self.tagger = self.add_tokens(tagger)

def build_verbalizer(self):
Expand Down
3 changes: 2 additions & 1 deletion itn/chinese/rules/license_plate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self):

def build_tagger(self):
digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9
province = string_file('itn/chinese/data/license_plate/province.tsv') # 皖
province = string_file(
'itn/chinese/data/license_plate/province.tsv') # 皖
license_plate = province + self.ALPHA + (self.ALPHA | digit)**5
tagger = insert('value: "') + license_plate + insert('"')
self.tagger = self.add_tokens(tagger)
21 changes: 10 additions & 11 deletions itn/chinese/rules/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,23 @@ def build_tagger(self):
units_en = string_file('itn/chinese/data/measure/units_en.tsv')
units_zh = string_file('itn/chinese/data/measure/units_zh.tsv')
digit = string_file('itn/chinese/data/number/digit.tsv') # 1 ~ 9
sign = string_file('itn/chinese/data/number/sign.tsv') # + -
sign = string_file('itn/chinese/data/number/sign.tsv') # + -
to = cross('到', '~') | cross('到百分之', '~')

units = add_weight((accep('亿') | accep('兆') | accep('万')), -0.5).ques + units_zh
units |= add_weight((cross('亿', '00M') | cross('兆', 'T') |
cross('万', 'W')), -0.5).ques + (
add_weight(units_en, -1.0)
)
units = add_weight(
(accep('亿') | accep('兆') | accep('万')), -0.5).ques + units_zh
units |= add_weight(
(cross('亿', '00M') | cross('兆', 'T') | cross('万', 'W')),
-0.5).ques + (add_weight(units_en, -1.0))

number = Cardinal().number if self.enable_0_to_9 else \
Cardinal().number_exclude_0_to_9
# 百分之三十, 百分三十, 百分之百,百分之三十到四十, 百分之三十到百分之五十五
percent = ((sign + delete('的').ques).ques + delete('百分') +
delete('之').ques +
((Cardinal().number + (to + Cardinal().number).ques) |
((Cardinal().number + to).ques + cross('百', '100')))
+ insert('%'))
((Cardinal().number + to).ques + cross('百', '100'))) +
insert('%'))

# 十千米每小时 => 10km/h, 十一到一百千米每小时 => 11~100km/h
measure = number + (to + number).ques + units
Expand All @@ -57,9 +57,8 @@ def build_tagger(self):
tagger = insert('value: "') + (measure | percent) + insert('"')

# 每小时十千米 => 10km/h, 每小时三十到三百一十一千米 => 30~311km/h
tagger |= (
insert('denominator: "') + delete('每') + units +
insert('" numerator: "') + measure + insert('"'))
tagger |= (insert('denominator: "') + delete('每') + units +
insert('" numerator: "') + measure + insert('"'))

self.tagger = self.add_tokens(tagger)

Expand Down
6 changes: 3 additions & 3 deletions itn/chinese/rules/money.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def build_tagger(self):
# 三千三百八十元五毛八分 => ¥3380.58
tagger = (insert('value: "') + number + insert('"') +
insert(' currency: "') + (code | symbol) + insert('"') +
insert(' decimal: "') + (
insert(".") + digit + (delete("毛") | delete("角")) + (digit + delete("分")).ques
).ques + insert('"'))
insert(' decimal: "') +
(insert(".") + digit + (delete("毛") | delete("角")) +
(digit + delete("分")).ques).ques + insert('"'))
self.tagger = self.add_tokens(tagger)

def build_verbalizer(self):
Expand Down
13 changes: 6 additions & 7 deletions itn/chinese/rules/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,10 @@ def build_tagger(self):
s = string_file('itn/chinese/data/time/second.tsv')
noon = string_file('itn/chinese/data/time/noon.tsv')

tagger = (
(insert('noon: "') + noon + insert('" ')).ques +
insert('hour: "') + h + insert('"') +
insert(' minute: "') + m + delete('分').ques + insert('"') +
(insert(' second: "') + s + insert('"')).ques)
tagger = ((insert('noon: "') + noon + insert('" ')).ques +
insert('hour: "') + h + insert('"') + insert(' minute: "') +
m + delete('分').ques + insert('"') +
(insert(' second: "') + s + insert('"')).ques)
self.tagger = self.add_tokens(tagger)

def build_verbalizer(self):
Expand All @@ -44,6 +43,6 @@ def build_verbalizer(self):
minute = delete(' minute: "') + self.SIGMA + delete('"')
second = delete(' second: "') + self.SIGMA + delete('"')
noon = delete(' noon: "') + self.SIGMA + delete('"')
verbalizer = (hour + addcolon + minute +
(addcolon + second).ques + noon.ques)
verbalizer = (hour + addcolon + minute + (addcolon + second).ques +
noon.ques)
self.verbalizer = self.delete_tokens(verbalizer)
Loading

0 comments on commit 85581f4

Please sign in to comment.