Skip to content

Commit

Permalink
bugfix in fit_predict related to preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
fonhorst committed Nov 26, 2023
1 parent 505a5ce commit a160b00
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
36 changes: 22 additions & 14 deletions autotm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from autotm.fitness.tm import extract_topics, print_topics
from autotm.infer import TopicsExtractor
from autotm.preprocessing.dictionaries_preparation import prepare_all_artifacts
from autotm.preprocessing.text_preprocessing import process_dataset
from autotm.preprocessing.text_preprocessing import process_dataset, PROCESSED_TEXT_COLUMN

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(self,
self.exp_dataset_name = exp_dataset_name
self._model: Optional[artm.ARTM] = None

def fit(self, dataset: Union[pd.DataFrame, pd.Series]) -> 'AutoTM':
def fit(self, dataset: Union[pd.DataFrame, pd.Series], processed_dataset_path: Optional[str] = None) -> 'AutoTM':
"""
Preprocess texts in the datasets, looks for the best hyperparameters for ARTM model and fits the model
with these parameters. The instance will contain topics with the most probable words belonging to them.
Expand All @@ -124,11 +124,12 @@ def fit(self, dataset: Union[pd.DataFrame, pd.Series]) -> 'AutoTM':
-------
self : object
Fitted Estimator.
:param processed_dataset_path: optional path where to write intermediate processed dataset
"""
self._check_if_already_fitted(fit_is_ok=False)

processed_dataset_path = os.path.join(self.working_dir_path, f"{uuid.uuid4()}")
processed_dataset_path = processed_dataset_path or os.path.join(self.working_dir_path, f"{uuid.uuid4()}")

logger.info(f"Stage 0: Create working dir {self.working_dir_path} if not exists")

Expand Down Expand Up @@ -196,16 +197,20 @@ def predict(self, dataset: Union[pd.DataFrame, pd.Series]) -> pd.DataFrame:
os.makedirs(self.working_dir_path, exist_ok=True)

with tempfile.TemporaryDirectory(dir=self.working_dir_path) as extractor_working_dir:
processed_dataset_path = os.path.join(extractor_working_dir, "preprocessed_dataset.csv")
process_dataset(
dataset,
self.texts_column_name,
processed_dataset_path,
**self.preprocessing_params
)
processed_dataset = pd.read_csv(processed_dataset_path)
if PROCESSED_TEXT_COLUMN not in dataset.columns:
process_dataset(
dataset,
self.texts_column_name,
extractor_working_dir,
**self.preprocessing_params
)
preprocessed_dataset = pd.read_csv(os.path.join(extractor_working_dir, "prep_df.csv"))
else:
preprocessed_dataset = dataset
topics_extractor = TopicsExtractor(self._model)
mixtures = topics_extractor.get_prob_mixture(dataset=processed_dataset, working_dir=extractor_working_dir)
mixtures = topics_extractor.get_prob_mixture(
dataset=preprocessed_dataset, working_dir=extractor_working_dir
)

return mixtures

Expand All @@ -229,8 +234,11 @@ def fit_predict(self, dataset: Union[pd.DataFrame, pd.Series]) -> pd.DataFrame:
"""
self._check_if_already_fitted(fit_is_ok=False)

self.fit(dataset)
return self.predict(dataset)
processed_dataset_path = os.path.join(self.working_dir_path, f"{uuid.uuid4()}")
self.fit(dataset, processed_dataset_path=processed_dataset_path)

preprocessed_dataset = pd.read_csv(os.path.join(processed_dataset_path, "prep_df.csv"))
return self.predict(preprocessed_dataset)

def save(self, path: str, overwrite: bool = False):
"""
Expand Down
3 changes: 2 additions & 1 deletion autotm/preprocessing/text_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from spacy_langdetect import LanguageDetector
from nltk.stem import WordNetLemmatizer

PROCESSED_TEXT_COLUMN = "processed_text"

# TODO: make transformer class and prep function to download all files

Expand Down Expand Up @@ -169,7 +170,7 @@ def process_dataset(
data = parallelize_dataframe(
data, lemmatize_text, n_cores, lang=lang, col_to_process=col_to_process
)
data["tokens_len"] = data["processed_text"].apply(tokens_num)
data["tokens_len"] = data[PROCESSED_TEXT_COLUMN].apply(tokens_num)
data = data[data["tokens_len"] > min_tokens_count]
data.to_csv(save_path, index=None)
print("Saved to {}".format(save_path))
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "autotm"
version = "0.2.1"
version = "0.2.1.1"
description = "Automatic hyperparameters tuning for topic models (ARTM approach) using evolutionary algorithms"
authors = [
"Khodorchenko Maria <[email protected]>",
Expand Down

0 comments on commit a160b00

Please sign in to comment.