diff --git a/examples/reverse_engineering.py b/examples/reverse_engineering.py new file mode 100644 index 0000000000..f80fc09e66 --- /dev/null +++ b/examples/reverse_engineering.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import asyncio +import shutil +from pathlib import Path + +import typer + +from metagpt.actions.rebuild_class_view import RebuildClassView +from metagpt.actions.rebuild_sequence_view import RebuildSequenceView +from metagpt.context import Context +from metagpt.llm import LLM +from metagpt.logs import logger +from metagpt.utils.git_repository import GitRepository +from metagpt.utils.project_repo import ProjectRepo + +app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False) + + +@app.command("", help="Python project reverse engineering.") +def startup( + project_root: str = typer.Argument( + default="", + help="Specify the root directory of the existing project for reverse engineering.", + ), + output_dir: str = typer.Option(default="", help="Specify the output directory path for reverse engineering."), +): + package_root = Path(project_root) + if not package_root.exists(): + raise FileNotFoundError(f"{project_root} not exists") + if not _is_python_package_root(package_root): + raise FileNotFoundError(f'There are no "*.py" files under "{project_root}".') + init_file = package_root / "__init__.py" # used by pyreverse + init_file_exists = init_file.exists() + if not init_file_exists: + init_file.touch() + + if not output_dir: + output_dir = package_root / "../reverse_engineering_output" + logger.info(f"output dir:{output_dir}") + try: + asyncio.run(reverse_engineering(package_root, Path(output_dir))) + finally: + if not init_file_exists: + init_file.unlink(missing_ok=True) + tmp_dir = package_root / "__dot__" + if tmp_dir.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + + +def _is_python_package_root(package_root: Path) -> bool: + for file_path in package_root.iterdir(): + if file_path.is_file(): + if file_path.suffix == ".py": + return True + return False + + +async def reverse_engineering(package_root: Path, output_dir: Path): + ctx = Context() + ctx.git_repo = GitRepository(output_dir) + ctx.repo = ProjectRepo(ctx.git_repo) + action = RebuildClassView(name="ReverseEngineering", i_context=str(package_root), llm=LLM(), context=ctx) + await action.run() + + action = RebuildSequenceView(name="ReverseEngineering", llm=LLM(), context=ctx) + await action.run() + + +if __name__ == "__main__": + app() diff --git a/metagpt/actions/rebuild_class_view.py b/metagpt/actions/rebuild_class_view.py index 6dd5690b66..ff030ec878 100644 --- a/metagpt/actions/rebuild_class_view.py +++ b/metagpt/actions/rebuild_class_view.py @@ -76,7 +76,7 @@ async def _create_mermaid_class_views(self) -> str: path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO path.mkdir(parents=True, exist_ok=True) pathname = path / self.context.git_repo.workdir.name - filename = str(pathname.with_suffix(".mmd")) + filename = str(pathname.with_suffix(".class_diagram.mmd")) async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer: content = "classDiagram\n" logger.debug(content) diff --git a/metagpt/actions/rebuild_sequence_view.py b/metagpt/actions/rebuild_sequence_view.py index 227d298720..0e67de9086 100644 --- a/metagpt/actions/rebuild_sequence_view.py +++ b/metagpt/actions/rebuild_sequence_view.py @@ -12,7 +12,7 @@ import re from datetime import datetime from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Set from pydantic import BaseModel from tenacity import retry, stop_after_attempt, wait_random_exponential @@ -125,7 +125,7 @@ async def _rebuild_main_sequence_view(self, entry: SPO): if prefix in r.subject: classes.append(r) await self._rebuild_use_case(r.subject) - participants = set() + participants = await self._search_participants(split_namespace(entry.subject)[0]) class_details = [] class_views = [] for c in classes: @@ -171,7 +171,8 @@ async def _rebuild_main_sequence_view(self, entry: SPO): sequence_view = rsp.removeprefix("```mermaid").removesuffix("```") rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW) for r in rows: - await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_) + if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW: + await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_) await self.graph_db.insert( subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view ) @@ -184,7 +185,7 @@ async def _rebuild_main_sequence_view(self, entry: SPO): await self.graph_db.insert( subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject) ) - await self.graph_db.save() + await self._save_sequence_view(subject=entry.subject, content=sequence_view) async def _merge_sequence_view(self, entry: SPO) -> bool: """ @@ -267,38 +268,6 @@ async def _rebuild_use_case(self, ns_class_name: str): prompt_blocks.append(block) prompt = "\n---\n".join(prompt_blocks) - # class _UseCase(BaseModel): - # description: str = Field(default="...", description="Describes about what the use case to do") - # inputs: List[str] = Field(default=["input name 1", "input name 2"], - # description="Lists the input names of the use case from external sources") - # outputs: List[str] = Field(default=["output name 1", "output name 2"], - # description="Lists the output names of the use case to external sources") - # actors: List[str] = Field(default=["actor name 1", "actor name 2"], - # description="Lists the participant actors of the use case") - # steps: List[str] = Field(default=["Step 1", "Step 2"], - # description="Lists the steps about how the use case works step by step") - # reason: str = Field(default="Because ...", - # description="Explaining under what circumstances would the external system execute this use case.") - # - # - # class _UseCaseList(BaseModel): - # description: str = Field(default="...", - # description="A summary explains what the whole source code want to do") - # use_cases: List[_UseCase] = Field(default=[ - # { - # "description": "Describes about what the use case to do", - # "inputs": ["input name 1", "input name 2"], - # "outputs": ["output name 1", "output name 2"], - # "actors": ["actor name 1", "actor name 2"], - # "steps": ["Step 1", "Step 2"], - # "reason": "Because ..." - # } - # ], description="List all use cases.") - # relationship: List[str] = Field(default=["use case 1 ..."], - # description="Lists all the descriptions of relationship among these use cases") - - # rsp = await ActionNode.from_pydantic(_UseCaseList).fill(context=prompt, llm=self.llm) - rsp = await self.llm.aask( msg=prompt, system_msgs=[ @@ -327,7 +296,6 @@ async def _rebuild_use_case(self, ns_class_name: str): await self.graph_db.insert( subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json() ) - await self.graph_db.save() @retry( wait=wait_random_exponential(min=1, max=20), @@ -347,7 +315,6 @@ async def _rebuild_sequence_view(self, ns_class_name: str): use_case_markdown = await self._get_class_use_cases(ns_class_name) if not use_case_markdown: # external class await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="") - await self.graph_db.save() return block = f"## Use Cases\n{use_case_markdown}" prompts_blocks.append(block) @@ -382,7 +349,6 @@ async def _rebuild_sequence_view(self, ns_class_name: str): await self.graph_db.insert( subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view ) - await self.graph_db.save() async def _get_participants(self, ns_class_name: str) -> List[str]: """ @@ -574,14 +540,12 @@ async def _merge_participant(self, entry: SPO, class_name: str): await self.graph_db.insert( subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name) ) - await self.graph_db.save() return if len(participants) > 1: for r in participants: await self.graph_db.insert( subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject) ) - await self.graph_db.save() return participant = participants[0] @@ -619,4 +583,31 @@ async def _merge_participant(self, entry: SPO, class_name: str): await self.graph_db.insert( subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject) ) - await self.graph_db.save() + await self._save_sequence_view(subject=entry.subject, content=sequence_view) + + async def _save_sequence_view(self, subject: str, content: str): + pattern = re.compile(r"[^a-zA-Z0-9]") + name = re.sub(pattern, "_", subject) + filename = Path(name).with_suffix(".sequence_diagram.mmd") + await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content) + + async def _search_participants(self, filename: str) -> Set: + content = await self._get_source_code(filename) + + rsp = await self.llm.aask( + msg=content, + system_msgs=[ + "You are a tool for listing all class names used in a source file.", + "Return a markdown JSON object with: " + '- a "class_names" key containing the list of class names used in the file; ' + '- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.', + ], + ) + + class _Data(BaseModel): + class_names: List[str] + reasons: List + + json_blocks = parse_json_code_block(rsp) + data = _Data.model_validate_json(json_blocks[0]) + return set(data.class_names) diff --git a/metagpt/repo_parser.py b/metagpt/repo_parser.py index 15842fdfb6..bc3bae6624 100644 --- a/metagpt/repo_parser.py +++ b/metagpt/repo_parser.py @@ -722,14 +722,19 @@ async def rebuild_class_views(self, path: str | Path = None): path = Path(path) if not path.exists(): return + init_file = path / "__init__.py" + if not init_file.exists(): + raise ValueError("Failed to import module __init__ with error:No module named __init__.") command = f"pyreverse {str(path)} -o dot" - result = subprocess.run(command, shell=True, check=True, cwd=str(path)) + output_dir = path / "__dot__" + output_dir.mkdir(parents=True, exist_ok=True) + result = subprocess.run(command, shell=True, check=True, cwd=str(output_dir)) if result.returncode != 0: raise ValueError(f"{result}") - class_view_pathname = path / "classes.dot" + class_view_pathname = output_dir / "classes.dot" class_views = await self._parse_classes(class_view_pathname) relationship_views = await self._parse_class_relationships(class_view_pathname) - packages_pathname = path / "packages.dot" + packages_pathname = output_dir / "packages.dot" class_views, relationship_views, package_root = RepoParser._repair_namespaces( class_views=class_views, relationship_views=relationship_views, path=path ) @@ -975,6 +980,8 @@ def _repair_ns(package: str, mappings: Dict[str, str]) -> str: file_ns = file_ns[0:ix] continue break + if file_ns == "": + return "" internal_ns = package[ix + 1 :] ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":") return ns diff --git a/tests/metagpt/actions/test_rebuild_class_view.py b/tests/metagpt/actions/test_rebuild_class_view.py index 4414c20009..3731cd5981 100644 --- a/tests/metagpt/actions/test_rebuild_class_view.py +++ b/tests/metagpt/actions/test_rebuild_class_view.py @@ -14,7 +14,6 @@ from metagpt.llm import LLM -@pytest.mark.skip @pytest.mark.asyncio async def test_rebuild(context): action = RebuildClassView( diff --git a/tests/metagpt/actions/test_rebuild_sequence_view.py b/tests/metagpt/actions/test_rebuild_sequence_view.py index 1daea22a4b..0e10e37762 100644 --- a/tests/metagpt/actions/test_rebuild_sequence_view.py +++ b/tests/metagpt/actions/test_rebuild_sequence_view.py @@ -47,6 +47,8 @@ async def test_rebuild(context, mocker): context=context, ) await action.run() + rows = await action.graph_db.select() + assert rows assert context.repo.docs.graph_repo.changed_files