Skip to content

Commit

Permalink
remove full-width space and triple dot tokens from seq2seq module
Browse files Browse the repository at this point in the history
  • Loading branch information
omukazu committed May 14, 2024
1 parent b0d7654 commit a658b0c
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 79 deletions.
12 changes: 12 additions & 0 deletions src/kwja/datamodule/datasets/seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from kwja.datamodule.examples import Seq2SeqExample
from kwja.utils.constants import IGNORE_INDEX
from kwja.utils.logging_util import track
from kwja.utils.normalization import normalize_text
from kwja.utils.seq2seq_format import Seq2SeqFormatter

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -111,3 +112,14 @@ def encode(self, example: Seq2SeqExample) -> Seq2SeqModuleFeatures:
attention_mask=example.src_attention_mask,
seq2seq_labels=seq2seq_labels,
)

def _postprocess_document(self, document: Document) -> Document:
for morpheme in document.morphemes:
normalized = normalize_text(morpheme.text)
if normalized != morpheme.text:
logger.warning(f"apply normalization ({morpheme.text} -> {normalized})")
morpheme.text = normalized
morpheme.reading = normalize_text(morpheme.reading)
morpheme.lemma = normalize_text(morpheme.lemma)
# propagate updates of morpheme.text to sentence.text and document.text
return document.reparse()
16 changes: 2 additions & 14 deletions src/kwja/modules/components/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,7 @@
from transformers import PreTrainedTokenizerFast
from transformers.generation import LogitsProcessor

from kwja.utils.constants import (
CANON_TOKEN,
FULL_SPACE_TOKEN,
HALF_SPACE_TOKEN1,
HALF_SPACE_TOKEN2,
LEMMA_TOKEN,
READING_TOKEN,
SPECIAL_TO_RARE,
SURF_TOKEN,
TRIPLE_DOT_TOKEN,
)
from kwja.utils.constants import CANON_TOKEN, HALF_SPACE_TOKEN, LEMMA_TOKEN, READING_TOKEN, SPECIAL_TO_RARE, SURF_TOKEN

KANJI_KATAKANA_PATTERN = r"[\p{Script=Han}\p{Script=Katakana}]"

