Skip to content

Commit

Permalink
remove some redundant code
Browse files Browse the repository at this point in the history
  • Loading branch information
peteryangms committed Jun 5, 2024
1 parent 2545277 commit 268e718
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from pathlib import Path

from rdagent.document_process.document_analysis import (
filter_factor_by_viability,
deduplicate_factors_several_times,
check_factor_viability,
deduplicate_factors_by_llm,
extract_factors_from_report_dict,
merge_file_to_factor_dict_to_factor_dict,
)
Expand All @@ -21,9 +21,9 @@ def extract_factors_and_implement(report_file_path: str):
file_to_factor_result = extract_factors_from_report_dict(docs_dict, selected_report_dict)
factor_dict = merge_file_to_factor_dict_to_factor_dict(file_to_factor_result)

factor_dict_viable, factor_viability = filter_factor_by_viability(factor_dict)
factor_viability = check_factor_viability(factor_dict)

factor_dict, duplication_names_list = deduplicate_factors_several_times(factor_dict, factor_viability)
factor_dict, duplication_names_list = deduplicate_factors_by_llm(factor_dict, factor_viability)


if __name__ == "__main__":
Expand Down
116 changes: 14 additions & 102 deletions rdagent/document_process/document_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,94 +29,6 @@
document_process_prompts = Prompts(file_path=Path(__file__).parent / "prompts.yaml")


def load_documents_by_langchain(path: Path) -> list:
"""Load documents from the specified path.
Args:
path (str): The path to the directory or file containing the documents.
Returns:
list: A list of loaded documents.
"""
loader = PyPDFDirectoryLoader(str(path), silent_errors=True) if path.is_dir() else PyPDFLoader(str(path))
return loader.load()


def process_documents_by_langchain(docs: list[Document]) -> dict[str, str]:
"""Process a list of documents and group them by document name.
Args:
docs (list): A list of documents.
Returns:
dict: A dictionary where the keys are document names and the values are
the concatenated content of the documents.
"""
content_dict = {}

for doc in docs:
doc_name = str(Path(doc.metadata["source"]).resolve())
doc_content = doc.page_content

if doc_name not in content_dict:
content_dict[str(doc_name)] = doc_content
else:
content_dict[str(doc_name)] += doc_content

return content_dict


def load_and_process_pdfs_by_langchain(path: Path) -> dict[str, str]:
return process_documents_by_langchain(load_documents_by_langchain(path))


def load_and_process_one_pdf_by_azure_document_intelligence(
path: Path,
key: str,
endpoint: str,
) -> str:
pages = len(PyPDFLoader(str(path)).load())
document_analysis_client = DocumentAnalysisClient(
endpoint=endpoint,
credential=AzureKeyCredential(key),
)

with path.open("rb") as file:
result = document_analysis_client.begin_analyze_document(
"prebuilt-document",
file,
pages=f"1-{pages}",
).result()
return result.content


def load_and_process_pdfs_by_azure_document_intelligence(path: Path) -> dict[str, str]:
config = Config()

assert config.azure_document_intelligence_key is not None
assert config.azure_document_intelligence_endpoint is not None

content_dict = {}
ab_path = path.resolve()
if ab_path.is_file():
assert ".pdf" in ab_path.suffixes, "The file must be a PDF file."
proc = load_and_process_one_pdf_by_azure_document_intelligence
content_dict[str(ab_path)] = proc(
ab_path,
config.azure_document_intelligence_key,
config.azure_document_intelligence_endpoint,
)
else:
for file_path in ab_path.rglob("*"):
if file_path.is_file() and ".pdf" in file_path.suffixes:
content_dict[str(file_path)] = load_and_process_one_pdf_by_azure_document_intelligence(
file_path,
config.azure_document_intelligence_key,
config.azure_document_intelligence_endpoint,
)
return content_dict


def classify_report_from_dict(
report_dict: Mapping[str, str],
input_max_token: int = 128000,
Expand Down Expand Up @@ -411,7 +323,7 @@ def __check_factor_dict_viability_simulate_json_mode(
return {}


def filter_factor_by_viability(
def check_factor_viability(
factor_dict: dict[str, dict[str, str]],
) -> tuple[dict[str, dict[str, str]], dict[str, dict[str, str]]]:
factor_viability_dict = {}
Expand Down Expand Up @@ -443,16 +355,16 @@ def filter_factor_by_viability(

factor_df = factor_df[~factor_df.index.isin(factor_viability_dict)]

filtered_factor_dict = {
factor_name: factor_dict[factor_name]
for factor_name in factor_dict
if factor_viability_dict[factor_name]["viability"]
}
# filtered_factor_dict = {
# factor_name: factor_dict[factor_name]
# for factor_name in factor_dict
# if factor_viability_dict[factor_name]["viability"]
# }

return filtered_factor_dict, factor_viability_dict
return factor_viability_dict


def check_factor_duplication_simulate_json_mode(
def __check_factor_duplication_simulate_json_mode(
factor_df: pd.DataFrame,
) -> list[list[str]]:
session = APIBackend().build_chat_session(
Expand Down Expand Up @@ -491,7 +403,7 @@ def check_factor_duplication_simulate_json_mode(
return generated_duplicated_groups


def kmeans_embeddings(embeddings: np.ndarray, k: int = 20) -> list[list[str]]:
def __kmeans_embeddings(embeddings: np.ndarray, k: int = 20) -> list[list[str]]:
x_normalized = normalize(embeddings)

kmeans = KMeans(
Expand Down Expand Up @@ -545,7 +457,7 @@ def find_closest_cluster_cosine_similarity(
)


def deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[list[str]]:
def __deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[list[str]]:
factor_df = pd.DataFrame(factor_dict).T
factor_df.index.names = ["factor_name"]

Expand Down Expand Up @@ -576,7 +488,7 @@ def deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[list
len(full_str_list) // Config().max_input_duplicate_factor_group,
30,
):
kmeans_index_group = kmeans_embeddings(embeddings=embeddings, k=k)
kmeans_index_group = __kmeans_embeddings(embeddings=embeddings, k=k)
if len(kmeans_index_group[0]) < Config().max_input_duplicate_factor_group:
target_k = k
FinCoLog().info(f"K-means group number: {k}")
Expand All @@ -589,7 +501,7 @@ def deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[list
result_list = []
result_list = [
pool.apply_async(
check_factor_duplication_simulate_json_mode,
__check_factor_duplication_simulate_json_mode,
(factor_df.loc[factor_name_group, :],),
)
for factor_name_group in factor_name_groups
Expand All @@ -610,14 +522,14 @@ def deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[list
return duplication_names_list


def deduplicate_factors_several_times(
def deduplicate_factors_by_llm(
factor_dict: dict[str, dict[str, str]],
factor_viability_dict: dict[str, dict[str, str]] = None,
) -> list[list[str]]:
final_duplication_names_list = []
current_round_factor_dict = factor_dict
for _ in range(10):
duplication_names_list = deduplicate_factor_dict(current_round_factor_dict)
duplication_names_list = __deduplicate_factor_dict(current_round_factor_dict)

new_round_names = []
for duplication_names in duplication_names_list:
Expand Down

0 comments on commit 268e718

Please sign in to comment.