Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

新加规则,转换错误问题 #191

Open
jeeveenn opened this issue Mar 27, 2024 · 1 comment
Open

新加规则,转换错误问题 #191

jeeveenn opened this issue Mar 27, 2024 · 1 comment

Comments

@jeeveenn
Copy link

jeeveenn commented Mar 27, 2024

您好,我们新增一个规则,实现把高铁的 高->G、动->D 的转换,在rules中新增规则如下:

from tn.processor import Processor
from pynini import string_file
from pynini.lib.pynutil import delete, insert


class FlightTrainCode(Processor):
    def __init__(self):
        super().__init__(name='flighttraincode')
        self.build_tagger()
        self.build_verbalizer()

    def build_tagger(self):
        digit = string_file('itn/chinese/data/number/digit.tsv')  # 1 ~ 9
        zero = string_file('itn/chinese/data/number/zero.tsv')  # 0
        # 有一个文件来定义识别高铁、飞机等特定字符的规则
        train_head = string_file('itn/chinese/data/train/train_tou.tsv')  # Example: 高 -> G

        digit_4 = digit + (digit | zero)**3 
        digit_3 = digit + (digit | zero)**2  
        digit_2 = digit + (digit | zero)  
        digit_1 = digit  

        tagger = (insert('head: "') + train_head + insert('"') +
                  insert(' number: "') +  (digit_4|digit_3|digit_2|digit_1) + insert('"'))
        self.tagger = self.add_tokens(tagger)

    def build_verbalizer(self):
        head = delete('head: "') + self.SIGMA + delete('"')
        number = delete(' number: "') + self.SIGMA + delete('"')
        verbalizer = head + number
        self.verbalizer = self.delete_tokens(verbalizer)

itn/chinese/inverse_normalizer.py的修改如下:

from tn.processor import Processor
from itn.chinese.rules.cardinal import Cardinal
from itn.chinese.rules.char import Char
from itn.chinese.rules.date import Date
from itn.chinese.rules.fraction import Fraction
from itn.chinese.rules.math import Math
from itn.chinese.rules.measure import Measure
from itn.chinese.rules.money import Money
from itn.chinese.rules.whitelist import Whitelist
from itn.chinese.rules.time import Time
from itn.chinese.rules.postprocessor import PostProcessor
from itn.chinese.rules.license_plate import LicensePlate
from itn.chinese.rules.flighttraincode import FlightTrainCode

from pynini.lib.pynutil import add_weight, delete
from importlib_resources import files


class InverseNormalizer(Processor):

    def __init__(self,
                 cache_dir=None,
                 overwrite_cache=False,
                 enable_standalone_number=True,
                 enable_0_to_9=False,
                 enable_million=False):
        super().__init__(name='inverse_normalizer', ordertype='itn')
        self.convert_number = enable_standalone_number
        self.enable_0_to_9 = enable_0_to_9
        self.enable_million = enable_million
        if cache_dir is None:
            cache_dir = files("itn")
        self.build_fst('zh_itn', cache_dir, overwrite_cache)

    def build_tagger(self):
        tagger = (add_weight(Date().tagger, 1.02)
                  | add_weight(Whitelist().tagger, 1.01)
                  | add_weight(Fraction().tagger, 1.05)
                  | add_weight(
                      Measure(enable_0_to_9=self.enable_0_to_9).tagger, 1.05)
                  | add_weight(
                      Money(enable_0_to_9=self.enable_0_to_9).tagger, 1.04)
                  | add_weight(Time().tagger, 1.05)
                  | add_weight(
                      Cardinal(self.convert_number, self.enable_0_to_9,
                               self.enable_million).tagger, 1.06)
                  | add_weight(Math().tagger, 1.10)
                  | add_weight(LicensePlate().tagger, 1.0)
                  | add_weight(FlightTrainCode().tagger, 1.07)
                  | add_weight(Char().tagger, 100)).optimize()

        tagger = tagger.star
        # remove the last space
        self.tagger = tagger @ self.build_rule(delete(' '), '', '[EOS]')

    def build_verbalizer(self):
        verbalizer = (Cardinal(self.convert_number, self.enable_0_to_9,
                               self.enable_million).verbalizer
                      | Char().verbalizer
                      | Date().verbalizer
                      | Fraction().verbalizer
                      | Math().verbalizer
                      | Measure(enable_0_to_9=self.enable_0_to_9).verbalizer
                      | Money(enable_0_to_9=self.enable_0_to_9).verbalizer
                      | Time().verbalizer
                      | LicensePlate().verbalizer
                      | FlightTrainCode().verbalizer
                      | Whitelist().verbalizer).optimize()
        postprocessor = PostProcessor(remove_interjections=True).processor

        self.verbalizer = (verbalizer @ postprocessor).star

识别结果出现顺序错误问题,输入:
python -m itn --text "明天的动二" --overwrite_cache
输出如下:

" number: "2" }明" } char { value: "天" } char { value: "的" } flighttraincode { head: "D 
2 天的D 
@pengzhendong
Copy link
Member

先看 tagger 的输出合不合理,再看 verbalizer

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants