Skip to content

Commit

Permalink
Merge pull request geekan#980 from orange-crow/refine_parse_outputs_i…
Browse files Browse the repository at this point in the history
…n_ExecuteNbCode

refine parse_outputs in ExecuteNbCode.
  • Loading branch information
garylin2099 authored Mar 11, 2024
2 parents 543f519 + 9808511 commit 882e941
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 52 deletions.
58 changes: 23 additions & 35 deletions metagpt/actions/mi/execute_nb_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import asyncio
import base64
import re
import traceback
from typing import Literal, Tuple

import nbformat
Expand Down Expand Up @@ -92,17 +91,17 @@ def add_output_to_cell(self, cell: NotebookNode, output: str):
else:
cell["outputs"].append(new_output(output_type="stream", name="stdout", text=str(output)))

def parse_outputs(self, outputs: list[str]) -> str:
def parse_outputs(self, outputs: list[str], keep_len: int = 2000) -> Tuple[bool, str]:
"""Parses the outputs received from notebook execution."""
assert isinstance(outputs, list)
parsed_output = ""

parsed_output, is_success = [], True
for i, output in enumerate(outputs):
output_text = ""
if output["output_type"] == "stream" and not any(
tag in output["text"]
for tag in ["| INFO | metagpt", "| ERROR | metagpt", "| WARNING | metagpt", "DEBUG"]
):
parsed_output += output["text"]
output_text = output["text"]
elif output["output_type"] == "display_data":
if "image/png" in output["data"]:
self.show_bytes_figure(output["data"]["image/png"], self.interaction)
Expand All @@ -111,8 +110,22 @@ def parse_outputs(self, outputs: list[str]) -> str:
f"{i}th output['data'] from nbclient outputs dont have image/png, continue next output ..."
)
elif output["output_type"] == "execute_result":
parsed_output += output["data"]["text/plain"]
return parsed_output
output_text = output["data"]["text/plain"]
elif output["output_type"] == "error":
output_text, is_success = "\n".join(output["traceback"]), False

# handle coroutines that are not executed asynchronously
if output_text.strip().startswith("<coroutine object"):
output_text = "Executed code failed, you need use key word 'await' to run a async code."
is_success = False

output_text = remove_escape_and_color_codes(output_text)
# The useful information of the exception is at the end,
# the useful information of normal output is at the begining.
output_text = output_text[:keep_len] if is_success else output_text[-keep_len:]

parsed_output.append(output_text)
return is_success, ",".join(parsed_output)

def show_bytes_figure(self, image_base64: str, interaction_type: Literal["ipython", None]):
image_bytes = base64.b64decode(image_base64)
Expand Down Expand Up @@ -146,7 +159,7 @@ async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str
"""
try:
await self.nb_client.async_execute_cell(cell, cell_index)
return True, ""
return self.parse_outputs(self.nb.cells[-1].outputs)
except CellTimeoutError:
assert self.nb_client.km is not None
await self.nb_client.km.interrupt_kernel()
Expand All @@ -157,7 +170,7 @@ async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str
await self.reset()
return False, "DeadKernelError"
except Exception:
return False, f"{traceback.format_exc()}"
return self.parse_outputs(self.nb.cells[-1].outputs)

async def run(self, code: str, language: Literal["python", "markdown"] = "python") -> Tuple[str, bool]:
"""
Expand All @@ -174,14 +187,7 @@ async def run(self, code: str, language: Literal["python", "markdown"] = "python

# run code
cell_index = len(self.nb.cells) - 1
success, error_message = await self.run_cell(self.nb.cells[-1], cell_index)

if not success:
return truncate(remove_escape_and_color_codes(error_message), is_success=success)

# code success
outputs = self.parse_outputs(self.nb.cells[-1].outputs)
outputs, success = truncate(remove_escape_and_color_codes(outputs), is_success=success)
success, outputs = await self.run_cell(self.nb.cells[-1], cell_index)

if "!pip" in code:
success = False
Expand All @@ -197,24 +203,6 @@ async def run(self, code: str, language: Literal["python", "markdown"] = "python
raise ValueError(f"Only support for language: python, markdown, but got {language}, ")


def truncate(result: str, keep_len: int = 2000, is_success: bool = True):
"""对于超出keep_len个字符的result: 执行失败的代码, 展示result后keep_len个字符; 执行成功的代码, 展示result前keep_len个字符。"""
if is_success:
desc = f"Executed code successfully. Truncated to show only first {keep_len} characters\n"
else:
desc = f"Executed code failed, please reflect the cause of bug and then debug. Truncated to show only last {keep_len} characters\n"

if result.strip().startswith("<coroutine object"):
result = "Executed code failed, you need use key word 'await' to run a async code."
return result, False

if len(result) > keep_len:
result = result[-keep_len:] if not is_success else result[:keep_len]
return desc + result, is_success

return result, is_success


def remove_escape_and_color_codes(input_str: str):
# 使用正则表达式去除jupyter notebook输出结果中的转义字符和颜色代码
# Use regular expressions to get rid of escape characters and color codes in jupyter notebook output.
Expand Down
37 changes: 20 additions & 17 deletions tests/metagpt/actions/mi/test_execute_nb_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest

from metagpt.actions.mi.execute_nb_code import ExecuteNbCode, truncate
from metagpt.actions.mi.execute_nb_code import ExecuteNbCode


@pytest.mark.asyncio
Expand Down Expand Up @@ -54,21 +54,6 @@ async def test_plotting_code():
assert is_success


def test_truncate():
# 代码执行成功
output, is_success = truncate("hello world", 5, True)
assert "Truncated to show only first 5 characters\nhello" in output
assert is_success
# 代码执行失败
output, is_success = truncate("hello world", 5, False)
assert "Truncated to show only last 5 characters\nworld" in output
assert not is_success
# 异步
output, is_success = truncate("<coroutine object", 5, True)
assert not is_success
assert "await" in output


@pytest.mark.asyncio
async def test_run_with_timeout():
executor = ExecuteNbCode(timeout=1)
Expand All @@ -83,7 +68,7 @@ async def test_run_code_text():
executor = ExecuteNbCode()
message, success = await executor.run(code='print("This is a code!")', language="python")
assert success
assert message == "This is a code!\n"
assert "This is a code!" in message
message, success = await executor.run(code="# This is a code!", language="markdown")
assert success
assert message == "# This is a code!"
Expand All @@ -100,6 +85,7 @@ async def test_terminate():
is_kernel_alive = await executor.nb_client.km.is_alive()
assert is_kernel_alive
await executor.terminate()

import time

time.sleep(2)
Expand All @@ -123,3 +109,20 @@ async def test_reset():
assert is_kernel_alive
await executor.reset()
assert executor.nb_client.km is None


@pytest.mark.asyncio
async def test_parse_outputs():
executor = ExecuteNbCode()
code = """
import pandas as pd
df = pd.DataFrame({'ID': [1,2,3], 'NAME': ['a', 'b', 'c']})
print(df.columns)
print(f"columns num:{len(df.columns)}")
print(df['DUMMPY_ID'])
"""
output, is_success = await executor.run(code)
assert not is_success
assert "Index(['ID', 'NAME'], dtype='object')" in output
assert "KeyError: 'DUMMPY_ID'" in output
assert "columns num:2" in output

0 comments on commit 882e941

Please sign in to comment.