diff --git a/ammico/__init__.py b/ammico/__init__.py index 7f1465ff..3a780334 100644 --- a/ammico/__init__.py +++ b/ammico/__init__.py @@ -8,7 +8,7 @@ from ammico.faces import EmotionDetector from ammico.multimodal_search import MultimodalSearch from ammico.summary import SummaryDetector -from ammico.text import TextDetector, PostprocessText +from ammico.text import TextDetector, TextAnalyzer, PostprocessText from ammico.utils import find_files, get_dataframe # Export the version defined in project metadata @@ -23,6 +23,7 @@ "MultimodalSearch", "SummaryDetector", "TextDetector", + "TextAnalyzer", "PostprocessText", "find_files", "get_dataframe", diff --git a/ammico/data/ref/test.csv b/ammico/data/ref/test.csv new file mode 100644 index 00000000..f73b9da1 --- /dev/null +++ b/ammico/data/ref/test.csv @@ -0,0 +1,8 @@ +text, date +this is a test, 05/31/24 +bu bir denemedir, 05/31/24 +dies ist ein Test, 05/31/24 +c'est un test, 05/31/24 +esto es una prueba, 05/31/24 +detta är ett test, 05/31/24 + diff --git a/ammico/notebooks/DemoNotebook_ammico.ipynb b/ammico/notebooks/DemoNotebook_ammico.ipynb index f3767c3c..12bd6ea7 100644 --- a/ammico/notebooks/DemoNotebook_ammico.ipynb +++ b/ammico/notebooks/DemoNotebook_ammico.ipynb @@ -366,6 +366,94 @@ "image_df.to_csv(\"/content/drive/MyDrive/misinformation-data/data_out.csv\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read in a csv file containing text and translating/analysing the text\n", + "\n", + "Instead of extracting text from an image, or to re-process text that was already extracted, it is also possible to provide a `csv` file containing text in its rows.\n", + "Provide the path and name of the csv file with the keyword `csv_path`. The keyword `column_key` tells the Analyzer which column key in the csv file holds the text that should be analyzed. This defaults to \"text\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ta = ammico.TextAnalyzer(csv_path=\"../data/ref/test.csv\", column_key=\"text\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read the csv file\n", + "ta.read_csv()\n", + "# set up the dict containing all text entries\n", + "text_dict = ta.mydict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# set the dump file\n", + "# dump file name\n", + "dump_file = \"dump_file.csv\"\n", + "# dump every N images \n", + "dump_every = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# analyze the csv file\n", + "for num, key in tqdm(enumerate(text_dict.keys()), total=len(text_dict)): # loop through all text entries\n", + " ammico.TextDetector(text_dict[key], analyse_text=True, skip_extraction=True).analyse_image() # analyse text with TextDetector and update dict\n", + " if num % dump_every == 0 | num == len(text_dict) - 1: # save results every dump_every to dump_file\n", + " image_df = ammico.get_dataframe(text_dict)\n", + " image_df.to_csv(dump_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save the results to a csv file\n", + "text_df = ammico.get_dataframe(text_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# inspect\n", + "text_df.head(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# write to csv\n", + "text_df.to_csv(\"data_out.csv\")" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/ammico/test/data/test-utf16.csv b/ammico/test/data/test-utf16.csv new file mode 100644 index 00000000..d88dec96 Binary files /dev/null and b/ammico/test/data/test-utf16.csv differ diff --git a/ammico/test/data/test.csv b/ammico/test/data/test.csv new file mode 100644 index 00000000..f73b9da1 --- /dev/null +++ b/ammico/test/data/test.csv @@ -0,0 +1,8 @@ +text, date +this is a test, 05/31/24 +bu bir denemedir, 05/31/24 +dies ist ein Test, 05/31/24 +c'est un test, 05/31/24 +esto es una prueba, 05/31/24 +detta är ett test, 05/31/24 + diff --git a/ammico/test/data/test_read_csv_ref.json b/ammico/test/data/test_read_csv_ref.json new file mode 100644 index 00000000..fd5b8957 --- /dev/null +++ b/ammico/test/data/test_read_csv_ref.json @@ -0,0 +1,32 @@ +{ + "test.csvrow-1": + { + "filename": "test.csv", + "text": "this is a test" + }, + "test.csvrow-2": + { + "filename": "test.csv", + "text": "bu bir denemedir" + }, + "test.csvrow-3": + { + "filename": "test.csv", + "text": "dies ist ein Test" + }, + "test.csvrow-4": + { + "filename": "test.csv", + "text": "c'est un test" + }, + "test.csvrow-5": + { + "filename": "test.csv", + "text": "esto es una prueba" + }, + "test.csvrow-6": + { + "filename": "test.csv", + "text": "detta är ett test" + } +} \ No newline at end of file diff --git a/ammico/test/test_text.py b/ammico/test/test_text.py index bee570ac..70ccc285 100644 --- a/ammico/test/test_text.py +++ b/ammico/test/test_text.py @@ -1,6 +1,8 @@ import pytest import ammico.text as tt import spacy +import json +import sys @pytest.fixture @@ -25,10 +27,25 @@ def set_testdict(get_path): def test_TextDetector(set_testdict): for item in set_testdict: test_obj = tt.TextDetector(set_testdict[item]) - assert test_obj.subdict["text"] is None - assert test_obj.subdict["text_language"] is None - assert test_obj.subdict["text_english"] is None assert not test_obj.analyse_text + assert not test_obj.skip_extraction + assert test_obj.subdict["filename"] == set_testdict[item]["filename"] + assert test_obj.model_summary == "sshleifer/distilbart-cnn-12-6" + assert ( + test_obj.model_sentiment + == "distilbert-base-uncased-finetuned-sst-2-english" + ) + assert test_obj.model_ner == "dbmdz/bert-large-cased-finetuned-conll03-english" + assert test_obj.revision_summary == "a4f8f3e" + assert test_obj.revision_sentiment == "af0f99b" + assert test_obj.revision_ner == "f2482bf" + test_obj = tt.TextDetector({}, analyse_text=True, skip_extraction=True) + assert test_obj.analyse_text + assert test_obj.skip_extraction + with pytest.raises(ValueError): + tt.TextDetector({}, analyse_text=1.0) + with pytest.raises(ValueError): + tt.TextDetector({}, skip_extraction=1.0) def test_run_spacy(set_testdict, get_path): @@ -140,7 +157,6 @@ def test_remove_linebreaks(): assert test_obj.subdict["text_english"] == "This is another test." -@pytest.mark.win_skip def test_text_summary(get_path): mydict = {} test_obj = tt.TextDetector(mydict, analyse_text=True) @@ -162,7 +178,6 @@ def test_text_sentiment_transformers(): assert mydict["sentiment_score"] == pytest.approx(0.99, 0.02) -@pytest.mark.win_skip def test_text_ner(): mydict = {} test_obj = tt.TextDetector(mydict, analyse_text=True) @@ -172,7 +187,51 @@ def test_text_ner(): assert mydict["entity_type"] == ["PER", "LOC"] -@pytest.mark.win_skip +def test_init_csv_option(get_path): + test_obj = tt.TextAnalyzer(csv_path=get_path + "test.csv") + assert test_obj.csv_path == get_path + "test.csv" + assert test_obj.column_key == "text" + assert test_obj.csv_encoding == "utf-8" + test_obj = tt.TextAnalyzer( + csv_path=get_path + "test.csv", column_key="mytext", csv_encoding="utf-16" + ) + assert test_obj.column_key == "mytext" + assert test_obj.csv_encoding == "utf-16" + with pytest.raises(ValueError): + tt.TextAnalyzer(csv_path=1.0) + with pytest.raises(ValueError): + tt.TextAnalyzer(csv_path="something") + with pytest.raises(FileNotFoundError): + tt.TextAnalyzer(csv_path=get_path + "test_no.csv") + with pytest.raises(ValueError): + tt.TextAnalyzer(csv_path=get_path + "test.csv", column_key=1.0) + with pytest.raises(ValueError): + tt.TextAnalyzer(csv_path=get_path + "test.csv", csv_encoding=1.0) + + +@pytest.mark.skipif(sys.platform == "win32", reason="Encoding different on Window") +def test_read_csv(get_path): + test_obj = tt.TextAnalyzer(csv_path=get_path + "test.csv") + test_obj.read_csv() + with open(get_path + "test_read_csv_ref.json", "r") as file: + ref_dict = json.load(file) + # we are assuming the order did not get jungled up + for (_, value_test), (_, value_ref) in zip( + test_obj.mydict.items(), ref_dict.items() + ): + assert value_test["text"] == value_ref["text"] + # test with different encoding + test_obj = tt.TextAnalyzer( + csv_path=get_path + "test-utf16.csv", csv_encoding="utf-16" + ) + test_obj.read_csv() + # we are assuming the order did not get jungled up + for (_, value_test), (_, value_ref) in zip( + test_obj.mydict.items(), ref_dict.items() + ): + assert value_test["text"] == value_ref["text"] + + def test_PostprocessText(set_testdict, get_path): reference_dict = "THE\nALGEBRAIC\nEIGENVALUE\nPROBLEM\nDOM\nNVS TIO\nMINA\nMonographs\non Numerical Analysis\nJ.. H. WILKINSON" reference_df = "Mathematische Formelsammlung\nfür Ingenieure und Naturwissenschaftler\nMit zahlreichen Abbildungen und Rechenbeispielen\nund einer ausführlichen Integraltafel\n3., verbesserte Auflage" diff --git a/ammico/text.py b/ammico/text.py index d917184d..5893a566 100644 --- a/ammico/text.py +++ b/ammico/text.py @@ -15,6 +15,7 @@ def __init__( self, subdict: dict, analyse_text: bool = False, + skip_extraction: bool = False, model_names: list = None, revision_numbers: list = None, ) -> None: @@ -25,6 +26,8 @@ def __init__( analysis results from other modules. analyse_text (bool, optional): Decide if extracted text will be further subject to analysis. Defaults to False. + skip_extraction (bool, optional): Decide if text will be extracted from images or + is already provided via a csv. Defaults to False. model_names (list, optional): Provide model names for summary, sentiment and ner analysis. Defaults to None, in which case the default model from transformers are used (as of 03/2023): "sshleifer/distilbart-cnn-12-6" (summary), @@ -40,11 +43,21 @@ def __init__( "f2482bf" (NER, bert). """ super().__init__(subdict) - self.subdict.update(self.set_keys()) + # disable this for now + # maybe it would be better to initialize the keys differently + # the reason is that they are inconsistent depending on the selected + # options, and also this may not be really necessary and rather restrictive + # self.subdict.update(self.set_keys()) self.translator = Translator() if not isinstance(analyse_text, bool): raise ValueError("analyse_text needs to be set to true or false") self.analyse_text = analyse_text + self.skip_extraction = skip_extraction + if not isinstance(skip_extraction, bool): + raise ValueError("skip_extraction needs to be set to true or false") + if self.skip_extraction: + print("Skipping text extraction from image.") + print("Reading text directly from provided dictionary.") if self.analyse_text: self._initialize_spacy() if model_names: @@ -155,7 +168,8 @@ def analyse_image(self) -> dict: Returns: dict: The updated dictionary with text analysis results. """ - self.get_text_from_image() + if not self.skip_extraction: + self.get_text_from_image() self.translate_text() self.remove_linebreaks() if self.analyse_text: @@ -287,18 +301,32 @@ def text_ner(self): class TextAnalyzer: """Used to get text from a csv and then run the TextDetector on it.""" - def __init__(self, csv_path: str, column_key: str = None) -> None: + def __init__( + self, csv_path: str, column_key: str = None, csv_encoding: str = "utf-8" + ) -> None: """Init the TextTranslator class. Args: csv_path (str): Path to the CSV file containing the text entries. column_key (str): Key for the column containing the text entries. Defaults to None. + csv_encoding (str): Encoding of the CSV file. Defaults to "utf-8". """ self.csv_path = csv_path self.column_key = column_key + self.csv_encoding = csv_encoding self._check_valid_csv_path() self._check_file_exists() + if not self.column_key: + print("No column key provided - using 'text' as default.") + self.column_key = "text" + if not self.csv_encoding: + print("No encoding provided - using 'utf-8' as default.") + self.csv_encoding = "utf-8" + if not isinstance(self.column_key, str): + raise ValueError("The provided column key is not a string.") + if not isinstance(self.csv_encoding, str): + raise ValueError("The provided encoding is not a string.") def _check_valid_csv_path(self): if not isinstance(self.csv_path, str): @@ -319,9 +347,7 @@ def read_csv(self) -> dict: Returns: dict: The dictionary with the text entries. """ - df = pd.read_csv(self.csv_path, encoding="utf8") - if not self.column_key: - self.column_key = "text" + df = pd.read_csv(self.csv_path, encoding=self.csv_encoding) if self.column_key not in df: raise ValueError( diff --git a/docs/source/notebooks/DemoNotebook_ammico.ipynb b/docs/source/notebooks/DemoNotebook_ammico.ipynb index 9a0b06da..292a93d5 100644 --- a/docs/source/notebooks/DemoNotebook_ammico.ipynb +++ b/docs/source/notebooks/DemoNotebook_ammico.ipynb @@ -94,7 +94,10 @@ "import os\n", "import ammico\n", "# for displaying a progress bar\n", - "from tqdm import tqdm" + "from tqdm import tqdm\n", + "# to get the reference data for text_dict\n", + "import importlib_resources\n", + "pkg = importlib_resources.files(\"ammico\")" ] }, { @@ -363,6 +366,95 @@ "image_df.to_csv(\"/content/drive/MyDrive/misinformation-data/data_out.csv\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read in a csv file containing text and translating/analysing the text\n", + "\n", + "Instead of extracting text from an image, or to re-process text that was already extracted, it is also possible to provide a `csv` file containing text in its rows.\n", + "Provide the path and name of the csv file with the keyword `csv_path`. The keyword `column_key` tells the Analyzer which column key in the csv file holds the text that should be analyzed. This defaults to \"text\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "csv_path = pkg / \"data\" / \"ref\" / \"test.csv\"\n", + "ta = ammico.TextAnalyzer(csv_path=str(csv_path), column_key=\"text\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# read the csv file\n", + "ta.read_csv()\n", + "# set up the dict containing all text entries\n", + "text_dict = ta.mydict" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# set the dump file\n", + "# dump file name\n", + "dump_file = \"dump_file.csv\"\n", + "# dump every N images \n", + "dump_every = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# analyze the csv file\n", + "for num, key in tqdm(enumerate(text_dict.keys()), total=len(text_dict)): # loop through all text entries\n", + " ammico.TextDetector(text_dict[key], analyse_text=True, skip_extraction=True).analyse_image() # analyse text with TextDetector and update dict\n", + " if num % dump_every == 0 | num == len(text_dict) - 1: # save results every dump_every to dump_file\n", + " image_df = ammico.get_dataframe(text_dict)\n", + " image_df.to_csv(dump_file)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# save the results to a csv file\n", + "text_df = ammico.get_dataframe(text_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# inspect\n", + "text_df.head(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# write to csv\n", + "text_df.to_csv(\"data_out.csv\")" + ] + }, { "cell_type": "markdown", "metadata": {},