Skip to content

Commit

Permalink
Code clean up
Browse files Browse the repository at this point in the history
- `pythainlp.augment.lm.phayathaibert.ThaiTextAugmenter.augment()`: Add exception if max num_args is exceeded, also guarantee return if work properly
- `Khavee`: skip McCabe complexity check for now (54 complexity at the moment)
- `pythainlp.parse.core.dependency_parsing()`: model: str can't be None
  • Loading branch information
bact committed Dec 12, 2023
1 parent ff74b39 commit 6e3fe5d
Show file tree
Hide file tree
Showing 13 changed files with 686 additions and 450 deletions.
56 changes: 33 additions & 23 deletions pythainlp/augment/lm/phayathaibert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,31 @@


class ThaiTextAugmenter:
def __init__(self,) -> None:
from transformers import (AutoTokenizer,
AutoModelForMaskedLM,
pipeline,)
def __init__(self) -> None:
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
pipeline,
)

self.tokenizer = AutoTokenizer.from_pretrained(_MODEL_NAME)
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(_MODEL_NAME)
self.model = pipeline("fill-mask", tokenizer=self.tokenizer, model=self.model_for_masked_lm)
self.model_for_masked_lm = AutoModelForMaskedLM.from_pretrained(
_MODEL_NAME
)
self.model = pipeline(
"fill-mask",
tokenizer=self.tokenizer,
model=self.model_for_masked_lm,
)
self.processor = ThaiTextProcessor()

def generate(self,
sample_text: str,
word_rank: int,
max_length: int = 3,
sample: bool = False
) -> str:
def generate(
self,
sample_text: str,
word_rank: int,
max_length: int = 3,
sample: bool = False,
) -> str:
sample_txt = sample_text
final_text = ""

Expand All @@ -45,11 +55,9 @@ def generate(self,

return gen_txt

def augment(self,
text: str,
num_augs: int = 3,
sample: bool = False
) -> List[str]:
def augment(
self, text: str, num_augs: int = 3, sample: bool = False
) -> List[str]:
"""
Text augmentation from PhayaThaiBERT
Expand Down Expand Up @@ -84,11 +92,13 @@ def augment(self,
if num_augs <= MAX_NUM_AUGS:
for rank in range(num_augs):
gen_text = self.generate(text, rank, sample=sample)
processed_text = re.sub("<_>", " ", self.processor.preprocess(gen_text))
processed_text = re.sub(
"<_>", " ", self.processor.preprocess(gen_text)
)
augment_list.append(processed_text)
else:
raise ValueError(
f"augmentation of more than {num_augs} is exceeded the default limit: {MAX_NUM_AUGS}"
)

return augment_list

raise ValueError(
f"augmentation of more than {num_augs} is exceeded the default limit: {MAX_NUM_AUGS}"
)
return augment_list
8 changes: 5 additions & 3 deletions pythainlp/augment/lm/wangchanberta.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0

from typing import List

from transformers import (
CamembertTokenizer,
pipeline,
Expand Down Expand Up @@ -51,9 +53,9 @@ def generate(self, sentence: str, num_replace_tokens: int = 3):

def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]:
"""
Text Augment from wangchanberta
Text augmentation from WangchanBERTa
:param str sentence: thai sentence
:param str sentence: Thai sentence
:param int num_replace_tokens: number replace tokens
:return: list of text augment
Expand All @@ -64,7 +66,7 @@ def augment(self, sentence: str, num_replace_tokens: int = 3) -> List[str]:
from pythainlp.augment.lm import Thai2transformersAug
aug=Thai2transformersAug()
aug = Thai2transformersAug()
aug.augment("ช้างมีทั้งหมด 50 ตัว บน")
# output: ['ช้างมีทั้งหมด 50 ตัว บนโลกใบนี้',
Expand Down
1 change: 1 addition & 0 deletions pythainlp/augment/wordnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections import OrderedDict
import itertools
from typing import List

from nltk.corpus import wordnet as wn
from pythainlp.corpus import wordnet
from pythainlp.tokenize import word_tokenize
Expand Down
1 change: 1 addition & 0 deletions pythainlp/khavee/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# SPDX-FileCopyrightText: Copyright 2016-2023 PyThaiNLP Project
# SPDX-License-Identifier: Apache-2.0

__all__ = ["KhaveeVerifier"]

from pythainlp.khavee.core import KhaveeVerifier
Loading

0 comments on commit 6e3fe5d

Please sign in to comment.