Skip to content

Commit

Permalink
Merge branch 'code_interpreter' into mi_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
garylin2099 committed Mar 12, 2024
2 parents b5af9cc + 612e4e1 commit 3244e6c
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 103 deletions.
14 changes: 12 additions & 2 deletions examples/mi/machine_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@

from metagpt.roles.mi.interpreter import Interpreter

WINE_REQ = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."

async def main(auto_run: bool = True):
requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy."
DATA_DIR = "path/to/your/data"
# sales_forecast data from https://www.kaggle.com/datasets/aslanahmedov/walmart-sales-forecast/data
SALES_FORECAST_REQ = f"""Train a model to predict sales for each department in every store (split the last 40 weeks records as validation dataset, the others is train dataset), include plot total sales trends, print metric and plot scatter plots of
groud truth and predictions on validation data. Dataset is {DATA_DIR}/train.csv, the metric is weighted mean absolute error (WMAE) for test data. Notice: *print* key variables to get more information for next task step.
"""

REQUIREMENTS = {"wine": WINE_REQ, "sales_forecast": SALES_FORECAST_REQ}


async def main(auto_run: bool = True, use_case: str = "wine"):
mi = Interpreter(auto_run=auto_run)
requirement = REQUIREMENTS[use_case]
await mi.run(requirement)


Expand Down
82 changes: 37 additions & 45 deletions metagpt/actions/mi/execute_nb_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import asyncio
import base64
import re
import traceback
from typing import Literal, Tuple

import nbformat
Expand Down Expand Up @@ -58,7 +57,8 @@ async def build(self):

async def terminate(self):
"""kill NotebookClient"""
await self.nb_client._async_cleanup_kernel()
if self.nb_client.km is not None:
await self.nb_client._async_cleanup_kernel()

async def reset(self):
"""reset NotebookClient"""
Expand Down Expand Up @@ -91,17 +91,17 @@ def add_output_to_cell(self, cell: NotebookNode, output: str):
else:
cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output)))

def parse_outputs(self, outputs: list[str]) -> str:
def parse_outputs(self, outputs: list[str], keep_len: int = 2000) -> Tuple[bool, str]:
"""Parses the outputs received from notebook execution."""
assert isinstance(outputs, list)
parsed_output = ""

parsed_output, is_success = [], True
for i, output in enumerate(outputs):
output_text = ""
if output["output_type"] == "stream" and not any(
tag in output["text"]
for tag in ["| INFO | metagpt", "| ERROR | metagpt", "| WARNING | metagpt", "DEBUG"]
):
parsed_output += output["text"]
output_text = output["text"]
elif output["output_type"] == "display_data":
if "image/png" in output["data"]:
self.show_bytes_figure(output["data"]["image/png"], self.interaction)
Expand All @@ -110,8 +110,22 @@ def parse_outputs(self, outputs: list[str]) -> str:
f"{i}th output['data'] from nbclient outputs dont have image/png, continue next output ..."
)
elif output["output_type"] == "execute_result":
parsed_output += output["data"]["text/plain"]
return parsed_output
output_text = output["data"]["text/plain"]
elif output["output_type"] == "error":
output_text, is_success = "\n".join(output["traceback"]), False

# handle coroutines that are not executed asynchronously
if output_text.strip().startswith("<coroutine object"):
output_text = "Executed code failed, you need use key word 'await' to run a async code."
is_success = False

output_text = remove_escape_and_color_codes(output_text)
# The useful information of the exception is at the end,
# the useful information of normal output is at the begining.
output_text = output_text[:keep_len] if is_success else output_text[-keep_len:]

parsed_output.append(output_text)
return is_success, ",".join(parsed_output)

