diff --git a/src/kwja/datamodule/datasets/seq2seq.py b/src/kwja/datamodule/datasets/seq2seq.py index d53e358a..2aa7d4f4 100644 --- a/src/kwja/datamodule/datasets/seq2seq.py +++ b/src/kwja/datamodule/datasets/seq2seq.py @@ -10,6 +10,7 @@ from kwja.datamodule.examples import Seq2SeqExample from kwja.utils.constants import IGNORE_INDEX from kwja.utils.logging_util import track +from kwja.utils.normalization import normalize_text from kwja.utils.seq2seq_format import Seq2SeqFormatter logger = logging.getLogger(__name__) @@ -111,3 +112,14 @@ def encode(self, example: Seq2SeqExample) -> Seq2SeqModuleFeatures: attention_mask=example.src_attention_mask, seq2seq_labels=seq2seq_labels, ) + + def _postprocess_document(self, document: Document) -> Document: + for morpheme in document.morphemes: + normalized = normalize_text(morpheme.text) + if normalized != morpheme.text: + logger.warning(f"apply normalization ({morpheme.text} -> {normalized})") + morpheme.text = normalized + morpheme.reading = normalize_text(morpheme.reading) + morpheme.lemma = normalize_text(morpheme.lemma) + # propagate updates of morpheme.text to sentence.text and document.text + return document.reparse() diff --git a/src/kwja/modules/components/logits_processor.py b/src/kwja/modules/components/logits_processor.py index 2f6cc93e..d42f0b10 100644 --- a/src/kwja/modules/components/logits_processor.py +++ b/src/kwja/modules/components/logits_processor.py @@ -6,17 +6,7 @@ from transformers import PreTrainedTokenizerFast from transformers.generation import LogitsProcessor -from kwja.utils.constants import ( - CANON_TOKEN, - FULL_SPACE_TOKEN, - HALF_SPACE_TOKEN1, - HALF_SPACE_TOKEN2, - LEMMA_TOKEN, - READING_TOKEN, - SPECIAL_TO_RARE, - SURF_TOKEN, - TRIPLE_DOT_TOKEN, -) +from kwja.utils.constants import CANON_TOKEN, HALF_SPACE_TOKEN, LEMMA_TOKEN, READING_TOKEN, SPECIAL_TO_RARE, SURF_TOKEN KANJI_KATAKANA_PATTERN = r"[\p{Script=Han}\p{Script=Katakana}]" @@ -78,9 +68,7 @@ def __init__( self.ids_except_kanji_and_katakana: Set[int] = set(self.tokenizer.get_vocab().values()) - reading_candidates self.token_to_ids_except_token: Dict[str, Set[int]] = {} - special_tokens: List[str] = [FULL_SPACE_TOKEN, HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, TRIPLE_DOT_TOKEN] + list( - SPECIAL_TO_RARE.keys() - ) + special_tokens: List[str] = [HALF_SPACE_TOKEN] + list(SPECIAL_TO_RARE.keys()) for special_token in special_tokens: self.token_to_ids_except_token[special_token] = set(self.tokenizer.get_vocab().values()) - { self.tokenizer.convert_tokens_to_ids(special_token) diff --git a/src/kwja/utils/constants.py b/src/kwja/utils/constants.py index a1bf05f0..4c64f219 100644 --- a/src/kwja/utils/constants.py +++ b/src/kwja/utils/constants.py @@ -207,35 +207,32 @@ READING_TOKEN: str = "" LEMMA_TOKEN: str = "" CANON_TOKEN: str = "" -# tokens to represent full space, half space, no canonical form, and triple dots +# tokens to represent no canonical form and half space NO_CANON_TOKEN: str = "" -FULL_SPACE_TOKEN: str = "" -HALF_SPACE_TOKEN1: str = "" -HALF_SPACE_TOKEN2: str = "" -TRIPLE_DOT_TOKEN: str = "" +HALF_SPACE_TOKEN: str = "" # token to split input text into morphemes -MORPHEME_SPLIT_TOKEN: str = "" +MORPHEME_SPLIT_TOKEN: str = "" # tokens for unk tokens RARE_TO_SPECIAL: Dict[str, str] = { - "ゔ": "", - "榕": "", - "謄": "", - "丿": "", - "孜": "", - "腑": "", - "庖": "", - "┘": "", - "秧": "", - "褪": "", - "疥": "", - "鮪": "", - "髑髏": "", - "侭": "", - "蒟蒻": "", - "╹": "", - "厂": "", - "Ӧ": "", - "溢": "", + "ゔ": "", + "榕": "", + "謄": "", + "丿": "", + "孜": "", + "腑": "", + "庖": "", + "┘": "", + "秧": "", + "褪": "", + "疥": "", + "鮪": "", + "髑髏": "", + "侭": "", + "蒟蒻": "", + "╹": "", + "厂": "", + "Ӧ": "", + "溢": "", } SPECIAL_TO_RARE: Dict[str, str] = {v: k for k, v in RARE_TO_SPECIAL.items()} diff --git a/src/kwja/utils/seq2seq_format.py b/src/kwja/utils/seq2seq_format.py index a00a4d27..a4242be2 100644 --- a/src/kwja/utils/seq2seq_format.py +++ b/src/kwja/utils/seq2seq_format.py @@ -5,9 +5,7 @@ from kwja.utils.constants import ( CANON_TOKEN, - FULL_SPACE_TOKEN, - HALF_SPACE_TOKEN1, - HALF_SPACE_TOKEN2, + HALF_SPACE_TOKEN, LEMMA_TOKEN, MORPHEME_SPLIT_TOKEN, NO_CANON_TOKEN, @@ -15,7 +13,6 @@ READING_TOKEN, SPECIAL_TO_RARE, SURF_TOKEN, - TRIPLE_DOT_TOKEN, ) @@ -23,12 +20,7 @@ class Seq2SeqFormatter: def __init__(self, tokenizer: PreTrainedTokenizerFast) -> None: self.tokenizer: PreTrainedTokenizerFast = tokenizer - self.word_to_token: Dict[str, str] = { - "\u3000": FULL_SPACE_TOKEN, - " ": HALF_SPACE_TOKEN1, - "␣": HALF_SPACE_TOKEN2, - "…": TRIPLE_DOT_TOKEN, - } + self.word_to_token: Dict[str, str] = {" ": HALF_SPACE_TOKEN} self.token_to_word: Dict[str, str] = {v: k for k, v in self.word_to_token.items()} def get_surfs(self, sentence: Sentence) -> List[str]: @@ -37,7 +29,6 @@ def get_surfs(self, sentence: Sentence) -> List[str]: surf: str = morpheme.surf for k, v in self.word_to_token.items(): surf = surf.replace(k, v) - surf = surf.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1) for k, v in RARE_TO_SPECIAL.items(): surf = surf.replace(k, v) tokenized_surf: List[str] = [x for x in self.tokenizer.tokenize(surf) if x != "▁"] @@ -46,14 +37,13 @@ def get_surfs(self, sentence: Sentence) -> List[str]: decoded = decoded.replace(f"{token} ", token) for token in SPECIAL_TO_RARE: decoded = decoded.replace(f"{token} ", token) - surfs.append(decoded.replace(" ", HALF_SPACE_TOKEN1)) + surfs.append(decoded.replace(" ", HALF_SPACE_TOKEN)) return surfs def get_src_tokens(self, sentence: Sentence) -> List[str]: src_text: str = MORPHEME_SPLIT_TOKEN.join(m.surf for m in sentence.morphemes) for k, v in self.word_to_token.items(): src_text = src_text.replace(k, v) - src_text = src_text.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1) for k, v in RARE_TO_SPECIAL.items(): src_text = src_text.replace(k, v) return [x for x in self.tokenizer.tokenize(src_text) if x != "▁"] @@ -61,30 +51,18 @@ def get_src_tokens(self, sentence: Sentence) -> List[str]: def get_tgt_tokens(self, sentence: Sentence) -> List[str]: seq2seq_format: str = "" for morpheme in sentence.morphemes: - if morpheme.surf == "\u3000": - surf: str = FULL_SPACE_TOKEN - reading: str = FULL_SPACE_TOKEN - lemma: str = FULL_SPACE_TOKEN - canon: str = "/" - elif morpheme.surf == " ": - surf = HALF_SPACE_TOKEN1 - reading = HALF_SPACE_TOKEN1 - lemma = HALF_SPACE_TOKEN1 + if morpheme.surf == " ": + surf = HALF_SPACE_TOKEN + reading = HALF_SPACE_TOKEN + lemma = HALF_SPACE_TOKEN canon = "/" - elif morpheme.surf == "…": - surf = TRIPLE_DOT_TOKEN - reading = TRIPLE_DOT_TOKEN - lemma = TRIPLE_DOT_TOKEN - canon = f"{TRIPLE_DOT_TOKEN}/{TRIPLE_DOT_TOKEN}" else: surf = morpheme.surf - if morpheme.reading == "\u3000": - reading = FULL_SPACE_TOKEN - elif "/" in morpheme.reading and len(morpheme.reading) > 1: + if "/" in morpheme.reading and len(morpheme.reading) > 1: reading = morpheme.reading.split("/")[0] else: reading = morpheme.reading - lemma = FULL_SPACE_TOKEN if morpheme.lemma == "\u3000" else morpheme.lemma + lemma = morpheme.lemma if morpheme.canon is not None: canon = morpheme.canon canon_list: List[str] = canon.split("/") @@ -123,18 +101,13 @@ def format_to_sent(self, text: str) -> Sentence: canon = canon.replace(k, v) for k, v in SPECIAL_TO_RARE.items(): canon = canon.replace(k, v) - canon = ( - f"{self.token_to_word[TRIPLE_DOT_TOKEN]}/{self.token_to_word[TRIPLE_DOT_TOKEN]}" - if canon == f"{TRIPLE_DOT_TOKEN}/{TRIPLE_DOT_TOKEN}" - else canon - ) canon = f'"代表表記:{canon}"' if canon != NO_CANON_TOKEN else "NIL" # 例外処理 - if surf == " " and reading == "\u3000" and lemma == "\u3000": - surf = "\u3000" - if surf == "°C": - surf, lemma, canon = "℃", "℃", '"代表表記:℃/ど"' + if surf == " ": + reading = " " + lemma = " " + canon = r'"代表表記:\␣/\␣"' formatted += f"{surf} {reading} {lemma} 未定義語 15 その他 1 * 0 * 0 {canon}\n" except IndexError: