Skip to content

Commit

Permalink
feat: +rfc197 example
Browse files Browse the repository at this point in the history
  • Loading branch information
莘权 马 committed Mar 6, 2024
1 parent 5cae13f commit 3fee7a5
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 47 deletions.
72 changes: 72 additions & 0 deletions examples/reverse_engineering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import asyncio
import shutil
from pathlib import Path

import typer

from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
from metagpt.context import Context
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo

app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)


@app.command("", help="Python project reverse engineering.")
def startup(
project_root: str = typer.Argument(
default="",
help="Specify the root directory of the existing project for reverse engineering.",
),
output_dir: str = typer.Option(default="", help="Specify the output directory path for reverse engineering."),
):
package_root = Path(project_root)
if not package_root.exists():
raise FileNotFoundError(f"{project_root} not exists")
if not _is_python_package_root(package_root):
raise FileNotFoundError(f'There are no "*.py" files under "{project_root}".')
init_file = package_root / "__init__.py" # used by pyreverse
init_file_exists = init_file.exists()
if not init_file_exists:
init_file.touch()

if not output_dir:
output_dir = package_root / "../reverse_engineering_output"
logger.info(f"output dir:{output_dir}")
try:
asyncio.run(reverse_engineering(package_root, Path(output_dir)))
finally:
if not init_file_exists:
init_file.unlink(missing_ok=True)
tmp_dir = package_root / "__dot__"
if tmp_dir.exists():
shutil.rmtree(tmp_dir, ignore_errors=True)


def _is_python_package_root(package_root: Path) -> bool:
for file_path in package_root.iterdir():
if file_path.is_file():
if file_path.suffix == ".py":
return True
return False


async def reverse_engineering(package_root: Path, output_dir: Path):
ctx = Context()
ctx.git_repo = GitRepository(output_dir)
ctx.repo = ProjectRepo(ctx.git_repo)
action = RebuildClassView(name="ReverseEngineering", i_context=str(package_root), llm=LLM(), context=ctx)
await action.run()

action = RebuildSequenceView(name="ReverseEngineering", llm=LLM(), context=ctx)
await action.run()


if __name__ == "__main__":
app()
2 changes: 1 addition & 1 deletion metagpt/actions/rebuild_class_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def _create_mermaid_class_views(self) -> str:
path = self.context.git_repo.workdir / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / self.context.git_repo.workdir.name
filename = str(pathname.with_suffix(".mmd"))
filename = str(pathname.with_suffix(".class_diagram.mmd"))
async with aiofiles.open(filename, mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)
Expand Down
75 changes: 33 additions & 42 deletions metagpt/actions/rebuild_sequence_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import re
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Set

from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_random_exponential
Expand Down Expand Up @@ -125,7 +125,7 @@ async def _rebuild_main_sequence_view(self, entry: SPO):
if prefix in r.subject:
classes.append(r)
await self._rebuild_use_case(r.subject)
participants = set()
participants = await self._search_participants(split_namespace(entry.subject)[0])
class_details = []
class_views = []
for c in classes:
Expand Down Expand Up @@ -171,7 +171,8 @@ async def _rebuild_main_sequence_view(self, entry: SPO):
sequence_view = rsp.removeprefix("```mermaid").removesuffix("```")
rows = await self.graph_db.select(subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW)
for r in rows:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
if r.predicate == GraphKeyword.HAS_SEQUENCE_VIEW:
await self.graph_db.delete(subject=r.subject, predicate=r.predicate, object_=r.object_)
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
Expand All @@ -184,7 +185,7 @@ async def _rebuild_main_sequence_view(self, entry: SPO):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(c.subject)
)
await self.graph_db.save()
await self._save_sequence_view(subject=entry.subject, content=sequence_view)