def show_bytes_figure(self, image_base64: str, interaction_type: Literal["ipython", None]):
image_bytes = base64.b64decode(image_base64)
Expand Down Expand Up @@ -145,7 +159,7 @@ async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str
"""
try:
await self.nb_client.async_execute_cell(cell, cell_index)
return True, ""
return self.parse_outputs(self.nb.cells[-1].outputs)
except CellTimeoutError:
assert self.nb_client.km is not None
await self.nb_client.km.interrupt_kernel()
Expand All @@ -156,7 +170,7 @@ async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str
await self.reset()
return False, "DeadKernelError"
except Exception:
return False, f"{traceback.format_exc()}"
return self.parse_outputs(self.nb.cells[-1].outputs)

async def run(self, code: str, language: Literal["python", "markdown"] = "python") -> Tuple[str, bool]:
"""
Expand All @@ -173,16 +187,9 @@ async def run(self, code: str, language: Literal["python", "markdown"] = "python

# run code
cell_index = len(self.nb.cells) - 1
success, error_message = await self.run_cell(self.nb.cells[-1], cell_index)

if not success:
return truncate(remove_escape_and_color_codes(error_message), is_success=success)

# code success
outputs = self.parse_outputs(self.nb.cells[-1].outputs)
outputs, success = truncate(remove_escape_and_color_codes(outputs), is_success=success)
success, outputs = await self.run_cell(self.nb.cells[-1], cell_index)

if "!pip" in outputs:
if "!pip" in code:
success = False

return outputs, success
Expand All @@ -196,54 +203,39 @@ async def run(self, code: str, language: Literal["python", "markdown"] = "python
raise ValueError(f"Only support for language: python, markdown, but got {language}, ")


def truncate(result: str, keep_len: int = 2000, is_success: bool = True):
"""对于超出keep_len个字符的result: 执行失败的代码, 展示result后keep_len个字符; 执行成功的代码, 展示result前keep_len个字符。"""
if is_success:
desc = f"Executed code successfully. Truncated to show only first {keep_len} characters\n"
else:
desc = f"Executed code failed, please reflect on the cause of bug and then debug. Truncated to show only last {keep_len} characters\n"

if result.strip().startswith("<coroutine object"):
result = "Executed code failed, you need use key word 'await' to run a async code."
return result, False

if len(result) > keep_len:
result = result[-keep_len:] if not is_success else result[:keep_len]
return desc + result, is_success

return result, is_success


def remove_escape_and_color_codes(input_str: str):
# 使用正则表达式去除转义字符和颜色代码
# 使用正则表达式去除jupyter notebook输出结果中的转义字符和颜色代码
# Use regular expressions to get rid of escape characters and color codes in jupyter notebook output.
pattern = re.compile(r"\x1b\[[0-9;]*[mK]")
result = pattern.sub("", input_str)
return result


def display_markdown(content: str):
# 使用正则表达式逐个匹配代码块
# Use regular expressions to match blocks of code one by one.
matches = re.finditer(r"```(.+?)```", content, re.DOTALL)
start_index = 0
content_panels = []
# 逐个打印匹配到的文本和代码
# Set the text background color and text color.
style = "black on white"
# Print the matching text and code one by one.
for match in matches:
text_content = content[start_index : match.start()].strip()
code_content = match.group(0).strip()[3:-3] # Remove triple backticks

if text_content:
content_panels.append(Panel(Markdown(text_content), box=MINIMAL))
content_panels.append(Panel(Markdown(text_content), style=style, box=MINIMAL))

if code_content:
content_panels.append(Panel(Markdown(f"```{code_content}"), box=MINIMAL))
content_panels.append(Panel(Markdown(f"```{code_content}"), style=style, box=MINIMAL))
start_index = match.end()

# 打印剩余文本(如果有)
# Print remaining text (if any).
remaining_text = content[start_index:].strip()
if remaining_text:
content_panels.append(Panel(Markdown(remaining_text), box=MINIMAL))
content_panels.append(Panel(Markdown(remaining_text), style=style, box=MINIMAL))

# 在Live模式中显示所有Panel
# Display all panels in Live mode.
with Live(auto_refresh=False, console=Console(), vertical_overflow="visible") as live:
live.update(Group(*content_panels))
live.refresh()
4 changes: 4 additions & 0 deletions metagpt/roles/mi/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ async def _act(self) -> Message:
code, _, _ = await self._write_and_exec_code()
return Message(content=code, role="assistant", cause_by=WriteCodeWithTools)

async def _plan_and_act(self) -> Message:
await super()._plan_and_act()
await self.execute_code.terminate()

async def _act_on_task(self, current_task: Task) -> TaskResult:
"""Useful in 'plan_and_act' mode. Wrap the output in a TaskResult for review and confirmation."""
code, result, is_success = await self._write_and_exec_code()
Expand Down
59 changes: 22 additions & 37 deletions metagpt/tools/libs/gpt_v_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
@Author : mannaandpoem
@File : gpt_v_generator.py
"""
import os
import re
from pathlib import Path

from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import encode_image
from metagpt.utils.common import CodeParser, encode_image

ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX designer, please generate layout information for this image:
Expand All @@ -29,7 +30,7 @@

@register_tool(include_functions=["__init__", "generate_webpages", "save_webpages"])
class GPTvGenerator:
"""Class for generating webpages at once.
"""Class for generating webpage code from a given webpage screenshot.
This class provides methods to generate webpages including all code (HTML, CSS, and JavaScript) based on an image.
It utilizes a vision model to analyze the layout from an image and generate webpage codes accordingly.
Expand Down Expand Up @@ -72,50 +73,34 @@ async def generate_webpages(self, image_path: str) -> str:
return await self.llm.aask(msg=prompt, images=[encode_image(image_path)])

@staticmethod
def save_webpages(image_path: str, webpages: str) -> Path:
def save_webpages(webpages: str, save_folder_name: str = "example") -> Path:
"""Save webpages including all code (HTML, CSS, and JavaScript) at once.
Args:
image_path (str): The path of the image file.
webpages (str): The generated webpages content.
save_folder_name (str, optional): The name of the folder to save the webpages. Defaults to 'example'.
Returns:
Path: The path of the saved webpages.
"""
# Create a folder called webpages in the workspace directory to store HTML, CSS, and JavaScript files
webpages_path = DEFAULT_WORKSPACE_ROOT / "webpages" / Path(image_path).stem
os.makedirs(webpages_path, exist_ok=True)
webpages_path = DEFAULT_WORKSPACE_ROOT / "webpages" / save_folder_name
logger.info(f"code will be saved at {webpages_path}")
webpages_path.mkdir(parents=True, exist_ok=True)

index_path = webpages_path / "index.html"
try:
index = webpages.split("```html")[1].split("```")[0]
style_path = None
if "styles.css" in index:
style_path = webpages_path / "styles.css"
elif "style.css" in index:
style_path = webpages_path / "style.css"
style = webpages.split("```css")[1].split("```")[0] if style_path else ""

js_path = None
if "scripts.js" in index:
js_path = webpages_path / "scripts.js"
elif "script.js" in index:
js_path = webpages_path / "script.js"

js = webpages.split("```javascript")[1].split("```")[0] if js_path else ""
except IndexError:
raise ValueError(f"No html or css or js code found in the result. \nWebpages: {webpages}")

try:
with open(index_path, "w", encoding="utf-8") as f:
f.write(index)
if style_path:
with open(style_path, "w", encoding="utf-8") as f:
f.write(style)
if js_path:
with open(js_path, "w", encoding="utf-8") as f:
f.write(js)
except FileNotFoundError as e:
raise FileNotFoundError(f"Cannot save the webpages to {str(webpages_path)}") from e
index_path.write_text(CodeParser.parse_code(block=None, text=webpages, lang="html"))

extract_and_save_code(folder=webpages_path, text=webpages, pattern="styles?.css", language="css")

extract_and_save_code(folder=webpages_path, text=webpages, pattern="scripts?.js", language="javascript")

return webpages_path


def extract_and_save_code(folder, text, pattern, language):
word = re.search(pattern, text)
if word:
path = folder / word.group(0)
code = CodeParser.parse_code(block=None, text=text, lang=language)
path.write_text(code, encoding="utf-8")
46 changes: 29 additions & 17 deletions tests/metagpt/actions/mi/test_execute_nb_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from metagpt.actions.mi.execute_nb_code import ExecuteNbCode, truncate
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode


@pytest.mark.asyncio
Expand Down Expand Up @@ -54,21 +54,6 @@ async def test_plotting_code():
assert is_success


def test_truncate():
# 代码执行成功
output, is_success = truncate("hello world", 5, True)
assert "Truncated to show only first 5 characters\nhello" in output
assert is_success
# 代码执行失败
output, is_success = truncate("hello world", 5, False)
assert "Truncated to show only last 5 characters\nworld" in output
assert not is_success
# 异步
output, is_success = truncate("<coroutine object", 5, True)
assert not is_success
assert "await" in output


@pytest.mark.asyncio
async def test_run_with_timeout():
executor = ExecuteNbCode(timeout=1)
Expand All @@ -83,7 +68,7 @@ async def test_run_code_text():
executor = ExecuteNbCode()
message, success = await executor.run(code='print("This is a code!")', language="python")
assert success
assert message == "This is a code!\n"
assert "This is a code!" in message
message, success = await executor.run(code="# This is a code!", language="markdown")
assert success
assert message == "# This is a code!"
Expand All @@ -100,10 +85,20 @@ async def test_terminate():
is_kernel_alive = await executor.nb_client.km.is_alive()
assert is_kernel_alive
await executor.terminate()

import time

time.sleep(2)
assert executor.nb_client.km is None
for _ in range(200):
executor = ExecuteNbCode()
await executor.run(code='print("This is a code!")', language="python")
is_kernel_alive = await executor.nb_client.km.is_alive()
assert is_kernel_alive
await executor.terminate()
assert executor.nb_client.km is None
assert executor.nb_client.kc is None
await executor.terminate()


@pytest.mark.asyncio
Expand All @@ -114,3 +109,20 @@ async def test_reset():
assert is_kernel_alive
await executor.reset()
assert executor.nb_client.km is None


@pytest.mark.asyncio
async def test_parse_outputs():
executor = ExecuteNbCode()
code = """
import pandas as pd
df = pd.DataFrame({'ID': [1,2,3], 'NAME': ['a', 'b', 'c']})
print(df.columns)
print(f"columns num:{len(df.columns)}")
print(df['DUMMPY_ID'])
"""
output, is_success = await executor.run(code)
assert not is_success
assert "Index(['ID', 'NAME'], dtype='object')" in output
assert "KeyError: 'DUMMPY_ID'" in output
assert "columns num:2" in output
Loading

0 comments on commit 3244e6c

Please sign in to comment.