Expand Down Expand Up @@ -78,9 +68,7 @@ def __init__(
self.ids_except_kanji_and_katakana: Set[int] = set(self.tokenizer.get_vocab().values()) - reading_candidates

self.token_to_ids_except_token: Dict[str, Set[int]] = {}
special_tokens: List[str] = [FULL_SPACE_TOKEN, HALF_SPACE_TOKEN1, HALF_SPACE_TOKEN2, TRIPLE_DOT_TOKEN] + list(
SPECIAL_TO_RARE.keys()
)
special_tokens: List[str] = [HALF_SPACE_TOKEN] + list(SPECIAL_TO_RARE.keys())
for special_token in special_tokens:
self.token_to_ids_except_token[special_token] = set(self.tokenizer.get_vocab().values()) - {
self.tokenizer.convert_tokens_to_ids(special_token)
Expand Down
47 changes: 22 additions & 25 deletions src/kwja/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,35 +207,32 @@
READING_TOKEN: str = "<extra_id_1>"
LEMMA_TOKEN: str = "<extra_id_2>"
CANON_TOKEN: str = "<extra_id_3>"
# tokens to represent full space, half space, no canonical form, and triple dots
# tokens to represent no canonical form and half space
NO_CANON_TOKEN: str = "<extra_id_4>"
FULL_SPACE_TOKEN: str = "<extra_id_5>"
HALF_SPACE_TOKEN1: str = "<extra_id_6>"
HALF_SPACE_TOKEN2: str = "<extra_id_7>"
TRIPLE_DOT_TOKEN: str = "<extra_id_8>"
HALF_SPACE_TOKEN: str = "<extra_id_5>"
# token to split input text into morphemes
MORPHEME_SPLIT_TOKEN: str = "<extra_id_9>"
MORPHEME_SPLIT_TOKEN: str = "<extra_id_6>"
# tokens for unk tokens
RARE_TO_SPECIAL: Dict[str, str] = {
"ゔ": "<extra_id_10>",
"榕": "<extra_id_11>",
"謄": "<extra_id_12>",
"丿": "<extra_id_13>",
"孜": "<extra_id_14>",
"腑": "<extra_id_15>",
"庖": "<extra_id_16>",
"┘": "<extra_id_17>",
"秧": "<extra_id_18>",
"褪": "<extra_id_19>",
"疥": "<extra_id_20>",
"鮪": "<extra_id_21>",
"髑髏": "<extra_id_22>",
"侭": "<extra_id_23>",
"蒟蒻": "<extra_id_24>",
"╹": "<extra_id_25>",
"厂": "<extra_id_26>",
"Ӧ": "<extra_id_27>",
"溢": "<extra_id_28>",
"ゔ": "<extra_id_7>",
"榕": "<extra_id_8>",
"謄": "<extra_id_9>",
"丿": "<extra_id_10>",
"孜": "<extra_id_11>",
"腑": "<extra_id_12>",
"庖": "<extra_id_13>",
"┘": "<extra_id_14>",
"秧": "<extra_id_15>",
"褪": "<extra_id_16>",
"疥": "<extra_id_17>",
"鮪": "<extra_id_18>",
"髑髏": "<extra_id_19>",
"侭": "<extra_id_20>",
"蒟蒻": "<extra_id_21>",
"╹": "<extra_id_22>",
"厂": "<extra_id_23>",
"Ӧ": "<extra_id_24>",
"溢": "<extra_id_25>",
}
SPECIAL_TO_RARE: Dict[str, str] = {v: k for k, v in RARE_TO_SPECIAL.items()}

Expand Down
53 changes: 13 additions & 40 deletions src/kwja/utils/seq2seq_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,22 @@

from kwja.utils.constants import (
CANON_TOKEN,
FULL_SPACE_TOKEN,
HALF_SPACE_TOKEN1,
HALF_SPACE_TOKEN2,
HALF_SPACE_TOKEN,
LEMMA_TOKEN,
MORPHEME_SPLIT_TOKEN,
NO_CANON_TOKEN,
RARE_TO_SPECIAL,
READING_TOKEN,
SPECIAL_TO_RARE,
SURF_TOKEN,
TRIPLE_DOT_TOKEN,
)


class Seq2SeqFormatter:
def __init__(self, tokenizer: PreTrainedTokenizerFast) -> None:
self.tokenizer: PreTrainedTokenizerFast = tokenizer

self.word_to_token: Dict[str, str] = {
"\u3000": FULL_SPACE_TOKEN,
" ": HALF_SPACE_TOKEN1,
"␣": HALF_SPACE_TOKEN2,
"…": TRIPLE_DOT_TOKEN,
}
self.word_to_token: Dict[str, str] = {" ": HALF_SPACE_TOKEN}
self.token_to_word: Dict[str, str] = {v: k for k, v in self.word_to_token.items()}

def get_surfs(self, sentence: Sentence) -> List[str]:
Expand All @@ -37,7 +29,6 @@ def get_surfs(self, sentence: Sentence) -> List[str]:
surf: str = morpheme.surf
for k, v in self.word_to_token.items():
surf = surf.replace(k, v)
surf = surf.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1)
for k, v in RARE_TO_SPECIAL.items():
surf = surf.replace(k, v)
tokenized_surf: List[str] = [x for x in self.tokenizer.tokenize(surf) if x != "▁"]
Expand All @@ -46,45 +37,32 @@ def get_surfs(self, sentence: Sentence) -> List[str]:
decoded = decoded.replace(f"{token} ", token)
for token in SPECIAL_TO_RARE:
decoded = decoded.replace(f"{token} ", token)
surfs.append(decoded.replace(" ", HALF_SPACE_TOKEN1))
surfs.append(decoded.replace(" ", HALF_SPACE_TOKEN))
return surfs

def get_src_tokens(self, sentence: Sentence) -> List[str]:
src_text: str = MORPHEME_SPLIT_TOKEN.join(m.surf for m in sentence.morphemes)
for k, v in self.word_to_token.items():
src_text = src_text.replace(k, v)
src_text = src_text.replace(HALF_SPACE_TOKEN2, HALF_SPACE_TOKEN1)
for k, v in RARE_TO_SPECIAL.items():
src_text = src_text.replace(k, v)
return [x for x in self.tokenizer.tokenize(src_text) if x != "▁"]

def get_tgt_tokens(self, sentence: Sentence) -> List[str]:
seq2seq_format: str = ""
for morpheme in sentence.morphemes:
if morpheme.surf == "\u3000":
surf: str = FULL_SPACE_TOKEN
reading: str = FULL_SPACE_TOKEN
lemma: str = FULL_SPACE_TOKEN
canon: str = "/"
elif morpheme.surf == " ":
surf = HALF_SPACE_TOKEN1
reading = HALF_SPACE_TOKEN1
lemma = HALF_SPACE_TOKEN1
if morpheme.surf == " ":
surf = HALF_SPACE_TOKEN
reading = HALF_SPACE_TOKEN
lemma = HALF_SPACE_TOKEN
canon = "/"
elif morpheme.surf == "…":
surf = TRIPLE_DOT_TOKEN
reading = TRIPLE_DOT_TOKEN
lemma = TRIPLE_DOT_TOKEN
canon = f"{TRIPLE_DOT_TOKEN}/{TRIPLE_DOT_TOKEN}"
else:
surf = morpheme.surf
if morpheme.reading == "\u3000":
reading = FULL_SPACE_TOKEN
elif "/" in morpheme.reading and len(morpheme.reading) > 1:
if "/" in morpheme.reading and len(morpheme.reading) > 1:
reading = morpheme.reading.split("/")[0]
else:
reading = morpheme.reading
lemma = FULL_SPACE_TOKEN if morpheme.lemma == "\u3000" else morpheme.lemma
lemma = morpheme.lemma
if morpheme.canon is not None:
canon = morpheme.canon
canon_list: List[str] = canon.split("/")
Expand Down Expand Up @@ -123,18 +101,13 @@ def format_to_sent(self, text: str) -> Sentence:
canon = canon.replace(k, v)
for k, v in SPECIAL_TO_RARE.items():
canon = canon.replace(k, v)
canon = (
f"{self.token_to_word[TRIPLE_DOT_TOKEN]}/{self.token_to_word[TRIPLE_DOT_TOKEN]}"
if canon == f"{TRIPLE_DOT_TOKEN}/{TRIPLE_DOT_TOKEN}"
else canon
)
canon = f'"代表表記:{canon}"' if canon != NO_CANON_TOKEN else "NIL"

# 例外処理
if surf == " " and reading == "\u3000" and lemma == "\u3000":
surf = "\u3000"
if surf == "°C":
surf, lemma, canon = "℃", "℃", '"代表表記:℃/ど"'
if surf == " ":
reading = " "
lemma = " "
canon = r'"代表表記:\␣/\␣"'

formatted += f"{surf} {reading} {lemma} 未定義語 15 その他 1 * 0 * 0 {canon}\n"
except IndexError:
Expand Down

0 comments on commit a658b0c

Please sign in to comment.