From 02eef225fe0510cf745a6dc20ccbf7ccb513dd6b Mon Sep 17 00:00:00 2001 From: fonhorst Date: Fri, 7 Jun 2024 21:29:06 +0300 Subject: [PATCH] refactoring --- autotm/clustering.py | 17 +++++++---------- autotm/content_splitter.py | 5 ++--- autotm/main_fitness_worker.py | 2 +- autotm/params.py | 2 -- autotm/utils.py | 2 ++ tests/unit/conftest.py | 2 +- 6 files changed, 13 insertions(+), 17 deletions(-) diff --git a/autotm/clustering.py b/autotm/clustering.py index 8c245ad..3b3bb4e 100644 --- a/autotm/clustering.py +++ b/autotm/clustering.py @@ -1,21 +1,18 @@ -from matplotlib import offsetbox import matplotlib.pyplot as plt -import matplotlib.patheffects as PathEffects -import seaborn as sns -import pandas as pd +import copy +import warnings + +import matplotlib.pyplot as plt import numpy as np -import plotly.graph_objects as go +import pandas as pd from sklearn.cluster import KMeans -from sklearn.preprocessing import StandardScaler from sklearn.manifold import TSNE -import warnings -import copy - +from sklearn.preprocessing import StandardScaler warnings.filterwarnings('ignore') -def cluster_phi(phi_df: pd.dataFrame, n_clusters=10, plot_img=True): +def cluster_phi(phi_df: pd.DataFrame, n_clusters=10, plot_img=True): _phi_df = copy.deepcopy(phi_df) y = _phi_df.index.values x = _phi_df.values diff --git a/autotm/content_splitter.py b/autotm/content_splitter.py index 5df8778..505be19 100644 --- a/autotm/content_splitter.py +++ b/autotm/content_splitter.py @@ -1,6 +1,5 @@ -import re -from abc import ABC, abstractmethod -from autotm.preprocessing.text_preprocessing import process_dataset +from abc import ABC + class BaseTextSplitter(ABC): diff --git a/autotm/main_fitness_worker.py b/autotm/main_fitness_worker.py index 6a7568c..d21492e 100644 --- a/autotm/main_fitness_worker.py +++ b/autotm/main_fitness_worker.py @@ -7,7 +7,7 @@ def main(): os.environ['AUTOTM_COMPONENT'] = 'worker' os.environ['AUTOTM_EXEC_MODE'] = 'cluster' - from autotm.fitness.tasks import make_celery_app + from autotm.fitness.cluster_tasks import make_celery_app from autotm.fitness.tm import TopicModelFactory if "DATASETS_CONFIG" in os.environ: diff --git a/autotm/params.py b/autotm/params.py index a4568b1..9df1151 100644 --- a/autotm/params.py +++ b/autotm/params.py @@ -1,5 +1,4 @@ import copy -import logging import random from typing import List @@ -11,7 +10,6 @@ from autotm.graph_ga import create_pipeline, crossover_pipelines, mutate_pipeline from autotm.pipeline import Pipeline, Stage, StageType, Param, create_stage, IntRangeDistribution, \ FloatRangeDistribution -from autotm.utils import do_suppress_stdout PARAM_NAMES = [ "val_decor", diff --git a/autotm/utils.py b/autotm/utils.py index dea25a5..e2ec7a6 100644 --- a/autotm/utils.py +++ b/autotm/utils.py @@ -102,6 +102,8 @@ def parallelize_dataframe(df: pd.DataFrame, func, n_cores, return_type="df", **k elif isinstance(map_res[0], tuple): zipped_elems = list(zip(*map_res)) res = (merge_dicts(zipped_elems[0]), merge_dicts(zipped_elems[1])) + else: + raise ValueError(f"Unsupported return_type: {return_type}") pool.close() pool.join() return res diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3c4f0c0..d334693 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,6 +1,6 @@ import itertools import os.path -from typing import List, Dict, Tuple +from typing import Dict import pytest