-
Notifications
You must be signed in to change notification settings - Fork 274
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #873 from pavaris-pm/dev
Add PhayaThaiBERT engine with new features [WIP] by @pavaris-pm and @MpolaarbearM
- Loading branch information
Showing
9 changed files
with
537 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# -*- coding: utf-8 -*- | ||
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import List | ||
import random | ||
import re | ||
|
||
from pythainlp.phayathaibert.core import ThaiTextProcessor | ||
|
||
|
||
_MODEL_NAME = "clicknext/phayathaibert" | ||
|
||
|
||
class ThaiTextAugmenter: | ||
def __init__(self,) -> None: | ||
from transformers import (AutoTokenizer, | ||
AutoModelForMaskedLM, | ||
pipeline,) | ||
self.tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME) | ||
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(_MODEL_NAME) | ||
self.model = pipeline("fill-mask", tokenizer=self.tokenizer, model=self.model_for_masked_lm) | ||
self.processor = ThaiTextProcessor() | ||
|
||
def generate(self, | ||
sample_text: str, | ||
word_rank: int, | ||
max_length: int = 3, | ||
sample: bool = False | ||
) -> str: | ||
sample_txt = sample_text | ||
final_text = "" | ||
|
||
for j in range(max_length): | ||
input = self.processor.preprocess(sample_txt) | ||
if sample: | ||
random_word_idx = random.randint(0, 4) | ||
output = self.model(input)[random_word_idx]["sequence"] | ||
else: | ||
output = self.model(input)[word_rank]["sequence"] | ||
sample_txt = output + "<mask>" | ||
final_text = sample_txt | ||
|
||
gen_txt = re.sub("<mask>", "", final_text) | ||
|
||
return gen_txt | ||
|
||
def augment(self, | ||
text: str, | ||
num_augs: int = 3, | ||
sample: bool = False | ||
) -> List[str]: | ||
""" | ||
Text augmentation from PhayaThaiBERT | ||
:param str text: Thai text | ||
:param int num_augs: an amount of augmentation text needed as an output | ||
:param bool sample: whether to sample the text as an output or not, \ | ||
true if more word diversity is needed | ||
:return: list of text augment | ||
:rtype: List[str] | ||
:Example: | ||
:: | ||
from pythainlp.augment.lm import ThaiTextAugmenter | ||
aug = ThaiTextAugmenter() | ||
aug.augment("ช้างมีทั้งหมด 50 ตัว บน", num_args=5) | ||
# output = ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้ครับ.', | ||
'ช้างมีทั้งหมด 50 ตัว บนพื้นดินครับ...', | ||
'ช้างมีทั้งหมด 50 ตัว บนท้องฟ้าครับ...', | ||
'ช้างมีทั้งหมด 50 ตัว บนดวงจันทร์.‼', | ||
'ช้างมีทั้งหมด 50 ตัว บนเขาค่ะ😁'] | ||
""" | ||
MAX_NUM_AUGS = 5 | ||
augment_list = [] | ||
|
||
if "<mask>" not in text: | ||
text = text + "<mask>" | ||
|
||
if num_augs <= MAX_NUM_AUGS: | ||
for rank in range(num_augs): | ||
gen_text = self.generate(text, rank, sample=sample) | ||
processed_text = re.sub("<_>", " ", self.processor.preprocess(gen_text)) | ||
augment_list.append(processed_text) | ||
|
||
return augment_list | ||
|
||
raise ValueError( | ||
f"augmentation of more than {num_augs} is exceeded the default limit: {MAX_NUM_AUGS}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# -*- coding: utf-8 -*- | ||
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project | ||
# SPDX-License-Identifier: Apache-2.0 | ||
__all__ = [ | ||
"NamedEntityTagger", | ||
"PartOfSpeechTagger", | ||
"ThaiTextAugmenter", | ||
"ThaiTextProcessor", | ||
"segment", | ||
] | ||
|
||
from pythainlp.phayathaibert.core import ( | ||
NamedEntityTagger, | ||
PartOfSpeechTagger, | ||
ThaiTextAugmenter, | ||
ThaiTextProcessor, | ||
segment, | ||
) |
Oops, something went wrong.