-
Notifications
You must be signed in to change notification settings - Fork 20
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
moving translation wrapper to a file, adding tests
- Loading branch information
Showing
5 changed files
with
112 additions
and
67 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from math import floor | ||
|
||
from nltk.tokenize import sent_tokenize | ||
from transformers.pipelines import pipeline | ||
|
||
|
||
class Translate: | ||
'''Translation class based on HuggingFace's translation pipeline.''' | ||
|
||
def __init__(self, model_name: str) -> None: | ||
''' | ||
Initialize the Translation class with given parameters. | ||
Args: | ||
model_name: The name of the model to use for translation | ||
''' | ||
self._translation_pipeline = pipeline("translation", | ||
model=model_name, | ||
tokenizer=model_name) | ||
self._max_length = self._translation_pipeline.model.config.max_length | ||
|
||
def _translate(self, texts: str) -> str: | ||
'''Translate the texts using the translation pipeline. | ||
It splits the texts into blocks and translates each block separately, | ||
avoiding problems with long texts. | ||
Args: | ||
texts: The texts to translate | ||
Returns: | ||
The translated texts | ||
''' | ||
tokenization = self._translation_pipeline.tokenizer( | ||
texts, return_tensors="pt") # type: ignore | ||
if tokenization.input_ids.shape[1] > (self._max_length / 2): | ||
blocks = floor( | ||
tokenization.input_ids.shape[1] / self._max_length) + 3 | ||
sentences = sent_tokenize(texts) | ||
# Split sentences into a number of blocks, e.g., 2 blocks = 2 groups | ||
len_block = floor(len(sentences) / blocks) + 1 | ||
sentences_list = [] | ||
for i in range(blocks): | ||
sentences_list.append(sentences[i * len_block:(i + 1) * | ||
len_block]) | ||
texts_ = [" ".join(sent) for sent in sentences_list] | ||
else: | ||
texts_ = [texts] | ||
texts_en = [] | ||
for text in texts_: | ||
print(text) | ||
text_en = [ | ||
str(d['translation_text']) # type: ignore | ||
for d in self._translation_pipeline(text) # type: ignore | ||
] | ||
texts_en.append(" ".join(text_en)) | ||
text_en_final = " ".join(texts_en) | ||
return text_en_final | ||
|
||
def __call__(self, text: str) -> str: | ||
'''Translate the text using the translation pipeline. | ||
Args: | ||
text: The text to translate | ||
Returns: | ||
The translated text | ||
''' | ||
return self._translate(text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from typing import List | ||
|
||
import pytest | ||
|
||
from langcheck.metrics.de import Translate | ||
|
||
|
||
@pytest.mark.parametrize( | ||
'de_text,en_text', | ||
[ | ||
([ | ||
'Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein.', # noqa: E501 | ||
'I have no personal opinions, emotions or consciousness.' | ||
]), | ||
([ | ||
'Mein Freund. Willkommen in den Karpaten.', | ||
'My friend, welcome to the Carpathians.' | ||
]), | ||
([ | ||
'Tokio ist die Hauptstadt von Japan.', | ||
'Tokyo is the capital of Japan.' | ||
]), | ||
]) | ||
def test_translate_de_en(de_text: str, en_text: str) -> None: | ||
translation = Translate('Helsinki-NLP/opus-mt-de-en') | ||
assert translation(de_text) == en_text | ||
|
||
|
||
@pytest.mark.parametrize('en_text,de_text', [ | ||
('I have no personal opinions, emotions or consciousness.', | ||
'Ich habe keine persönlichen Meinungen, Emotionen oder Bewusstsein.'), | ||
('My Friend. Welcome to the Carpathians. I am anxiously expecting you.', | ||
'Willkommen bei den Karpaten, ich erwarte Sie.'), | ||
('Tokyo is the capital of Japan.', 'Tokio ist die Hauptstadt Japans.'), | ||
]) | ||
def test_translate_en_de(en_text: str, de_text: List[str]) -> None: | ||
translation = Translate('Helsinki-NLP/opus-mt-en-de') | ||
assert translation(en_text) == de_text |