Skip to content

Commit

Permalink
Merge pull request #873 from pavaris-pm/dev
Browse files Browse the repository at this point in the history
Add PhayaThaiBERT engine with new features [WIP] by @pavaris-pm  and @MpolaarbearM
  • Loading branch information
bact authored Dec 11, 2023
2 parents de4f206 + e7ef6ce commit ff74b39
Show file tree
Hide file tree
Showing 9 changed files with 537 additions and 7 deletions.
4 changes: 3 additions & 1 deletion pythainlp/augment/lm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0
"""
LM
Language Models
"""

__all__ = [
"FastTextAug",
"Thai2transformersAug",
"ThaiTextAugmenter",
]

from pythainlp.augment.lm.fasttext import FastTextAug
from pythainlp.augment.lm.phayathaibert import ThaiTextAugmenter
from pythainlp.augment.lm.wangchanberta import Thai2transformersAug
94 changes: 94 additions & 0 deletions pythainlp/augment/lm/phayathaibert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0

from typing import List
import random
import re

from pythainlp.phayathaibert.core import ThaiTextProcessor


_MODEL_NAME = "clicknext/phayathaibert"


class ThaiTextAugmenter:
def __init__(self,) -> None:
from transformers import (AutoTokenizer,
AutoModelForMaskedLM,
pipeline,)
self.tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(_MODEL_NAME)
self.model = pipeline("fill-mask", tokenizer=self.tokenizer, model=self.model_for_masked_lm)
self.processor = ThaiTextProcessor()

def generate(self,
sample_text: str,
word_rank: int,
max_length: int = 3,
sample: bool = False
) -> str:
sample_txt = sample_text
final_text = ""

for j in range(max_length):
input = self.processor.preprocess(sample_txt)
if sample:
random_word_idx = random.randint(0, 4)
output = self.model(input)[random_word_idx]["sequence"]
else:
output = self.model(input)[word_rank]["sequence"]
sample_txt = output + "<mask>"
final_text = sample_txt

gen_txt = re.sub("<mask>", "", final_text)

return gen_txt

def augment(self,
text: str,
num_augs: int = 3,
sample: bool = False
) -> List[str]:
"""
Text augmentation from PhayaThaiBERT
:param str text: Thai text
:param int num_augs: an amount of augmentation text needed as an output
:param bool sample: whether to sample the text as an output or not, \
true if more word diversity is needed
:return: list of text augment
:rtype: List[str]
:Example:
::
from pythainlp.augment.lm import ThaiTextAugmenter
aug = ThaiTextAugmenter()
aug.augment("ช้างมีทั้งหมด 50 ตัว บน", num_args=5)
# output = ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้ครับ.',
'ช้างมีทั้งหมด 50 ตัว บนพื้นดินครับ...',
'ช้างมีทั้งหมด 50 ตัว บนท้องฟ้าครับ...',
'ช้างมีทั้งหมด 50 ตัว บนดวงจันทร์.‼',
'ช้างมีทั้งหมด 50 ตัว บนเขาค่ะ😁']
"""
MAX_NUM_AUGS = 5
augment_list = []

if "<mask>" not in text:
text = text + "<mask>"

if num_augs <= MAX_NUM_AUGS:
for rank in range(num_augs):
gen_text = self.generate(text, rank, sample=sample)
processed_text = re.sub("<_>", " ", self.processor.preprocess(gen_text))
augment_list.append(processed_text)

return augment_list

raise ValueError(
f"augmentation of more than {num_augs} is exceeded the default limit: {MAX_NUM_AUGS}"
)
18 changes: 18 additions & 0 deletions pythainlp/phayathaibert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0
__all__ = [
"NamedEntityTagger",
"PartOfSpeechTagger",
"ThaiTextAugmenter",
"ThaiTextProcessor",
"segment",
]

from pythainlp.phayathaibert.core import (
NamedEntityTagger,
PartOfSpeechTagger,
ThaiTextAugmenter,
ThaiTextProcessor,
segment,
)
Loading

0 comments on commit ff74b39

Please sign in to comment.