From 59123bfd94a295921d4ef1a5f03e43e39d3edd40 Mon Sep 17 00:00:00 2001 From: XianBW <36835909+XianBW@users.noreply.github.com> Date: Wed, 12 Jun 2024 15:12:11 +0800 Subject: [PATCH] Ci fix (#22) * fix replace function of CI tool * fix ruff errors (ignore some parts) * add ruff rule ignore comment --- pyproject.toml | 1 + rdagent/app/CI/run.py | 113 ++++--- .../factor_extract_and_implement.py | 7 +- rdagent/core/conf.py | 19 +- rdagent/core/log.py | 14 +- rdagent/core/prompts.py | 16 +- rdagent/core/utils.py | 59 ++-- rdagent/document_process/document_analysis.py | 18 +- rdagent/document_process/document_reader.py | 10 +- .../evolving/evaluators.py | 3 +- .../evolving/evolving_strategy.py | 13 +- .../factor_implementation_evolving_cli.py | 2 +- .../evolving/knowledge_management.py | 9 +- .../share_modules/evaluator.py | 14 +- .../share_modules/factor.py | 12 +- .../factor_implementation_utils.py | 1 - rdagent/knowledge_management/graph.py | 216 ++++++------ rdagent/oai/llm_utils.py | 312 +++++++++--------- test/oai/test_completion.py | 29 +- test/oai/test_embedding_and_similarity.py | 16 +- 20 files changed, 449 insertions(+), 435 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9f7a46b4..40ced362 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ ignore = [ "PGH", "PLR0913", "S101", + "S301", "T20", "TCH003", "TD", diff --git a/rdagent/app/CI/run.py b/rdagent/app/CI/run.py index b3667373..e6e515c5 100644 --- a/rdagent/app/CI/run.py +++ b/rdagent/app/CI/run.py @@ -3,6 +3,7 @@ import datetime import json import re +import shlex import subprocess import time from collections import defaultdict @@ -122,20 +123,23 @@ def load(self) -> None: self.lineno_width = len(str(self.lineno)) self.code_lines_with_lineno = self.add_line_number(self.code_lines) - def get(self, start: int = 1, end: int | None = None, add_line_number: bool = False, return_list: bool = False) -> list[str] | str: + def get( + self, start: int = 1, end: int | None = None, *, + add_line_number: bool = False, return_list: bool = False, + ) -> list[str] | str: """ Retrieves a portion of the code lines. line number starts from 1, return codes in [start, end]. Args: start (int): The starting line number (inclusive). Defaults to 1. - end (int): The ending line number (inclusive). Defaults to None, which means the last line. + end (int | None): The ending line number (inclusive). Defaults to None, which means the last line. add_line_number (bool): Whether to include line numbers in the result. Defaults to False. return_list (bool): Whether to return the result as a list of lines or as a single string. Defaults to False. Returns: - Union[List[str], str]: The code lines as a list of strings or as a + list[str] | str: The code lines as a list of strings or as a single string, depending on the value of `return_list`. """ start -= 1 @@ -161,11 +165,12 @@ def apply_changes(self, changes: list[tuple[int, int, str]]) -> None: """ offset = 0 for start, end, code in changes: + # starts from 1 --> starts from 0 adjusted_start = max(start - 1, 0) new_code = code.split("\n") - self.code_lines[adjusted_start+offset:end+offset] = new_code - offset += len(new_code) - (end - start) + self.code_lines[adjusted_start + offset : end + offset] = new_code + offset += len(new_code) - (end - adjusted_start) self.path.write_text("\n".join(self.code_lines), encoding="utf-8") self.load() @@ -207,15 +212,17 @@ def __str__(self) -> str: class Repo(EvolvableSubjects): - def __init__(self, project_path: Path | str, excludes: list[Path] = [], **kwargs: Any) -> None: + def __init__(self, project_path: Path | str, excludes: list[Path] | None = None, **kwargs: Any) -> None: + if excludes is None: + excludes = [] self.params = kwargs self.project_path = Path(project_path) excludes = [self.project_path / path for path in excludes] git_ignored_output = subprocess.check_output( - ["git", "status", "--ignored", "-s"], - cwd=project_path, + ["/usr/bin/git", "status", "--ignored", "-s"], # noqa: S603 + cwd=str(self.project_path), stderr=subprocess.STDOUT, text=True, ) @@ -283,7 +290,7 @@ def explain_rule(error_code: str) -> RuffRule: explain_command = f"ruff rule {error_code} --output-format json" try: out = subprocess.check_output( - explain_command, + shlex.split(explain_command), # noqa: S603 stderr=subprocess.STDOUT, text=True, ) @@ -293,11 +300,11 @@ def explain_rule(error_code: str) -> RuffRule: return RuffRule(**json.loads(out)) - def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback: + def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback: # noqa: ARG002 """Simply run ruff to get the feedbacks.""" try: out = subprocess.check_output( - self.command.split(), + shlex.split(self.command), # noqa: S603 cwd=evo.project_path, stderr=subprocess.STDOUT, text=True, @@ -349,10 +356,10 @@ def __init__(self, command: str | None = None) -> None: else: self.command = command - def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback: + def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback: # noqa: ARG002 try: out = subprocess.check_output( - self.command.split(), + shlex.split(self.command), # noqa: S603 cwd=evo.project_path, stderr=subprocess.STDOUT, text=True, @@ -368,9 +375,11 @@ def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback: for match in re.findall(pattern, out, re.DOTALL): raw_str, file_path, line_number, column_number, error_message, error_code, error_hint = match error_message = error_message.strip().replace("\n", " ") - if re.match(r".*[^\n]*?:\d+:\d+: note:.*", error_hint, re.DOTALL) is not None: - error_hint_position = re.split(r"[^\n]*?:\d+:\d+: note:", error_hint, re.DOTALL)[0] - error_hint_help = re.findall(r"^.*?:\d+:\d+: note: (.*)$", error_hint, re.MULTILINE) + if re.match(r".*[^\n]*?:\d+:\d+: note:.*", error_hint, flags=re.DOTALL) is not None: + error_hint_position = re.split( + pattern=r"[^\n]*?:\d+:\d+: note:", string=error_hint, maxsplit=1, flags=re.DOTALL, + )[0] + error_hint_help = re.findall(r"^.*?:\d+:\d+: note: (.*)$", error_hint, flags=re.MULTILINE) error_hint_help = "\n".join(error_hint_help) error_hint = f"{error_hint_position}\nHelp:\n{error_hint_help}" @@ -410,12 +419,12 @@ def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback: return CIFeedback(errors=all_errors) class CIEvoStr(EvolvingStrategy): - def evolve( + def evolve( # noqa: C901, PLR0912, PLR0915 self, evo: Repo, evolving_trace: list[EvoStep] | None = None, - knowledge_l: list[Knowledge] | None = None, - **kwargs: Any, + knowledge_l: list[Knowledge] | None = None, # noqa: ARG002 + **kwargs: Any, # noqa: ARG002 ) -> Repo: @dataclass @@ -433,11 +442,22 @@ class CodeFixGroup: last_feedback: CIFeedback = evolving_trace[-1].feedback # print statistics - checker_error_counts = {checker: sum(c_statistics.values()) for checker, c_statistics in last_feedback.statistics().items()} - print(f"Found [red]{sum(checker_error_counts.values())}[/red] errors, including: " + - ", ".join(f"[red]{count}[/red] [magenta]{checker}[/magenta] errors" for checker, count in checker_error_counts.items())) + checker_error_counts = { + checker: sum(c_statistics.values()) + for checker, c_statistics in last_feedback.statistics().items() + } + print( + f"Found [red]{sum(checker_error_counts.values())}[/red] errors, " + "including: " + + ", ".join( + f"[red]{count}[/red] [magenta]{checker}[/magenta] errors" + for checker, count in checker_error_counts.items() + ), + ) - fix_records: dict[str, FixRecord] = defaultdict(lambda: FixRecord([], [], [], defaultdict(list))) + fix_records: dict[str, FixRecord] = defaultdict( + lambda: FixRecord([], [], [], defaultdict(list)), + ) # Group errors by code blocks fix_groups: dict[str, list[CodeFixGroup]] = defaultdict(list) @@ -450,7 +470,7 @@ class CodeFixGroup: # TODO @bowen: current way of handling errors like 'Add import statement' may be not good for error in errors: if error.code in ("FA100", "FA102"): - changes[file_path].append((0, 0, "from __future__ import annotations\n")) + changes[file_path].append((1, 1, "from __future__ import annotations\n")) break # Group errors by code blocks @@ -485,10 +505,12 @@ class CodeFixGroup: for file_path in fix_groups: file = evo.files[evo.project_path / Path(file_path)] for code_fix_g in fix_groups[file_path]: - start_line, end_line, group_errors = code_fix_g.start_line, code_fix_g.end_line, code_fix_g.errors + start_line = code_fix_g.start_line + end_line = code_fix_g.end_line + group_errors = code_fix_g.errors code_snippet_with_lineno = file.get( - start_line, end_line, add_line_number=True, return_list=False, - ) + start_line, end_line, add_line_number=True, return_list=False, + ) errors_str = "\n\n".join(str(e) for e in group_errors) # ask LLM to repair current code snippet @@ -502,14 +524,18 @@ class CodeFixGroup: session = api.build_chat_session(conversation_id=code_fix_g.session_id) res = session.build_chat_completion(user_prompt) - code_fix_g.responses.append(res) - progress.update(task_id, description=f"[green]Fixing[/green] [cyan]{file_path}[/cyan]...", advance=1) + progress.update( + task_id, + description=f"[green]Fixing[/green] [cyan]{file_path}[/cyan]...", + advance=1, + ) # Manual inspection and repair for file_path in last_feedback.errors: - print(Rule(f"[bright_blue]Checking[/bright_blue] [cyan]{file_path}[/cyan]", style="bright_blue", align="left", characters=".")) + print(Rule(f"[bright_blue]Checking[/bright_blue] [cyan]{file_path}[/cyan]", + style="bright_blue", align="left", characters=".")) file = evo.files[evo.project_path / Path(file_path)] @@ -529,7 +555,8 @@ class CodeFixGroup: # print errors printed_errors_str = "\n".join( - [f"[{error.checker}] {error.line: >{file.lineno_width}}:{error.column: <4} {error.code} {error.msg}" for error in group_errors], + [f"[{error.checker}] {error.line: >{file.lineno_width}}:{error.column: <4}" + f" {error.code} {error.msg}" for error in group_errors], ) print( Panel.fit( @@ -554,13 +581,16 @@ class CodeFixGroup: while True: try: new_code = re.search(r".*```[Pp]ython\n(.*?)\n```.*", res, re.DOTALL).group(1) - except Exception: - print(f"[red]Error when extract codes[/red]:\n {res}") + except (re.error, AttributeError) as exc: + print(f"[red]Error when extract codes[/red]:\n {res}\nException: {exc}") try: fixed_errors_info = re.search(r".*```[Jj]son\n(.*?)\n```.*", res, re.DOTALL).group(1) fixed_errors_info = json.loads(fixed_errors_info) - except Exception: + except AttributeError: fixed_errors_info = None + except (json.JSONDecodeError, re.error) as exc: + fixed_errors_info = None + print(f"[red]Error when extracting fixed_errors[/red]: {exc}") new_code = CodeFile.remove_line_number(new_code) @@ -600,7 +630,8 @@ class CodeFixGroup: print(Panel.fit(table, title="Repair Status")) operation = Prompt.ask("Input your operation [ [red]([bold]s[/bold])kip[/red] / " - "[green]([bold]a[/bold])pply[/green] / [yellow]manual instruction[/yellow] ]") + "[green]([bold]a[/bold])pply[/green] / " + "[yellow]manual instruction[/yellow] ]") print() if operation in ("s", "skip"): fix_records[file_path].skipped_errors.extend(group_errors) @@ -620,7 +651,9 @@ class CodeFixGroup: break fix_records[file_path].manual_instructions[operation].extend(group_errors) - res = session.build_chat_completion(CI_prompts["session_manual_template"].format(operation=operation)) + res = session.build_chat_completion( + CI_prompts["session_manual_template"].format(operation=operation), + ) code_fix_g.responses.append(res) # apply changes @@ -636,14 +669,18 @@ class CodeFixGroup: DIR = Prompt.ask("Please input the [cyan]project directory[/cyan]") DIR = Path(DIR) -excludes = Prompt.ask("Input the [dark_orange]excluded directories[/dark_orange] (relative to [cyan]project path[/cyan] and separated by whitespace)").split(" ") +excludes = Prompt.ask( + "Input the [dark_orange]excluded directories[/dark_orange] (relative to " + "[cyan]project path[/cyan] and separated by whitespace)", +).split(" ") excludes = [Path(exclude.strip()) for exclude in excludes if exclude.strip() != ""] start_time = time.time() start_timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%m%d%H%M") repo = Repo(DIR, excludes=excludes) -evaluator = MultiEvaluator(MypyEvaluator(), RuffEvaluator()) +# evaluator = MultiEvaluator(MypyEvaluator(), RuffEvaluator()) +evaluator = RuffEvaluator() estr = CIEvoStr() rag = None # RAG is not enable firstly. ea = EvoAgent(estr, rag=rag) diff --git a/rdagent/app/factor_extraction_and_implementation/factor_extract_and_implement.py b/rdagent/app/factor_extraction_and_implementation/factor_extract_and_implement.py index 0de00127..25282b37 100644 --- a/rdagent/app/factor_extraction_and_implementation/factor_extract_and_implement.py +++ b/rdagent/app/factor_extraction_and_implementation/factor_extract_and_implement.py @@ -1,19 +1,18 @@ # %% -import json from pathlib import Path +from dotenv import load_dotenv from rdagent.document_process.document_analysis import ( check_factor_viability, + classify_report_from_dict, deduplicate_factors_by_llm, extract_factors_from_report_dict, merge_file_to_factor_dict_to_factor_dict, ) from rdagent.document_process.document_reader import load_and_process_pdfs_by_langchain -from rdagent.document_process.document_analysis import classify_report_from_dict -from dotenv import load_dotenv -def extract_factors_and_implement(report_file_path: str): +def extract_factors_and_implement(report_file_path: str) -> None: assert load_dotenv() docs_dict = load_and_process_pdfs_by_langchain(Path(report_file_path)) diff --git a/rdagent/core/conf.py b/rdagent/core/conf.py index 12b526b2..be5822be 100644 --- a/rdagent/core/conf.py +++ b/rdagent/core/conf.py @@ -1,15 +1,16 @@ -# TODO: use pydantic for other modules in Qlib -# from pydantic_settings import BaseSettings -import os -from typing import Union +from __future__ import annotations + +from pathlib import Path from dotenv import load_dotenv +from pydantic_settings import BaseSettings + +# TODO: use pydantic for other modules in Qlib +# from pydantic_settings import BaseSettings # make sure that env variable is loaded while calling Config() load_dotenv(verbose=True, override=True) -from pydantic_settings import BaseSettings - class FincoSettings(BaseSettings): use_azure: bool = True @@ -22,8 +23,8 @@ class FincoSettings(BaseSettings): dump_embedding_cache: bool = False use_embedding_cache: bool = False workspace: str = "./finco_workspace" - prompt_cache_path: str = os.getcwd() + "/prompt_cache.db" - session_cache_folder_location: str = os.getcwd() + "/session_cache_folder/" + prompt_cache_path: str = str(Path.cwd() / "prompt_cache.db") + session_cache_folder_location: str = str(Path.cwd() / "session_cache_folder/") max_past_message_include: int = 10 use_vector_only: bool = False @@ -37,7 +38,7 @@ class FincoSettings(BaseSettings): chat_max_tokens: int = 3000 chat_temperature: float = 0.5 chat_stream: bool = True - chat_seed: Union[int, None] = None + chat_seed: int | None = None chat_frequency_penalty: float = 0.0 chat_presence_penalty: float = 0.0 diff --git a/rdagent/core/log.py b/rdagent/core/log.py index 7c421ac1..87630fa9 100644 --- a/rdagent/core/log.py +++ b/rdagent/core/log.py @@ -78,17 +78,10 @@ class FinCoLog: def __init__(self) -> None: self.logger: Logger = logger - def info(self, *args: Sequence, plain: bool = False, title: str = "Info") -> None: + def info(self, *args: Sequence, plain: bool = False) -> None: if plain: return self.plain_info(*args) for arg in args: - # Changes to accommodate ruff checks. - # Original code: - # self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}") - # Description of the problem: - # G004 Logging statement uses f-string - # References: - # https://docs.astral.sh/ruff/rules/logging-f-string/ info = f"{LogColors.WHITE}{arg}{LogColors.END}" self.logger.info(info) return None @@ -96,11 +89,6 @@ def info(self, *args: Sequence, plain: bool = False, title: str = "Info") -> Non def __getstate__(self) -> dict: return {} - # Changes to accommodate ruff checks. - # Original code: def __setstate__(self, _: str) -> None: - # Description of the problem: - # PLE0302 The special method `__setstate__` expects 2 parameters, 1 was given - # References: https://docs.astral.sh/ruff/rules/unexpected-special-method-signature/ def __setstate__(self, _: str) -> None: self.logger = logger diff --git a/rdagent/core/prompts.py b/rdagent/core/prompts.py index 4b6b8cef..32ab9046 100644 --- a/rdagent/core/prompts.py +++ b/rdagent/core/prompts.py @@ -5,18 +5,14 @@ from rdagent.core.utils import SingletonBaseClass -class Prompts(Dict[str, str], SingletonBaseClass): - def __init__(self, file_path: Path): - prompt_yaml_dict = yaml.load( - open( - file_path, - encoding="utf8", - ), - Loader=yaml.FullLoader, - ) +class Prompts(SingletonBaseClass, Dict[str, str]): + def __init__(self, file_path: Path) -> None: + with file_path.open(encoding="utf8") as file: + prompt_yaml_dict = yaml.safe_load(file) if prompt_yaml_dict is None: - raise ValueError(f"Failed to load prompts from {file_path}") + error_message = f"Failed to load prompts from {file_path}" + raise ValueError(error_message) for key, value in prompt_yaml_dict.items(): self[key] = value diff --git a/rdagent/core/utils.py b/rdagent/core/utils.py index f0e058f6..69932504 100644 --- a/rdagent/core/utils.py +++ b/rdagent/core/utils.py @@ -8,49 +8,48 @@ import string from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, ClassVar import yaml from fuzzywuzzy import fuzz -class RDAgentException(Exception): +class RDAgentException(Exception): # noqa: N818 pass class SingletonMeta(type): - _instance_dict = {} + _instance_dict: ClassVar[dict] = {} - def __call__(cls, *args, **kwargs): + def __call__(cls, *args: Any, **kwargs: Any) -> Any: # Since it's hard to align the difference call using args and kwargs, we strictly ask to use kwargs in Singleton - if len(args) > 0: - raise RDAgentException("Please only use kwargs in Singleton to avoid misunderstanding.") + if args: + exception_message = "Please only use kwargs in Singleton to avoid misunderstanding." + raise RDAgentException(exception_message) kwargs_hash = hash(tuple(sorted(kwargs.items()))) if kwargs_hash not in cls._instance_dict: - cls._instance_dict[kwargs_hash] = super(SingletonMeta, cls).__call__(*args, **kwargs) + cls._instance_dict[kwargs_hash] = super().__call__(**kwargs) return cls._instance_dict[kwargs_hash] - class SingletonBaseClass(metaclass=SingletonMeta): """ - Because we try to support defining Singleton with `class A(SingletonBaseClass)` instead of `A(metaclass=SingletonMeta)` - This class becomes necessary - + Because we try to support defining Singleton with `class A(SingletonBaseClass)` + instead of `A(metaclass=SingletonMeta)` this class becomes necessary. """ # TODO: Add move this class to Qlib's general utils. -def parse_json(response): +def parse_json(response: str) -> Any: try: return json.loads(response) except json.decoder.JSONDecodeError: pass - - raise Exception(f"Failed to parse response: {response}, please report it or help us to fix it.") + error_message = f"Failed to parse response: {response}, please report it or help us to fix it." + raise ValueError(error_message) -def similarity(text1, text2): +def similarity(text1: str, text2: str) -> int: text1 = text1 if isinstance(text1, str) else "" text2 = text2 if isinstance(text2, str) else "" @@ -58,12 +57,12 @@ def similarity(text1, text2): return fuzz.ratio(text1, text2) -def random_string(length=10): +def random_string(length: int = 10) -> str: letters = string.ascii_letters + string.digits - return "".join(random.choice(letters) for i in range(length)) + return "".join(random.SystemRandom().choice(letters) for _ in range(length)) -def remove_uncommon_keys(new_dict, org_dict): +def remove_uncommon_keys(new_dict: dict, org_dict: dict) -> None: keys_to_remove = [] for key in new_dict: @@ -78,25 +77,25 @@ def remove_uncommon_keys(new_dict, org_dict): del new_dict[key] -def crawl_the_folder(folder_path: Path): +def crawl_the_folder(folder_path: Path) -> list: yaml_files = [] for root, _, files in os.walk(folder_path.as_posix()): for file in files: - if file.endswith(".yaml") or file.endswith(".yml"): - yaml_file_path = Path(os.path.join(root, file)).relative_to(folder_path) - yaml_files.append(yaml_file_path.as_posix()) + if file.endswith((".yaml", ".yml")): + yaml_file_path = Path(root) / file + yaml_files.append(str(yaml_file_path.relative_to(folder_path))) return sorted(yaml_files) -def compare_yaml(file1, file2): - with open(file1) as stream: +def compare_yaml(file1: Path | str, file2: Path | str) -> bool: + with Path(file1).open() as stream: data1 = yaml.safe_load(stream) - with open(file2) as stream: + with Path(file2).open() as stream: data2 = yaml.safe_load(stream) return data1 == data2 -def remove_keys(valid_keys, ori_dict): +def remove_keys(valid_keys: set[Any], ori_dict: dict[Any, Any]) -> dict[Any, Any]: for key in list(ori_dict.keys()): if key not in valid_keys: ori_dict.pop(key) @@ -106,14 +105,14 @@ def remove_keys(valid_keys, ori_dict): class YamlConfigCache(SingletonBaseClass): def __init__(self) -> None: super().__init__() - self.path_to_config = dict() + self.path_to_config = {} - def load(self, path): - with open(path) as stream: + def load(self, path: str) -> None: + with Path(path).open() as stream: data = yaml.safe_load(stream) self.path_to_config[path] = data - def __getitem__(self, path): + def __getitem__(self, path: str) -> Any: if path not in self.path_to_config: self.load(path) return self.path_to_config[path] diff --git a/rdagent/document_process/document_analysis.py b/rdagent/document_process/document_analysis.py index dc390a09..984cfb2e 100644 --- a/rdagent/document_process/document_analysis.py +++ b/rdagent/document_process/document_analysis.py @@ -4,28 +4,20 @@ import multiprocessing as mp import re from pathlib import Path -from typing import TYPE_CHECKING, Mapping +from typing import Mapping import numpy as np import pandas as pd import tiktoken -import yaml -from azure.ai.formrecognizer import DocumentAnalysisClient -from azure.core.credentials import AzureKeyCredential +from jinja2 import Template from rdagent.core.conf import FincoSettings as Config from rdagent.core.log import FinCoLog from rdagent.core.prompts import Prompts -from jinja2 import Template from rdagent.oai.llm_utils import APIBackend, create_embedding_with_multiprocessing from sklearn.cluster import KMeans from sklearn.metrics.pairwise import cosine_similarity from sklearn.preprocessing import normalize -if TYPE_CHECKING: - from langchain_core.documents import Document - -from langchain.document_loaders import PyPDFDirectoryLoader, PyPDFLoader - document_process_prompts = Prompts(file_path=Path(__file__).parent / "prompts.yaml") @@ -522,9 +514,9 @@ def __deduplicate_factor_dict(factor_dict: dict[str, dict[str, str]]) -> list[li return duplication_names_list -def deduplicate_factors_by_llm( +def deduplicate_factors_by_llm( # noqa: C901, PLR0912 factor_dict: dict[str, dict[str, str]], - factor_viability_dict: dict[str, dict[str, str]] = None, + factor_viability_dict: dict[str, dict[str, str]] | None = None, ) -> list[list[str]]: final_duplication_names_list = [] current_round_factor_dict = factor_dict @@ -559,7 +551,7 @@ def deduplicate_factors_by_llm( continue to_replace_dict[duplication_factor_name] = target_factor_name - llm_deduplicated_factor_dict = dict() + llm_deduplicated_factor_dict = {} added_lower_name_set = set() for factor_name in factor_dict: if factor_name not in to_replace_dict and factor_name.lower() not in added_lower_name_set: diff --git a/rdagent/document_process/document_reader.py b/rdagent/document_process/document_reader.py index 2e8ad630..33d6d9b4 100644 --- a/rdagent/document_process/document_reader.py +++ b/rdagent/document_process/document_reader.py @@ -1,16 +1,16 @@ from __future__ import annotations from pathlib import Path +from typing import TYPE_CHECKING -import yaml from azure.ai.formrecognizer import DocumentAnalysisClient from azure.core.credentials import AzureKeyCredential -from rdagent.core.conf import FincoSettings as Config -from rdagent.core.prompts import Prompts - -from langchain_core.documents import Document from langchain.document_loaders import PyPDFDirectoryLoader, PyPDFLoader +if TYPE_CHECKING: + from langchain_core.documents import Document +from rdagent.core.conf import FincoSettings as Config + def load_documents_by_langchain(path: Path) -> list: """Load documents from the specified path. diff --git a/rdagent/factor_implementation/evolving/evaluators.py b/rdagent/factor_implementation/evolving/evaluators.py index 930a9fee..06fc72d2 100644 --- a/rdagent/factor_implementation/evolving/evaluators.py +++ b/rdagent/factor_implementation/evolving/evaluators.py @@ -3,6 +3,7 @@ import re from typing import List +from pandas.core.api import DataFrame as DataFrame from rdagent.core.evolving_framework import Evaluator as EvolvingEvaluator from rdagent.core.evolving_framework import Feedback, QueriedKnowledge from rdagent.core.log import FinCoLog @@ -20,8 +21,6 @@ FactorImplementation, FactorImplementationTask, ) -from pandas.core.api import DataFrame as DataFrame - from rdagent.factor_implementation.share_modules.factor_implementation_config import ( FactorImplementSettings, ) diff --git a/rdagent/factor_implementation/evolving/evolving_strategy.py b/rdagent/factor_implementation/evolving/evolving_strategy.py index 7a1c19ed..c00e3458 100644 --- a/rdagent/factor_implementation/evolving/evolving_strategy.py +++ b/rdagent/factor_implementation/evolving/evolving_strategy.py @@ -1,29 +1,28 @@ from __future__ import annotations import json -from pathlib import Path import random from abc import abstractmethod from copy import deepcopy +from pathlib import Path from typing import TYPE_CHECKING +from jinja2 import Template from rdagent.core.evolving_framework import EvolvingStrategy, QueriedKnowledge +from rdagent.core.prompts import Prompts from rdagent.core.utils import multiprocessing_wrapper from rdagent.factor_implementation.share_modules.factor import ( FactorImplementation, FactorImplementationTask, FileBasedFactorImplementation, ) -from rdagent.core.prompts import Prompts -from jinja2 import Template -from rdagent.oai.llm_utils import APIBackend - from rdagent.factor_implementation.share_modules.factor_implementation_config import ( FactorImplementSettings, ) from rdagent.factor_implementation.share_modules.factor_implementation_utils import ( get_data_folder_intro, ) +from rdagent.oai.llm_utils import APIBackend if TYPE_CHECKING: from factor_implementation.evolving.evolvable_subjects import ( @@ -234,7 +233,7 @@ def implement_one_factor( Template( Prompts(file_path=Path(__file__).parent.parent / "prompts.yaml")[ "evolving_strategy_error_summary_v2_system" - ] + ], ) .render( factor_information_str=target_factor_task_information, @@ -252,7 +251,7 @@ def implement_one_factor( Template( Prompts(file_path=Path(__file__).parent.parent / "prompts.yaml")[ "evolving_strategy_error_summary_v2_user" - ] + ], ) .render( queried_similar_component_knowledge=queried_similar_component_knowledge_to_render, diff --git a/rdagent/factor_implementation/evolving/factor_implementation_evolving_cli.py b/rdagent/factor_implementation/evolving/factor_implementation_evolving_cli.py index ba5eb97f..0cd40fe8 100644 --- a/rdagent/factor_implementation/evolving/factor_implementation_evolving_cli.py +++ b/rdagent/factor_implementation/evolving/factor_implementation_evolving_cli.py @@ -4,6 +4,7 @@ from pathlib import Path import pandas as pd +from fire.core import Fire from rdagent.core.evolving_framework import EvoAgent, KnowledgeBase from rdagent.core.utils import multiprocessing_wrapper from rdagent.factor_implementation.evolving.evaluators import ( @@ -25,7 +26,6 @@ FactorImplementationTask, FileBasedFactorImplementation, ) -from fire.core import Fire from tqdm import tqdm ALPHA101_INIT_COMPONENTS = [ diff --git a/rdagent/factor_implementation/evolving/knowledge_management.py b/rdagent/factor_implementation/evolving/knowledge_management.py index d79f5a81..f7182906 100644 --- a/rdagent/factor_implementation/evolving/knowledge_management.py +++ b/rdagent/factor_implementation/evolving/knowledge_management.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import Union +from jinja2 import Template from rdagent.core.evolving_framework import ( EvolvableSubjects, EvoStep, @@ -17,19 +18,17 @@ RAGStrategy, ) from rdagent.core.log import FinCoLog +from rdagent.core.prompts import Prompts from rdagent.factor_implementation.evolving.evaluators import FactorImplementationSingleFeedback from rdagent.factor_implementation.share_modules.factor import ( FactorImplementation, FactorImplementationTask, ) -from rdagent.core.prompts import Prompts -from rdagent.knowledge_management.graph import UndirectedGraph, UndirectedNode -from jinja2 import Template -from rdagent.oai.llm_utils import APIBackend, calculate_embedding_distance_between_str_list - from rdagent.factor_implementation.share_modules.factor_implementation_config import ( FactorImplementSettings, ) +from rdagent.knowledge_management.graph import UndirectedGraph, UndirectedNode +from rdagent.oai.llm_utils import APIBackend, calculate_embedding_distance_between_str_list class FactorImplementationKnowledge(Knowledge): diff --git a/rdagent/factor_implementation/share_modules/evaluator.py b/rdagent/factor_implementation/share_modules/evaluator.py index 3820e0a7..82959e13 100644 --- a/rdagent/factor_implementation/share_modules/evaluator.py +++ b/rdagent/factor_implementation/share_modules/evaluator.py @@ -3,18 +3,18 @@ from typing import Tuple import pandas as pd -from factor_implementation.share_modules.factor import ( - FactorImplementation, - FactorImplementationTask, -) -from factor_implementation.share_modules.prompt import FactorImplementationPrompts from finco.log import FinCoLog from jinja2 import Template -from rdagent.oai.llm_utils import APIBackend - from rdagent.factor_implementation.share_modules.factor_implementation_config import ( FactorImplementSettings, ) +from rdagent.oai.llm_utils import APIBackend + +from factor_implementation.share_modules.factor import ( + FactorImplementation, + FactorImplementationTask, +) +from factor_implementation.share_modules.prompt import FactorImplementationPrompts class Evaluator(ABC): diff --git a/rdagent/factor_implementation/share_modules/factor.py b/rdagent/factor_implementation/share_modules/factor.py index 5d211944..9916cb87 100644 --- a/rdagent/factor_implementation/share_modules/factor.py +++ b/rdagent/factor_implementation/share_modules/factor.py @@ -6,19 +6,19 @@ from typing import Tuple, Union import pandas as pd +from filelock import FileLock +from finco.log import FinCoLog +from rdagent.factor_implementation.share_modules.factor_implementation_config import ( + FactorImplementSettings, +) + from factor_implementation.share_modules.exception import ( CodeFormatException, NoOutputException, RuntimeErrorException, ) -from filelock import FileLock -from finco.log import FinCoLog from oai.llm_utils import md5_hash -from rdagent.factor_implementation.share_modules.factor_implementation_config import ( - FactorImplementSettings, -) - class FactorImplementationTask: # TODO: remove the factor_ prefix may be better diff --git a/rdagent/factor_implementation/share_modules/factor_implementation_utils.py b/rdagent/factor_implementation/share_modules/factor_implementation_utils.py index 4538fb12..56407ba0 100644 --- a/rdagent/factor_implementation/share_modules/factor_implementation_utils.py +++ b/rdagent/factor_implementation/share_modules/factor_implementation_utils.py @@ -4,7 +4,6 @@ # render it with jinja from jinja2 import Template - from rdagent.factor_implementation.share_modules.factor_implementation_config import FIS TPL = """ diff --git a/rdagent/knowledge_management/graph.py b/rdagent/knowledge_management/graph.py index b3f4bf19..7459bcb1 100644 --- a/rdagent/knowledge_management/graph.py +++ b/rdagent/knowledge_management/graph.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import pickle import random from collections import deque from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Any, NoReturn from finco.llm import APIBackend from finco.vector_base import KnowledgeMetaData, PDVectorBase, VectorBase, cosine @@ -11,29 +13,29 @@ class UndirectedNode(Node): - def __init__(self, content: str = "", label: str = "", embedding=None): + def __init__(self, content: str = "", label: str = "", embedding: Any = None) -> None: super().__init__(content, label, embedding) - self.neighbors = set() + self.neighbors: set[UndirectedNode] = set() - def add_neighbor(self, node): + def add_neighbor(self, node: UndirectedNode) -> None: self.neighbors.add(node) node.neighbors.add(self) - def remove_neighbor(self, node): + def remove_neighbor(self, node: UndirectedNode) -> None: if node in self.neighbors: self.neighbors.remove(node) node.neighbors.remove(self) - def get_neighbors(self): + def get_neighbors(self) -> set[UndirectedNode]: return self.neighbors - def __str__(self): + def __str__(self) -> str: return ( f"UndirectedNode(id={self.id}, label={self.label}, content={self.content[:100]}, " f"neighbors={self.neighbors})" ) - def __repr__(self): + def __repr__(self) -> str: return ( f"UndirectedNode(id={self.id}, label={self.label}, content={self.content[:100]}, " f"neighbors={self.neighbors})" @@ -45,53 +47,49 @@ class Graph: base Graph class for Knowledge Graph Search """ - def __init__(self, path: Union[str, Path] = None): + def __init__(self, path: str | Path | None = None) -> None: self.path = path self.nodes = {} - def size(self): + def size(self) -> int: return len(self.nodes) - def get_node(self, node_id: str) -> Node: - node = self.nodes.get(node_id) - return node + def get_node(self, node_id: str) -> Node | None: + return self.nodes.get(node_id) - def add_node(self, **kwargs): + def add_node(self, **kwargs: Any) -> NoReturn: raise NotImplementedError - def get_all_nodes(self) -> List: + def get_all_nodes(self) -> list[Node]: return list(self.nodes.values()) - def get_all_nodes_by_label_list(self, label_list: List[str]) -> List: - node_list = [] - for node in self.nodes.values(): - if node.label in label_list: - node_list.append(node) - return node_list + def get_all_nodes_by_label_list(self, label_list: list[str]) -> list[Node]: + return [node for node in self.nodes.values() if node.label in label_list] - def find_node(self, content: str, label: str): + def find_node(self, content: str, label: str) -> Node | None: for node in self.nodes.values(): if node.content == content and node.label == label: return node + return None @classmethod - def load(cls, path: Union[str, Path]): + def load(cls: type[Graph], path: str | Path) -> Graph: """use pickle as the default load method""" path = path if isinstance(path, Path) else Path(path) if not path.exists(): - return Graph(path=path) + return cls(path=path) - with open(path, "rb") as f: + with path.open("rb") as f: return pickle.load(f) - def save(self, path: Union[str, Path], **kwargs): + def save(self, path: str | Path) -> None: """use pickle as the default save method""" Path.mkdir(path.parent, exist_ok=True) - with open(path, "wb") as f: + with path.open("wb") as f: pickle.dump(self, f) @staticmethod - def batch_embedding(nodes: List[Node]): + def batch_embedding(nodes: list[Node]) -> list[Node]: contents = [node.content for node in nodes] # openai create embedding API input's max length is 16 size = 16 @@ -101,14 +99,12 @@ def batch_embedding(nodes: List[Node]): APIBackend().create_embedding(input_content=contents[i : i + size]), ) - assert len(nodes) == len( - embeddings, - ), "nodes' length must equals embeddings' length" + assert len(nodes) == len(embeddings), "nodes' length must equals embeddings' length" for node, embedding in zip(nodes, embeddings): node.embedding = embedding return nodes - def __str__(self): + def __str__(self) -> str: return f"Graph(nodes={self.nodes})" @@ -117,19 +113,19 @@ class UndirectedGraph(Graph): Undirected Graph which edges have no relationship """ - def __init__(self, path: Union[str, Path] = None): + def __init__(self, path: str | Path | None = None) -> None: super().__init__(path=path) self.vector_base: VectorBase = PDVectorBase() - def __str__(self): + def __str__(self) -> str: return f"UndirectedGraph(nodes={self.nodes})" def add_node( self, node: UndirectedNode, neighbor: UndirectedNode = None, - same_node_threshold=0.95, - ): + same_node_threshold: float = 0.95, # noqa: ARG002 + ) -> None: """ add node and neighbor to the Graph Parameters @@ -174,27 +170,27 @@ def add_node( node.add_neighbor(neighbor) @classmethod - def load(cls, path: Union[str, Path]): + def load(cls: type[UndirectedGraph], path: str | Path) -> UndirectedGraph: """use pickle as the default load method""" path = path if isinstance(path, Path) else Path(path) if not path.exists(): - return UndirectedGraph(path=path) + return cls(path=path) - with open(path, "rb") as f: + with path.open("rb") as f: return pickle.load(f) - def add_nodes(self, node: UndirectedNode, neighbors: List[UndirectedNode]): - if not len(neighbors): + def add_nodes(self, node: UndirectedNode, neighbors: list[UndirectedNode]) -> None: + if not neighbors: self.add_node(node) else: for neighbor in neighbors: self.add_node(node, neighbor=neighbor) def get_node(self, node_id: str) -> UndirectedNode: - node = self.nodes.get(node_id) - return node + return self.nodes.get(node_id) + - def get_node_by_content(self, content: str) -> Union[UndirectedNode, None]: + def get_node_by_content(self, content: str) -> UndirectedNode | None: """ Get node by semantic distance Parameters @@ -208,18 +204,18 @@ def get_node_by_content(self, content: str) -> Union[UndirectedNode, None]: if content == "Model": pass match = self.semantic_search(node=content, similarity_threshold=0.999) - if len(match): + if match: return match[0] - else: - return None + return None def get_nodes_within_steps( self, start_node: UndirectedNode, steps: int = 1, - constraint_labels: List[str] = None, + constraint_labels: list[str] | None = None, + *, block: bool = False, - ) -> List[UndirectedNode]: + ) -> list[UndirectedNode]: """ Returns the nodes in the graph whose distance from node is less than or equal to step """ @@ -238,24 +234,23 @@ def get_nodes_within_steps( result.append(node) for neighbor in sorted( - list(self.get_node(node.id).neighbors), key=lambda x: x.content, + self.get_node(node.id).neighbors, key=lambda x: x.content, ): # to make sure the result is deterministic - if neighbor not in visited: - if not (block and neighbor.label not in constraint_labels): - queue.append((neighbor, current_steps + 1)) + if neighbor not in visited and not (block and neighbor.label not in constraint_labels): + queue.append((neighbor, current_steps + 1)) if constraint_labels: result = [node for node in result if node.label in constraint_labels] if start_node in result: - result.pop(result.index(start_node)) + result.remove(start_node) return result def get_nodes_intersection( self, - nodes: List[UndirectedNode], + nodes: list[UndirectedNode], steps: int = 1, - constraint_labels: List[str] = None, - ) -> List[UndirectedNode]: + constraint_labels: list[str] | None = None, + ) -> list[UndirectedNode]: """ Get the intersection with nodes connected within n steps of nodes @@ -269,7 +264,8 @@ def get_nodes_intersection( ------- """ - assert len(nodes) >= 2, "nodes length must >=2" + min_nodes_count = 2 + assert len(nodes) >= min_nodes_count, "nodes length must >=2" intersection = None for node in nodes: @@ -283,15 +279,14 @@ def get_nodes_intersection( node, steps=steps, constraint_labels=constraint_labels, ), ) - return intersection def semantic_search( self, - node: Union[UndirectedNode, str], + node: UndirectedNode | str, similarity_threshold: float = 0.0, topk_k: int = 5, - ) -> List[UndirectedNode]: + ) -> list[UndirectedNode]: """ semantic search by node's embedding @@ -299,7 +294,8 @@ def semantic_search( ---------- topk_k node - similarity_threshold: Returns nodes whose distance score from the input node is greater than similarity_threshold + similarity_threshold: Returns nodes whose distance score from the input + node is greater than similarity_threshold Returns ------- @@ -312,10 +308,9 @@ def semantic_search( topk_k=topk_k, similarity_threshold=similarity_threshold, ) - nodes = [self.get_node(doc.id) for doc in docs] - return nodes + return [self.get_node(doc.id) for doc in docs] - def clear(self): + def clear(self) -> None: self.nodes.clear() self.vector_base: VectorBase = PDVectorBase() @@ -323,11 +318,12 @@ def query_by_node( self, node: UndirectedNode, step: int = 1, - constraint_labels: List[str] = None, - constraint_node: UndirectedNode = None, + constraint_labels: list[str] | None = None, + constraint_node: UndirectedNode | None = None, constraint_distance: float = 0, + *, block: bool = False, - ) -> List[UndirectedNode]: + ) -> list[UndirectedNode]: """ search graph by connection, return empty list if nodes' chain without node near to constraint_node Parameters @@ -358,29 +354,39 @@ def query_by_node( def query_by_content( self, - content: Union[str, List[str]], + content: str | list[str], topk_k: int = 5, step: int = 1, - constraint_labels: List[str] = None, - constraint_node: UndirectedNode = None, + constraint_labels: list[str] | None = None, + constraint_node: UndirectedNode | None = None, similarity_threshold: float = 0.0, constraint_distance: float = 0, + *, block: bool = False, - ) -> List[UndirectedNode]: + ) -> list[UndirectedNode]: """ - search graph by content similarity and connection relationship, return empty list if nodes' chain without node - near to constraint_node + Search graph by content similarity and connection relationship, return empty + list if nodes' chain without node near to constraint_node. Parameters ---------- - constraint_distance : float the distance between the node and the constraint_node + constraint_distance : float + The distance between the node and the constraint_node. content : Union[str, List[str]] - topk_k: the upper number of output for each query, if the number of fit nodes is less than topk_k, return all fit nodes's content - step : the maximum distance between the start node and the result node - constraint_labels : the type of nodes that the search can only flow through - constraint_node : the node that the search can only flow through - similarity_threshold : the similarity threshold of the content - block: despite the start node, the search can only flow through the constraint_label type nodes + Content to search for. + topk_k: int + The upper number of output for each query. If the number of fit nodes is + less than topk_k, returns all fit nodes' content. + step : int + The maximum distance between the start node and the result node. + constraint_labels : List[str] + The type of nodes that the search can only flow through. + constraint_node : UndirectedNode, optional + The node that the search can only flow through. + similarity_threshold : float + The similarity threshold of the content. + block: bool + Despite the start node, the search can only flow through the constraint_label type nodes. Returns ------- @@ -422,54 +428,50 @@ def query_by_content( return res_list @staticmethod - def intersection(nodes1: List[UndirectedNode], nodes2: List[UndirectedNode]): - intersection = [node for node in nodes1 if node in nodes2] - return intersection + def intersection(nodes1: list[UndirectedNode], nodes2: list[UndirectedNode]) -> list[UndirectedNode]: + return [node for node in nodes1 if node in nodes2] @staticmethod - def different(nodes1: List[UndirectedNode], nodes2: List[UndirectedNode]): - difference = list(set(nodes1).symmetric_difference(set(nodes2))) - return difference + def different(nodes1: list[UndirectedNode], nodes2: list[UndirectedNode]) -> list[UndirectedNode]: + return list(set(nodes1).symmetric_difference(set(nodes2))) @staticmethod - def cal_distance(node1: UndirectedNode, node2: UndirectedNode): - distance = cosine(node1.embedding, node2.embedding) - return distance + def cal_distance(node1: UndirectedNode, node2: UndirectedNode) -> float: + return cosine(node1.embedding, node2.embedding) @staticmethod - def filter_label(nodes: List[UndirectedNode], labels: List[str]): - nodes = [node for node in nodes if node.label in labels] - return nodes + def filter_label(nodes: list[UndirectedNode], labels: list[str]) -> list[UndirectedNode]: + return [node for node in nodes if node.label in labels] -def graph_to_edges(graph: Dict[str, List[str]]): + +def graph_to_edges(graph: dict[str, list[str]]) -> list[tuple[str, str]]: edges = [] for node, neighbors in graph.items(): for neighbor in neighbors: - if [node, neighbor] in edges or [neighbor, node] in edges: + if (node, neighbor) in edges or (neighbor, node) in edges: continue - edges.append([node, neighbor]) + edges.append((node, neighbor)) return edges def assign_random_coordinate_to_node( - nodes: List, scope: float = 1.0, origin: Tuple = (0.0, 0.0), -) -> Dict: + nodes: list[str], scope: float = 1.0, origin: tuple[float, float] = (0.0, 0.0), +) -> dict[str, tuple[float, float]]: coordinates = {} - for node in nodes: - x = random.uniform(0, scope) + origin[0] - y = random.uniform(0, scope) + origin[1] + x = random.SystemRandom().uniform(0, scope) + origin[0] + y = random.SystemRandom().uniform(0, scope) + origin[1] coordinates[node] = (x, y) return coordinates def assign_isometric_coordinate_to_node( - nodes: List, x_step: float = 1.0, x_origin: float = 0.0, y_origin: float = 0.0, -) -> Dict: + nodes: list, x_step: float = 1.0, x_origin: float = 0.0, y_origin: float = 0.0, +) -> dict: coordinates = {} for i, node in enumerate(nodes): @@ -481,10 +483,10 @@ def assign_isometric_coordinate_to_node( def curly_node_coordinate( - coordinates: Dict, center_y: float = 1.0, r: float = 1.0, -) -> Dict: + coordinates: dict, center_y: float = 1.0, r: float = 1.0, +) -> dict: # noto: this method can only curly < 90 degree, and the curl line is circle. - # the original funtion is: x**2 + (y-m)**2 = r**2 + # the original function is: x**2 + (y-m)**2 = r**2 for node, coordinate in coordinates.items(): - coordinate[1] = center_y + (r**2 - coordinate[0] ** 2) ** 0.5 + coordinates[node] = (coordinate[0], center_y + (r**2 - coordinate[0] ** 2) ** 0.5) return coordinates diff --git a/rdagent/oai/llm_utils.py b/rdagent/oai/llm_utils.py index f20b990b..cd4bd30c 100644 --- a/rdagent/oai/llm_utils.py +++ b/rdagent/oai/llm_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime import hashlib import json @@ -11,7 +13,7 @@ import uuid from copy import deepcopy from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Any import numpy as np import tiktoken @@ -22,12 +24,11 @@ DEFAULT_QLIB_DOT_PATH = Path("./") -def md5_hash(input_string): - md5 = hashlib.md5() +def md5_hash(input_string: str) -> str: + hash_md5 = hashlib.md5(usedforsecurity=False) input_bytes = input_string.encode("utf-8") - md5.update(input_bytes) - hashed_string = md5.hexdigest() - return hashed_string + hash_md5.update(input_bytes) + return hash_md5.hexdigest() try: @@ -54,14 +55,14 @@ class ConvManager: def __init__( self, - path: Union[Path, str] = DEFAULT_QLIB_DOT_PATH / "llm_conv", + path: Path | str = DEFAULT_QLIB_DOT_PATH / "llm_conv", recent_n: int = 10, ) -> None: self.path = Path(path) self.path.mkdir(parents=True, exist_ok=True) self.recent_n = recent_n - def _rotate_files(self): + def _rotate_files(self) -> None: pairs = [] for f in self.path.glob("*.json"): m = re.match(r"(\d+).json", f.name) @@ -70,21 +71,22 @@ def _rotate_files(self): pairs.append((n, f)) pairs.sort(key=lambda x: x[0]) for n, f in pairs[: self.recent_n][::-1]: - if Path(self.path / f"{n+1}.json").exists(): - os.remove(self.path / f"{n+1}.json") + if (self.path / f"{n+1}.json").exists(): + (self.path / f"{n+1}.json").unlink() f.rename(self.path / f"{n+1}.json") - def append(self, conv: Tuple[list, str]): + def append(self, conv: tuple[list, str]) -> None: self._rotate_files() - json.dump(conv, open(self.path / "0.json", "w")) + with (self.path / "0.json").open("w") as file: + json.dump(conv, file) # TODO: reseve line breaks to make it more convient to edit file directly. class SQliteLazyCache(SingletonBaseClass): - def __init__(self, cache_location) -> None: + def __init__(self, cache_location: str) -> None: super().__init__() self.cache_location = cache_location - db_file_exist = os.path.exists(cache_location) + db_file_exist = Path(cache_location).exists() self.conn = sqlite3.connect(cache_location) self.c = self.conn.cursor() if not db_file_exist: @@ -106,25 +108,23 @@ def __init__(self, cache_location) -> None: ) self.conn.commit() - def chat_get(self, key): + def chat_get(self, key: str) -> str | None: md5_key = md5_hash(key) self.c.execute("SELECT chat FROM chat_cache WHERE md5_key=?", (md5_key,)) result = self.c.fetchone() if result is None: return None - else: - return result[0] + return result[0] - def embedding_get(self, key): + def embedding_get(self, key: str) -> list | dict | str | None: md5_key = md5_hash(key) self.c.execute("SELECT embedding FROM embedding_cache WHERE md5_key=?", (md5_key,)) result = self.c.fetchone() if result is None: return None - else: - return json.loads(result[0]) + return json.loads(result[0]) - def chat_set(self, key, value): + def chat_set(self, key: str, value: str) -> None: md5_key = md5_hash(key) self.c.execute( "INSERT OR REPLACE INTO chat_cache (md5_key, chat) VALUES (?, ?)", @@ -132,7 +132,8 @@ def chat_set(self, key, value): ) self.conn.commit() - def embedding_set(self, content_to_embedding_dict): + + def embedding_set(self, content_to_embedding_dict: dict) -> None: for key, value in content_to_embedding_dict.items(): md5_key = md5_hash(key) self.c.execute( @@ -160,26 +161,26 @@ def __init__(self) -> None: conversation_content = json.load(f) self.cache[conversation_id] = conversation_content["content"] - def message_get(self, conversation_id: str): + def message_get(self, conversation_id: str) -> list[str]: return self.cache.get(conversation_id, []) - def message_set(self, conversation_id, message_value): + def message_set(self, conversation_id: str, message_value: list[str]) -> None: self.cache[conversation_id] = message_value conversation_path = self.session_cache_location / conversation_id conversation_path = conversation_path.with_suffix(".json") - current_time = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S") - with open(conversation_path, "w") as f: + current_time = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d-%H-%M-%S") + with conversation_path.open("w") as f: json.dump({"content": message_value, "last_modified_time": current_time}, f) class ChatSession: - def __init__(self, api_backend, conversation_id=None, system_prompt=None): + def __init__(self, api_backend: Any, conversation_id: str | None = None, system_prompt: str | None = None) -> None: self.conversation_id = str(uuid.uuid4()) if conversation_id is None else conversation_id self.cfg = Config() self.system_prompt = system_prompt if system_prompt is not None else self.cfg.default_system_prompt self.api_backend = api_backend - def build_chat_completion_message(self, user_prompt, **kwargs): + def build_chat_completion_message(self, user_prompt: str) -> list[dict[str, Any]]: history_message = SessionChatHistoryCache().message_get(self.conversation_id) messages = history_message if not messages: @@ -192,18 +193,18 @@ def build_chat_completion_message(self, user_prompt, **kwargs): ) return messages - def build_chat_completion_message_and_calculate_token(self, user_prompt, **kwargs): - messages = self.build_chat_completion_message(user_prompt, **kwargs) + def build_chat_completion_message_and_calculate_token(self, user_prompt: str) -> Any: + messages = self.build_chat_completion_message(user_prompt) return self.api_backend.calculate_token_from_messages(messages) - def build_chat_completion(self, user_prompt, **kwargs): + def build_chat_completion(self, user_prompt: str, **kwargs: Any) -> str: """ this function is to build the session messages user prompt should always be provided """ messages = self.build_chat_completion_message(user_prompt, **kwargs) - response = self.api_backend._try_create_chat_completion_or_embedding( + response = self.api_backend._try_create_chat_completion_or_embedding( # noqa: SLF001 messages=messages, chat_completion=True, **kwargs, ) messages.append( @@ -215,30 +216,30 @@ def build_chat_completion(self, user_prompt, **kwargs): SessionChatHistoryCache().message_set(self.conversation_id, messages) return response - def get_conversation_id(self): + def get_conversation_id(self) -> str: return self.conversation_id - def dispaly_history(): + def display_history(self) -> None: # TODO: Realize a beautiful presentation format for history messages pass class APIBackend: - def __init__( + def __init__( # noqa: C901, PLR0912, PLR0915 self, *, - chat_api_key=None, - chat_model=None, - chat_api_base=None, - chat_api_version=None, - embedding_api_key=None, - embedding_model=None, - embedding_api_base=None, - embedding_api_version=None, - use_chat_cache=None, - dump_chat_cache=None, - use_embedding_cache=None, - dump_embedding_cache=None, + chat_api_key: str | None = None, + chat_model: str | None = None, + chat_api_base: str | None = None, + chat_api_version: str | None = None, + embedding_api_key: str | None = None, + embedding_model: str | None = None, + embedding_api_base: str | None = None, + embedding_api_version: str | None = None, + use_chat_cache: bool | None = None, + dump_chat_cache: bool | None = None, + use_embedding_cache: bool | None = None, + dump_embedding_cache: bool | None = None, ) -> None: self.cfg = Config() if self.cfg.use_llama2: @@ -250,40 +251,41 @@ def __init__( ) self.encoder = None elif self.cfg.use_gcr_endpoint: - if self.cfg.gcr_endpoint_type == "llama2_70b": - self.gcr_endpoinpt_key = self.cfg.llama2_70b_endpoint_key + gcr_endpoint_type = self.cfg.gcr_endpoint_type + if gcr_endpoint_type == "llama2_70b": + self.gcr_endpoint_key = self.cfg.llama2_70b_endpoint_key self.gcr_endpoint_deployment = self.cfg.llama2_70b_endpoint_deployment self.gcr_endpoint = self.cfg.llama2_70b_endpoint - elif self.cfg.gcr_endpoint_type == "llama3_70b": - self.gcr_endpoinpt_key = self.cfg.llama3_70b_endpoint_key + elif gcr_endpoint_type == "llama3_70b": + self.gcr_endpoint_key = self.cfg.llama3_70b_endpoint_key self.gcr_endpoint_deployment = self.cfg.llama3_70b_endpoint_deployment self.gcr_endpoint = self.cfg.llama3_70b_endpoint - elif self.cfg.gcr_endpoint_type == "phi2": - self.gcr_endpoinpt_key = self.cfg.phi2_endpoint_key + elif gcr_endpoint_type == "phi2": + self.gcr_endpoint_key = self.cfg.phi2_endpoint_key self.gcr_endpoint_deployment = self.cfg.phi2_endpoint_deployment self.gcr_endpoint = self.cfg.phi2_endpoint - elif self.cfg.gcr_endpoint_type == "phi3_4k": - self.gcr_endpoinpt_key = self.cfg.phi3_4k_endpoint_key + elif gcr_endpoint_type == "phi3_4k": + self.gcr_endpoint_key = self.cfg.phi3_4k_endpoint_key self.gcr_endpoint_deployment = self.cfg.phi3_4k_endpoint_deployment self.gcr_endpoint = self.cfg.phi3_4k_endpoint - elif self.cfg.gcr_endpoint_type == "phi3_128k": - self.gcr_endpoinpt_key = self.cfg.phi3_128k_endpoint_key + elif gcr_endpoint_type == "phi3_128k": + self.gcr_endpoint_key = self.cfg.phi3_128k_endpoint_key self.gcr_endpoint_deployment = self.cfg.phi3_128k_endpoint_deployment self.gcr_endpoint = self.cfg.phi3_128k_endpoint else: - raise ValueError(f"Invalid gcr_endpoint_type: {self.cfg.gcr_endpoint_type}") + error_message = f"Invalid gcr_endpoint_type: {gcr_endpoint_type}" + raise ValueError(error_message) self.headers = { "Content-Type": "application/json", - "Authorization": ("Bearer " + self.gcr_endpoinpt_key), + "Authorization": ("Bearer " + self.gcr_endpoint_key), "azureml-model-deployment": self.gcr_endpoint_deployment, } - # self.gcr_endpoint = self.cfg.llama2_endpoint self.gcr_endpoint_temperature = self.cfg.gcr_endpoint_temperature self.gcr_endpoint_top_p = self.cfg.gcr_endpoint_top_p self.gcr_endpoint_do_sample = self.cfg.gcr_endpoint_do_sample self.gcr_endpoint_max_token = self.cfg.gcr_endpoint_max_token - if not os.environ.get("PYTHONHTTPSVERIFY", "") and getattr(ssl, "_create_unverified_context", None): - ssl._create_default_https_context = ssl._create_unverified_context + if not os.environ.get("PYTHONHTTPSVERIFY", "") and hasattr(ssl, "_create_unverified_context"): + ssl._create_default_https_context = ssl._create_unverified_context # noqa: SLF001 self.encoder = None else: self.use_azure = self.cfg.use_azure @@ -312,7 +314,7 @@ def __init__( if self.use_azure_token_provider: credential = DefaultAzureCredential() token_provider = get_bearer_token_provider( - credential, "https://cognitiveservices.azure.com/.default" + credential, "https://cognitiveservices.azure.com/.default", ) self.chat_client = openai.AzureOpenAI( azure_ad_token_provider=token_provider, @@ -354,28 +356,35 @@ def __init__( self.use_gcr_endpoint = self.cfg.use_gcr_endpoint self.retry_wait_seconds = self.cfg.retry_wait_seconds - def build_chat_session(self, conversation_id=None, session_system_prompt=None): + def build_chat_session( + self, + conversation_id: str | None = None, + session_system_prompt: str | None = None, + ) -> ChatSession: """ conversation_id is a 256-bit string created by uuid.uuid4() and is also the file name under session_cache_folder/ for each conversation """ - session = ChatSession(self, conversation_id, session_system_prompt) - return session + return ChatSession(self, conversation_id, session_system_prompt) def build_messages( self, - user_prompt, - system_prompt=None, - former_messages=[], - shrink_multiple_break=False, - ): + user_prompt: str, + system_prompt: str | None = None, + former_messages: list[dict] | None = None, + *, + shrink_multiple_break: bool = False, + ) -> list[dict]: """build the messages to avoid implementing several redundant lines of code""" + if former_messages is None: + former_messages = [] # shrink multiple break will recursively remove multiple breaks(more than 2) if shrink_multiple_break: while "\n\n\n" in user_prompt: user_prompt = user_prompt.replace("\n\n\n", "\n\n") - while "\n\n\n" in system_prompt: - system_prompt = system_prompt.replace("\n\n\n", "\n\n") + if system_prompt is not None: + while "\n\n\n" in system_prompt: + system_prompt = system_prompt.replace("\n\n\n", "\n\n") system_prompt = self.cfg.default_system_prompt if system_prompt is None else system_prompt messages = [ { @@ -394,42 +403,39 @@ def build_messages( def build_messages_and_create_chat_completion( self, - user_prompt, - system_prompt=None, - former_messages=[], - shrink_multiple_break=False, - chat_cache_prefix="", - **kwargs, - ): + user_prompt: str, + system_prompt: str | None = None, + former_messages: list | None = None, + chat_cache_prefix: str = "", + *, + shrink_multiple_break: bool = False, + **kwargs: Any, + ) -> str: + if former_messages is None: + former_messages = [] messages = self.build_messages(user_prompt, system_prompt, former_messages, shrink_multiple_break) - response = self._try_create_chat_completion_or_embedding( + return self._try_create_chat_completion_or_embedding( messages=messages, chat_completion=True, chat_cache_prefix=chat_cache_prefix, **kwargs, ) - # if self.debug_mode: - # ConvManager().append((messages, response)) - return response - def create_embedding(self, input_content, **kwargs): - if isinstance(input_content, str): - input_content_list = [input_content] - elif isinstance(input_content, list): - input_content_list = input_content + + def create_embedding(self, input_content: str | list[str], **kwargs: Any) -> list[Any] | Any: + input_content_list = [input_content] if isinstance(input_content, str) else input_content resp = self._try_create_chat_completion_or_embedding( input_content_list=input_content_list, embedding=True, **kwargs, ) if isinstance(input_content, str): return resp[0] - elif isinstance(input_content, list): - return resp + return resp - def _create_chat_completion_auto_continue(self, messages, **kwargs): + def _create_chat_completion_auto_continue(self, messages: list, **kwargs: dict) -> str: """ - this function is to call the chat completion function and automatically continue the conversation if the finish_reason is length - # TODO: this function only continue once, maybe need to continue more than once in the future + Call the chat completion function and automatically continue the conversation if the finish_reason is length. + TODO: This function only continues once, maybe need to continue more than once in the future. """ response, finish_reason = self._create_chat_completion_inner_function(messages=messages, **kwargs) @@ -444,39 +450,36 @@ def _create_chat_completion_auto_continue(self, messages, **kwargs): ) new_response, finish_reason = self._create_chat_completion_inner_function(messages=new_message, **kwargs) return response + new_response - else: - return response + return response - def _try_create_chat_completion_or_embedding(self, max_retry=10, chat_completion=False, embedding=False, **kwargs): + def _try_create_chat_completion_or_embedding( + self, max_retry: int = 10, *, chat_completion: bool = False, embedding: bool = False, **kwargs: Any, + ) -> Any: assert not (chat_completion and embedding), "chat_completion and embedding cannot be True at the same time" max_retry = self.cfg.max_retry if self.cfg.max_retry is not None else max_retry for i in range(max_retry): try: if embedding: - response = self._create_embedding_inner_function(**kwargs) - return response - elif chat_completion: - response = self._create_chat_completion_auto_continue(**kwargs) - return response - except Exception as e: + return self._create_embedding_inner_function(**kwargs) + if chat_completion: + return self._create_chat_completion_auto_continue(**kwargs) + except openai.BadRequestError as e: # noqa: PERF203 print(e) print(f"Retrying {i+1}th time...") - if ( - isinstance(e, openai.BadRequestError) - and r"'messages' must contain the word 'json' in some form" in e.message - ): + if "'messages' must contain the word 'json' in some form" in e.message: kwargs["add_json_in_prompt"] = True - elif isinstance(e, openai.BadRequestError) and embedding and "maximum context length" in e.message: - for index in range(len(kwargs["input_content_list"])): - kwargs["input_content_list"][index] = kwargs["input_content_list"][index][ - : len(kwargs["input_content_list"][index]) // 2 - ] - else: - time.sleep(self.retry_wait_seconds) - continue - raise Exception(f"Failed to create chat completion after {max_retry} retries.") + elif embedding and "maximum context length" in e.message: + kwargs["input_content_list"] = [ + content[: len(content) // 2] for content in kwargs.get("input_content_list", []) + ] + except Exception as e: # noqa: BLE001 + print(e) + print(f"Retrying {i+1}th time...") + time.sleep(self.retry_wait_seconds) + error_message = f"Failed to create chat completion after {max_retry} retries." + raise RuntimeError(error_message) - def _create_embedding_inner_function(self, input_content_list, **kwargs): + def _create_embedding_inner_function(self, input_content_list: list[str], **kwargs: Any) -> list[Any]: # noqa: ARG002 content_to_embedding_dict = {} filtered_input_content_list = [] if self.use_embedding_cache: @@ -505,41 +508,42 @@ def _create_embedding_inner_function(self, input_content_list, **kwargs): if self.dump_embedding_cache: self.cache.embedding_set(content_to_embedding_dict) - resp = [content_to_embedding_dict[content] for content in input_content_list] - return resp + return [content_to_embedding_dict[content] for content in input_content_list] + - def _build_messages(self, messages): + def _build_messages(self, messages: list[dict]) -> str: log_messages = "" for m in messages: log_messages += ( f"\n{LogColors.MAGENTA}{LogColors.BOLD}Role:{LogColors.END}" - + f"{LogColors.CYAN}{m['role']}{LogColors.END}\n" - + f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} " - + f"{LogColors.CYAN}{m['content']}{LogColors.END}\n" + f"{LogColors.CYAN}{m['role']}{LogColors.END}\n" + f"{LogColors.MAGENTA}{LogColors.BOLD}Content:{LogColors.END} " + f"{LogColors.CYAN}{m['content']}{LogColors.END}\n" ) return log_messages - def log_messages(self, messages): + def log_messages(self, messages: list[dict]) -> None: if self.cfg.log_llm_chat_content: FinCoLog().info(self._build_messages(messages)) - def log_response(self, response=None, stream=False): + def log_response(self, response: str | None = None, *, stream: bool = False) -> None: if self.cfg.log_llm_chat_content: if stream: FinCoLog().info(f"\n{LogColors.CYAN}Response:{LogColors.END}") else: FinCoLog().info(f"\n{LogColors.CYAN}Response:{response}{LogColors.END}") - def _create_chat_completion_inner_function( + def _create_chat_completion_inner_function( # noqa: C901, PLR0912, PLR0915 self, - messages, - temperature: float = None, - max_tokens: Optional[int] = None, - chat_cache_prefix="", - json_mode=False, - add_json_in_prompt=False, - frequency_penalty=None, - presence_penalty=None, + messages: list[dict], + temperature: float | None = None, + max_tokens: int | None = None, + chat_cache_prefix: str = "", + frequency_penalty: float | None = None, + presence_penalty: float | None = None, + *, + json_mode: bool = False, + add_json_in_prompt: bool = False, ) -> str: self.log_messages(messages) # TODO: fail to use loguru adaptor due to stream response @@ -587,8 +591,8 @@ def _create_chat_completion_inner_function( ), ) - req = urllib.request.Request(self.gcr_endpoint, body, self.headers) - response = urllib.request.urlopen(req) + req = urllib.request.Request(self.gcr_endpoint, body, self.headers) # noqa: S310 + response = urllib.request.urlopen(req) # noqa: S310 resp = json.loads(response.read().decode())["output"] self.log_response(resp) else: @@ -655,7 +659,7 @@ def _create_chat_completion_inner_function( # TODO: fail to use loguru adaptor due to stream response return resp, finish_reason - def calculate_token_from_messages(self, messages): + def calculate_token_from_messages(self, messages: list[dict]) -> int: if self.use_llama2 or self.use_gcr_endpoint: FinCoLog().warning("num_tokens_from_messages() is not implemented for model llama2.") return 0 # TODO implement this function for llama2 @@ -678,27 +682,28 @@ def calculate_token_from_messages(self, messages): def build_messages_and_calculate_token( self, - user_prompt, - system_prompt, - former_messages=[], - shrink_multiple_break=False, - ): + user_prompt: str, + system_prompt: str | None, + former_messages: list[dict] | None = None, + *, + shrink_multiple_break: bool = False, + ) -> int: + if former_messages is None: + former_messages = [] messages = self.build_messages(user_prompt, system_prompt, former_messages, shrink_multiple_break) return self.calculate_token_from_messages(messages) -def calculate_embedding_process(str_list): +def calculate_embedding_process(str_list: list) -> list: return APIBackend().create_embedding(str_list) -def create_embedding_with_multiprocessing(str_list, slice_count=50, nproc=8): +def create_embedding_with_multiprocessing(str_list: list, slice_count: int = 50, nproc: int = 8) -> list: embeddings = [] pool = multiprocessing.Pool(nproc) - result_list = [] - for index in range(0, len(str_list), slice_count): - result_list.append(pool.apply_async(calculate_embedding_process, (str_list[index : index + slice_count],))) - + result_list = [pool.apply_async(calculate_embedding_process, (str_list[index : index + slice_count],)) + for index in range(0, len(str_list), slice_count)] pool.close() pool.join() @@ -707,13 +712,16 @@ def create_embedding_with_multiprocessing(str_list, slice_count=50, nproc=8): return embeddings -def calculate_embedding_distance_between_str_list(source_str_list: List, target_str_list: List): - if len(source_str_list) == 0 or len(target_str_list) == 0: + +def calculate_embedding_distance_between_str_list( + source_str_list: list[str], target_str_list: list[str], +) -> list[list[float]]: + if not source_str_list or not target_str_list: return [[]] embeddings = create_embedding_with_multiprocessing(source_str_list + target_str_list, slice_count=50, nproc=8) - source_embeddings = embeddings[: len(source_str_list)] - target_embeddings = embeddings[len(source_str_list) :] + source_embeddings = embeddings[:len(source_str_list)] + target_embeddings = embeddings[len(source_str_list):] source_embeddings_np = np.array(source_embeddings) target_embeddings_np = np.array(target_embeddings) diff --git a/test/oai/test_completion.py b/test/oai/test_completion.py index 8b518244..79297d0b 100644 --- a/test/oai/test_completion.py +++ b/test/oai/test_completion.py @@ -1,47 +1,46 @@ -import pickle -import unittest -from pathlib import Path import json import random +import unittest from rdagent.oai.llm_utils import APIBackend class TestChatCompletion(unittest.TestCase): - def test_chat_completion(self): + def test_chat_completion(self) -> None: system_prompt = "You are a helpful assistant." user_prompt = "What is your name?" response = APIBackend().build_messages_and_create_chat_completion( - system_prompt=system_prompt, user_prompt=user_prompt + system_prompt=system_prompt, user_prompt=user_prompt, ) assert response is not None - assert type(response) == str + assert isinstance(response, str) - def test_chat_completion_json_mode(self): + def test_chat_completion_json_mode(self) -> None: system_prompt = "You are a helpful assistant. answer in Json format." user_prompt = "What is your name?" response = APIBackend().build_messages_and_create_chat_completion( - system_prompt=system_prompt, user_prompt=user_prompt, json_mode=True + system_prompt=system_prompt, user_prompt=user_prompt, json_mode=True, ) assert response is not None - assert type(response) == str + assert isinstance(response, str) json.loads(response) - def test_chat_multi_round(self): + def test_chat_multi_round(self) -> None: system_prompt = "You are a helpful assistant." - fruit_name = ["apple", "banana", "orange", "grape", "watermelon"][random.randint(0, 4)] - user_prompt_1 = f"I will tell you a name of fruit, please remember them and tell me later. The name is {fruit_name}. Once you remembeer it, please answer OK." - user_prompt_2 = f"What is the name of the fruit I told you before?" + fruit_name = random.SystemRandom().choice(["apple", "banana", "orange", "grape", "watermelon"]) + user_prompt_1 = ( + f"I will tell you a name of fruit, please remember them and tell me later. " + f"The name is {fruit_name}. Once you remember it, please answer OK." + ) + user_prompt_2 = "What is the name of the fruit I told you before?" session = APIBackend().build_chat_session(session_system_prompt=system_prompt) response_1 = session.build_chat_completion(user_prompt=user_prompt_1) assert response_1 is not None assert "ok" in response_1.lower() - response2 = session.build_chat_completion(user_prompt=user_prompt_2) assert response2 is not None - assert fruit_name in response2.lower() if __name__ == "__main__": diff --git a/test/oai/test_embedding_and_similarity.py b/test/oai/test_embedding_and_similarity.py index 8e426d4d..a577c4a5 100644 --- a/test/oai/test_embedding_and_similarity.py +++ b/test/oai/test_embedding_and_similarity.py @@ -1,25 +1,21 @@ -import pickle import unittest -from pathlib import Path -import json -import random from rdagent.oai.llm_utils import APIBackend, calculate_embedding_distance_between_str_list class TestEmbedding(unittest.TestCase): - def test_embedding(self): + def test_embedding(self) -> None: emb = APIBackend().create_embedding("hello") assert emb is not None - assert type(emb) == list + assert isinstance(emb, list) assert len(emb) > 0 - def test_embedding_similarity(self): + def test_embedding_similarity(self) -> None: similarity = calculate_embedding_distance_between_str_list(["Hello"], ["Hi"])[0][0] assert similarity is not None - assert type(similarity) == float - assert similarity >= 0.8 - + assert isinstance(similarity, float) + min_similarity_threshold = 0.8 + assert similarity >= min_similarity_threshold if __name__ == "__main__": unittest.main()