async def _merge_sequence_view(self, entry: SPO) -> bool:
"""
Expand Down Expand Up @@ -267,38 +268,6 @@ async def _rebuild_use_case(self, ns_class_name: str):
prompt_blocks.append(block)
prompt = "\n---\n".join(prompt_blocks)

# class _UseCase(BaseModel):
# description: str = Field(default="...", description="Describes about what the use case to do")
# inputs: List[str] = Field(default=["input name 1", "input name 2"],
# description="Lists the input names of the use case from external sources")
# outputs: List[str] = Field(default=["output name 1", "output name 2"],
# description="Lists the output names of the use case to external sources")
# actors: List[str] = Field(default=["actor name 1", "actor name 2"],
# description="Lists the participant actors of the use case")
# steps: List[str] = Field(default=["Step 1", "Step 2"],
# description="Lists the steps about how the use case works step by step")
# reason: str = Field(default="Because ...",
# description="Explaining under what circumstances would the external system execute this use case.")
#
#
# class _UseCaseList(BaseModel):
# description: str = Field(default="...",
# description="A summary explains what the whole source code want to do")
# use_cases: List[_UseCase] = Field(default=[
# {
# "description": "Describes about what the use case to do",
# "inputs": ["input name 1", "input name 2"],
# "outputs": ["output name 1", "output name 2"],
# "actors": ["actor name 1", "actor name 2"],
# "steps": ["Step 1", "Step 2"],
# "reason": "Because ..."
# }
# ], description="List all use cases.")
# relationship: List[str] = Field(default=["use case 1 ..."],
# description="Lists all the descriptions of relationship among these use cases")

# rsp = await ActionNode.from_pydantic(_UseCaseList).fill(context=prompt, llm=self.llm)

rsp = await self.llm.aask(
msg=prompt,
system_msgs=[
Expand Down Expand Up @@ -327,7 +296,6 @@ async def _rebuild_use_case(self, ns_class_name: str):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_CLASS_USE_CASE, object_=detail.model_dump_json()
)
await self.graph_db.save()

@retry(
wait=wait_random_exponential(min=1, max=20),
Expand All @@ -347,7 +315,6 @@ async def _rebuild_sequence_view(self, ns_class_name: str):
use_case_markdown = await self._get_class_use_cases(ns_class_name)
if not use_case_markdown: # external class
await self.graph_db.insert(subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_="")
await self.graph_db.save()
return
block = f"## Use Cases\n{use_case_markdown}"
prompts_blocks.append(block)
Expand Down Expand Up @@ -382,7 +349,6 @@ async def _rebuild_sequence_view(self, ns_class_name: str):
await self.graph_db.insert(
subject=ns_class_name, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=sequence_view
)
await self.graph_db.save()

async def _get_participants(self, ns_class_name: str) -> List[str]:
"""
Expand Down Expand Up @@ -574,14 +540,12 @@ async def _merge_participant(self, entry: SPO, class_name: str):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=concat_namespace("?", class_name)
)
await self.graph_db.save()
return
if len(participants) > 1:
for r in participants:
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(r.subject)
)
await self.graph_db.save()
return

participant = participants[0]
Expand Down Expand Up @@ -619,4 +583,31 @@ async def _merge_participant(self, entry: SPO, class_name: str):
await self.graph_db.insert(
subject=entry.subject, predicate=GraphKeyword.HAS_PARTICIPANT, object_=auto_namespace(participant.subject)
)
await self.graph_db.save()
await self._save_sequence_view(subject=entry.subject, content=sequence_view)

async def _save_sequence_view(self, subject: str, content: str):
pattern = re.compile(r"[^a-zA-Z0-9]")
name = re.sub(pattern, "_", subject)
filename = Path(name).with_suffix(".sequence_diagram.mmd")
await self.context.repo.resources.data_api_design.save(filename=str(filename), content=content)

async def _search_participants(self, filename: str) -> Set:
content = await self._get_source_code(filename)

rsp = await self.llm.aask(
msg=content,
system_msgs=[
"You are a tool for listing all class names used in a source file.",
"Return a markdown JSON object with: "
'- a "class_names" key containing the list of class names used in the file; '
'- a "reasons" key lists all reason objects, each object containing a "class_name" key for class name, a "reference" key explaining the line where the class has been used.',
],
)

class _Data(BaseModel):
class_names: List[str]
reasons: List

json_blocks = parse_json_code_block(rsp)
data = _Data.model_validate_json(json_blocks[0])
return set(data.class_names)
13 changes: 10 additions & 3 deletions metagpt/repo_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,14 +722,19 @@ async def rebuild_class_views(self, path: str | Path = None):
path = Path(path)
if not path.exists():
return
init_file = path / "__init__.py"
if not init_file.exists():
raise ValueError("Failed to import module __init__ with error:No module named __init__.")
command = f"pyreverse {str(path)} -o dot"
result = subprocess.run(command, shell=True, check=True, cwd=str(path))
output_dir = path / "__dot__"
output_dir.mkdir(parents=True, exist_ok=True)
result = subprocess.run(command, shell=True, check=True, cwd=str(output_dir))
if result.returncode != 0:
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_view_pathname = output_dir / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
packages_pathname = output_dir / "packages.dot"
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
Expand Down Expand Up @@ -975,6 +980,8 @@ def _repair_ns(package: str, mappings: Dict[str, str]) -> str:
file_ns = file_ns[0:ix]
continue
break
if file_ns == "":
return ""
internal_ns = package[ix + 1 :]
ns = mappings[file_ns] + ":" + internal_ns.replace(".", ":")
return ns
Expand Down
1 change: 0 additions & 1 deletion tests/metagpt/actions/test_rebuild_class_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from metagpt.llm import LLM


@pytest.mark.skip
@pytest.mark.asyncio
async def test_rebuild(context):
action = RebuildClassView(
Expand Down
2 changes: 2 additions & 0 deletions tests/metagpt/actions/test_rebuild_sequence_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ async def test_rebuild(context, mocker):
context=context,
)
await action.run()
rows = await action.graph_db.select()
assert rows
assert context.repo.docs.graph_repo.changed_files


Expand Down

0 comments on commit 3fee7a5

Please sign in to comment.