Skip to content

Commit

Permalink
Ci fix (#22)
Browse files Browse the repository at this point in the history
* fix replace function of CI tool

* fix ruff errors (ignore some parts)

* add ruff rule ignore comment
  • Loading branch information
XianBW authored Jun 12, 2024
1 parent 0e83ecd commit 59123bf
Show file tree
Hide file tree
Showing 20 changed files with 449 additions and 435 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ ignore = [
"PGH",
"PLR0913",
"S101",
"S301",
"T20",
"TCH003",
"TD",
Expand Down
113 changes: 75 additions & 38 deletions rdagent/app/CI/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import json
import re
import shlex
import subprocess
import time
from collections import defaultdict
Expand Down Expand Up @@ -122,20 +123,23 @@ def load(self) -> None:
self.lineno_width = len(str(self.lineno))
self.code_lines_with_lineno = self.add_line_number(self.code_lines)

def get(self, start: int = 1, end: int | None = None, add_line_number: bool = False, return_list: bool = False) -> list[str] | str:
def get(
self, start: int = 1, end: int | None = None, *,
add_line_number: bool = False, return_list: bool = False,
) -> list[str] | str:
"""
Retrieves a portion of the code lines.
line number starts from 1, return codes in [start, end].
Args:
start (int): The starting line number (inclusive). Defaults to 1.
end (int): The ending line number (inclusive). Defaults to None, which means the last line.
end (int | None): The ending line number (inclusive). Defaults to None, which means the last line.
add_line_number (bool): Whether to include line numbers in the result. Defaults to False.
return_list (bool): Whether to return the result as a list of lines
or as a single string. Defaults to False.
Returns:
Union[List[str], str]: The code lines as a list of strings or as a
list[str] | str: The code lines as a list of strings or as a
single string, depending on the value of `return_list`.
"""
start -= 1
Expand All @@ -161,11 +165,12 @@ def apply_changes(self, changes: list[tuple[int, int, str]]) -> None:
"""
offset = 0
for start, end, code in changes:
# starts from 1 --> starts from 0
adjusted_start = max(start - 1, 0)

new_code = code.split("\n")
self.code_lines[adjusted_start+offset:end+offset] = new_code
offset += len(new_code) - (end - start)
self.code_lines[adjusted_start + offset : end + offset] = new_code
offset += len(new_code) - (end - adjusted_start)

self.path.write_text("\n".join(self.code_lines), encoding="utf-8")
self.load()
Expand Down Expand Up @@ -207,15 +212,17 @@ def __str__(self) -> str:


class Repo(EvolvableSubjects):
def __init__(self, project_path: Path | str, excludes: list[Path] = [], **kwargs: Any) -> None:
def __init__(self, project_path: Path | str, excludes: list[Path] | None = None, **kwargs: Any) -> None:
if excludes is None:
excludes = []
self.params = kwargs
self.project_path = Path(project_path)

excludes = [self.project_path / path for path in excludes]

git_ignored_output = subprocess.check_output(
["git", "status", "--ignored", "-s"],
cwd=project_path,
["/usr/bin/git", "status", "--ignored", "-s"], # noqa: S603
cwd=str(self.project_path),
stderr=subprocess.STDOUT,
text=True,
)
Expand Down Expand Up @@ -283,7 +290,7 @@ def explain_rule(error_code: str) -> RuffRule:
explain_command = f"ruff rule {error_code} --output-format json"
try:
out = subprocess.check_output(
explain_command,
shlex.split(explain_command), # noqa: S603
stderr=subprocess.STDOUT,
text=True,
)
Expand All @@ -293,11 +300,11 @@ def explain_rule(error_code: str) -> RuffRule:
return RuffRule(**json.loads(out))


def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback:
def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback: # noqa: ARG002
"""Simply run ruff to get the feedbacks."""
try:
out = subprocess.check_output(
self.command.split(),
shlex.split(self.command), # noqa: S603
cwd=evo.project_path,
stderr=subprocess.STDOUT,
text=True,
Expand Down Expand Up @@ -349,10 +356,10 @@ def __init__(self, command: str | None = None) -> None:
else:
self.command = command

def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback:
def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback: # noqa: ARG002
try:
out = subprocess.check_output(
self.command.split(),
shlex.split(self.command), # noqa: S603
cwd=evo.project_path,
stderr=subprocess.STDOUT,
text=True,
Expand All @@ -368,9 +375,11 @@ def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback:
for match in re.findall(pattern, out, re.DOTALL):
raw_str, file_path, line_number, column_number, error_message, error_code, error_hint = match
error_message = error_message.strip().replace("\n", " ")
if re.match(r".*[^\n]*?:\d+:\d+: note:.*", error_hint, re.DOTALL) is not None:
error_hint_position = re.split(r"[^\n]*?:\d+:\d+: note:", error_hint, re.DOTALL)[0]
error_hint_help = re.findall(r"^.*?:\d+:\d+: note: (.*)$", error_hint, re.MULTILINE)
if re.match(r".*[^\n]*?:\d+:\d+: note:.*", error_hint, flags=re.DOTALL) is not None:
error_hint_position = re.split(
pattern=r"[^\n]*?:\d+:\d+: note:", string=error_hint, maxsplit=1, flags=re.DOTALL,
)[0]
error_hint_help = re.findall(r"^.*?:\d+:\d+: note: (.*)$", error_hint, flags=re.MULTILINE)
error_hint_help = "\n".join(error_hint_help)
error_hint = f"{error_hint_position}\nHelp:\n{error_hint_help}"

Expand Down Expand Up @@ -410,12 +419,12 @@ def evaluate(self, evo: Repo, **kwargs: Any) -> CIFeedback:
return CIFeedback(errors=all_errors)

class CIEvoStr(EvolvingStrategy):
def evolve(
def evolve( # noqa: C901, PLR0912, PLR0915
self,
evo: Repo,
evolving_trace: list[EvoStep] | None = None,
knowledge_l: list[Knowledge] | None = None,
**kwargs: Any,
knowledge_l: list[Knowledge] | None = None, # noqa: ARG002
**kwargs: Any, # noqa: ARG002
) -> Repo:

@dataclass
Expand All @@ -433,11 +442,22 @@ class CodeFixGroup:
last_feedback: CIFeedback = evolving_trace[-1].feedback

# print statistics
checker_error_counts = {checker: sum(c_statistics.values()) for checker, c_statistics in last_feedback.statistics().items()}
print(f"Found [red]{sum(checker_error_counts.values())}[/red] errors, including: " +
", ".join(f"[red]{count}[/red] [magenta]{checker}[/magenta] errors" for checker, count in checker_error_counts.items()))
checker_error_counts = {
checker: sum(c_statistics.values())
for checker, c_statistics in last_feedback.statistics().items()
}
print(
f"Found [red]{sum(checker_error_counts.values())}[/red] errors, "
"including: " +
", ".join(
f"[red]{count}[/red] [magenta]{checker}[/magenta] errors"
for checker, count in checker_error_counts.items()
),
)

fix_records: dict[str, FixRecord] = defaultdict(lambda: FixRecord([], [], [], defaultdict(list)))
fix_records: dict[str, FixRecord] = defaultdict(
lambda: FixRecord([], [], [], defaultdict(list)),
)

# Group errors by code blocks
fix_groups: dict[str, list[CodeFixGroup]] = defaultdict(list)
Expand All @@ -450,7 +470,7 @@ class CodeFixGroup:
# TODO @bowen: current way of handling errors like 'Add import statement' may be not good
for error in errors:
if error.code in ("FA100", "FA102"):
changes[file_path].append((0, 0, "from __future__ import annotations\n"))
changes[file_path].append((1, 1, "from __future__ import annotations\n"))
break

# Group errors by code blocks
Expand Down Expand Up @@ -485,10 +505,12 @@ class CodeFixGroup:
for file_path in fix_groups:
file = evo.files[evo.project_path / Path(file_path)]
for code_fix_g in fix_groups[file_path]:
start_line, end_line, group_errors = code_fix_g.start_line, code_fix_g.end_line, code_fix_g.errors
start_line = code_fix_g.start_line
end_line = code_fix_g.end_line
group_errors = code_fix_g.errors
code_snippet_with_lineno = file.get(
start_line, end_line, add_line_number=True, return_list=False,
)
start_line, end_line, add_line_number=True, return_list=False,
)
errors_str = "\n\n".join(str(e) for e in group_errors)

# ask LLM to repair current code snippet
Expand All @@ -502,14 +524,18 @@ class CodeFixGroup:

session = api.build_chat_session(conversation_id=code_fix_g.session_id)
res = session.build_chat_completion(user_prompt)

code_fix_g.responses.append(res)
progress.update(task_id, description=f"[green]Fixing[/green] [cyan]{file_path}[/cyan]...", advance=1)
progress.update(
task_id,
description=f"[green]Fixing[/green] [cyan]{file_path}[/cyan]...",
advance=1,
)


# Manual inspection and repair
for file_path in last_feedback.errors:
print(Rule(f"[bright_blue]Checking[/bright_blue] [cyan]{file_path}[/cyan]", style="bright_blue", align="left", characters="."))
print(Rule(f"[bright_blue]Checking[/bright_blue] [cyan]{file_path}[/cyan]",
style="bright_blue", align="left", characters="."))

file = evo.files[evo.project_path / Path(file_path)]

Expand All @@ -529,7 +555,8 @@ class CodeFixGroup:

# print errors
printed_errors_str = "\n".join(
[f"[{error.checker}] {error.line: >{file.lineno_width}}:{error.column: <4} {error.code} {error.msg}" for error in group_errors],
[f"[{error.checker}] {error.line: >{file.lineno_width}}:{error.column: <4}"
f" {error.code} {error.msg}" for error in group_errors],
)
print(
Panel.fit(
Expand All @@ -554,13 +581,16 @@ class CodeFixGroup:
while True:
try:
new_code = re.search(r".*```[Pp]ython\n(.*?)\n```.*", res, re.DOTALL).group(1)
except Exception:
print(f"[red]Error when extract codes[/red]:\n {res}")
except (re.error, AttributeError) as exc:
print(f"[red]Error when extract codes[/red]:\n {res}\nException: {exc}")
try:
fixed_errors_info = re.search(r".*```[Jj]son\n(.*?)\n```.*", res, re.DOTALL).group(1)
fixed_errors_info = json.loads(fixed_errors_info)
except Exception:
except AttributeError:
fixed_errors_info = None
except (json.JSONDecodeError, re.error) as exc:
fixed_errors_info = None
print(f"[red]Error when extracting fixed_errors[/red]: {exc}")

new_code = CodeFile.remove_line_number(new_code)

Expand Down Expand Up @@ -600,7 +630,8 @@ class CodeFixGroup:
print(Panel.fit(table, title="Repair Status"))

operation = Prompt.ask("Input your operation [ [red]([bold]s[/bold])kip[/red] / "
"[green]([bold]a[/bold])pply[/green] / [yellow]manual instruction[/yellow] ]")
"[green]([bold]a[/bold])pply[/green] / "
"[yellow]manual instruction[/yellow] ]")
print()
if operation in ("s", "skip"):
fix_records[file_path].skipped_errors.extend(group_errors)
Expand All @@ -620,7 +651,9 @@ class CodeFixGroup:
break

fix_records[file_path].manual_instructions[operation].extend(group_errors)
res = session.build_chat_completion(CI_prompts["session_manual_template"].format(operation=operation))
res = session.build_chat_completion(
CI_prompts["session_manual_template"].format(operation=operation),
)
code_fix_g.responses.append(res)

# apply changes
Expand All @@ -636,14 +669,18 @@ class CodeFixGroup:
DIR = Prompt.ask("Please input the [cyan]project directory[/cyan]")
DIR = Path(DIR)

excludes = Prompt.ask("Input the [dark_orange]excluded directories[/dark_orange] (relative to [cyan]project path[/cyan] and separated by whitespace)").split(" ")
excludes = Prompt.ask(
"Input the [dark_orange]excluded directories[/dark_orange] (relative to "
"[cyan]project path[/cyan] and separated by whitespace)",
).split(" ")
excludes = [Path(exclude.strip()) for exclude in excludes if exclude.strip() != ""]

start_time = time.time()
start_timestamp = datetime.datetime.now(datetime.timezone.utc).strftime("%m%d%H%M")

repo = Repo(DIR, excludes=excludes)
evaluator = MultiEvaluator(MypyEvaluator(), RuffEvaluator())
# evaluator = MultiEvaluator(MypyEvaluator(), RuffEvaluator())
evaluator = RuffEvaluator()
estr = CIEvoStr()
rag = None # RAG is not enable firstly.
ea = EvoAgent(estr, rag=rag)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
# %%
import json
from pathlib import Path

from dotenv import load_dotenv
from rdagent.document_process.document_analysis import (
check_factor_viability,
classify_report_from_dict,
deduplicate_factors_by_llm,
extract_factors_from_report_dict,
merge_file_to_factor_dict_to_factor_dict,
)
from rdagent.document_process.document_reader import load_and_process_pdfs_by_langchain
from rdagent.document_process.document_analysis import classify_report_from_dict
from dotenv import load_dotenv


def extract_factors_and_implement(report_file_path: str):
def extract_factors_and_implement(report_file_path: str) -> None:
assert load_dotenv()
docs_dict = load_and_process_pdfs_by_langchain(Path(report_file_path))

Expand Down
19 changes: 10 additions & 9 deletions rdagent/core/conf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# TODO: use pydantic for other modules in Qlib
# from pydantic_settings import BaseSettings
import os
from typing import Union
from __future__ import annotations

from pathlib import Path

from dotenv import load_dotenv
from pydantic_settings import BaseSettings

# TODO: use pydantic for other modules in Qlib
# from pydantic_settings import BaseSettings

# make sure that env variable is loaded while calling Config()
load_dotenv(verbose=True, override=True)

from pydantic_settings import BaseSettings


class FincoSettings(BaseSettings):
use_azure: bool = True
Expand All @@ -22,8 +23,8 @@ class FincoSettings(BaseSettings):
dump_embedding_cache: bool = False
use_embedding_cache: bool = False
workspace: str = "./finco_workspace"
prompt_cache_path: str = os.getcwd() + "/prompt_cache.db"
session_cache_folder_location: str = os.getcwd() + "/session_cache_folder/"
prompt_cache_path: str = str(Path.cwd() / "prompt_cache.db")
session_cache_folder_location: str = str(Path.cwd() / "session_cache_folder/")
max_past_message_include: int = 10

use_vector_only: bool = False
Expand All @@ -37,7 +38,7 @@ class FincoSettings(BaseSettings):
chat_max_tokens: int = 3000
chat_temperature: float = 0.5
chat_stream: bool = True
chat_seed: Union[int, None] = None
chat_seed: int | None = None
chat_frequency_penalty: float = 0.0
chat_presence_penalty: float = 0.0

Expand Down
14 changes: 1 addition & 13 deletions rdagent/core/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,29 +78,17 @@ class FinCoLog:
def __init__(self) -> None:
self.logger: Logger = logger

def info(self, *args: Sequence, plain: bool = False, title: str = "Info") -> None:
def info(self, *args: Sequence, plain: bool = False) -> None:
if plain:
return self.plain_info(*args)
for arg in args:
# Changes to accommodate ruff checks.
# Original code:
# self.logger.info(f"{LogColors.WHITE}{arg}{LogColors.END}")
# Description of the problem:
# G004 Logging statement uses f-string
# References:
# https://docs.astral.sh/ruff/rules/logging-f-string/
info = f"{LogColors.WHITE}{arg}{LogColors.END}"
self.logger.info(info)
return None

def __getstate__(self) -> dict:
return {}

# Changes to accommodate ruff checks.
# Original code: def __setstate__(self, _: str) -> None:
# Description of the problem:
# PLE0302 The special method `__setstate__` expects 2 parameters, 1 was given
# References: https://docs.astral.sh/ruff/rules/unexpected-special-method-signature/
def __setstate__(self, _: str) -> None:
self.logger = logger

Expand Down
Loading

0 comments on commit 59123bf

Please sign in to comment.