Skip to content

Commit

Permalink
refactor: extract functions to get environment and the retry mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
gventuri committed Jun 17, 2023
1 parent b096f70 commit 644e279
Show file tree
Hide file tree
Showing 2 changed files with 260 additions and 36 deletions.
96 changes: 60 additions & 36 deletions pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,58 @@ def _clean_code(self, code: str) -> str:
new_tree = ast.Module(body=new_body)
return astor.to_source(new_tree).strip()

def _get_environment(self) -> dict:
"""
Returns the environment for the code to be executed.
Returns (dict): A dictionary of environment variables
"""

return {
"pd": pd,
**{
lib["alias"]: getattr(import_dependency(lib["module"]), lib["name"])
if hasattr(import_dependency(lib["module"]), lib["name"])
else import_dependency(lib["module"])
for lib in self._additional_dependencies
},
"__builtins__": {
**{builtin: __builtins__[builtin] for builtin in WHITELISTED_BUILTINS},
},
}

def _retry_run_code(self, code: str, e: Exception, multiple: bool = False):
"""
A method to retry the code execution with error correction framework.
Args:
code (str): A python code
e (Exception): An exception
multiple (bool): A boolean to indicate if the code is for multiple
dataframes
Returns (str): A python code
"""

if multiple:
error_correcting_instruction = CorrectMultipleDataframesErrorPrompt(
code=code,
error_returned=e,
question=self._original_instructions["question"],
df_head=self._original_instructions["df_head"],
)
else:
error_correcting_instruction = CorrectErrorPrompt(
code=code,
error_returned=e,
question=self._original_instructions["question"],
df_head=self._original_instructions["df_head"],
num_rows=self._original_instructions["num_rows"],
num_columns=self._original_instructions["num_columns"],
)

return self._llm.generate_code(error_correcting_instruction, "")

def run_code(
self,
code: str,
Expand Down Expand Up @@ -529,24 +581,12 @@ def run_code(
```"""
)

environment: dict = {
"pd": pd,
**{
lib["alias"]: getattr(import_dependency(lib["module"]), lib["name"])
if hasattr(import_dependency(lib["module"]), lib["name"])
else import_dependency(lib["module"])
for lib in self._additional_dependencies
},
"__builtins__": {
**{builtin: __builtins__[builtin] for builtin in WHITELISTED_BUILTINS},
},
}
environment: dict = self._get_environment()

if multiple:
environment.update(
{f"df{i}": dataframe for i, dataframe in enumerate(data_frame, start=1)}
)

else:
environment["df"] = data_frame

Expand All @@ -565,29 +605,7 @@ def run_code(

count += 1

if multiple:
error_correcting_instruction = (
CorrectMultipleDataframesErrorPrompt(
code=code,
error_returned=e,
question=self._original_instructions["question"],
df_head=self._original_instructions["df_head"],
)
)

else:
error_correcting_instruction = CorrectErrorPrompt(
code=code,
error_returned=e,
question=self._original_instructions["question"],
df_head=self._original_instructions["df_head"],
num_rows=self._original_instructions["num_rows"],
num_columns=self._original_instructions["num_columns"],
)

code_to_run = self._llm.generate_code(
error_correcting_instruction, ""
)
code_to_run = self._retry_run_code(code, e, multiple)

captured_output = output.getvalue()

Expand Down Expand Up @@ -617,3 +635,9 @@ def last_prompt_id(self) -> str:
if self._prompt_id is None:
raise ValueError("Pandas AI has not been run yet.")
return self._prompt_id

@property
def last_prompt(self) -> str:
"""Return the last prompt that was executed."""
if self._llm:
return self._llm.last_prompt
200 changes: 200 additions & 0 deletions tests/test_pandasai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from uuid import UUID

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pytest

from pandasai import PandasAI
Expand All @@ -26,6 +28,49 @@ def llm(self, output: Optional[str] = None):
def pandasai(self, llm):
return PandasAI(llm, enable_cache=False)

@pytest.fixture
def sample_df(self, llm):
return pd.DataFrame(
{
"country": [
"United States",
"United Kingdom",
"France",
"Germany",
"Italy",
"Spain",
"Canada",
"Australia",
"Japan",
"China",
],
"gdp": [
19294482071552,
2891615567872,
2411255037952,
3435817336832,
1745433788416,
1181205135360,
1607402389504,
1490967855104,
4380756541440,
14631844184064,
],
"happiness_index": [
6.94,
7.16,
6.66,
7.07,
6.38,
6.4,
7.23,
7.22,
5.87,
5.12,
],
}
)

@pytest.fixture
def test_middleware(self):
class TestMiddleware(Middleware):
Expand Down Expand Up @@ -408,3 +453,158 @@ def test_load_llm_with_langchain_llm(self, pandasai):

pandasai._load_llm(langchain_llm)
assert pandasai._llm._langchain_llm == langchain_llm

def test_get_environment(self, pandasai):
pandasai._additional_dependencies = [
{"name": "pyplot", "alias": "plt", "module": "matplotlib"},
{"name": "numpy", "alias": "np", "module": "numpy"},
]
assert pandasai._get_environment() == {
"pd": pd,
"plt": plt,
"np": np,
"__builtins__": {
"abs": abs,
"all": all,
"any": any,
"ascii": ascii,
"bin": bin,
"bool": bool,
"bytearray": bytearray,
"bytes": bytes,
"callable": callable,
"chr": chr,
"classmethod": classmethod,
"complex": complex,
"delattr": delattr,
"dict": dict,
"dir": dir,
"divmod": divmod,
"enumerate": enumerate,
"filter": filter,
"float": float,
"format": format,
"frozenset": frozenset,
"getattr": getattr,
"hasattr": hasattr,
"hash": hash,
"help": help,
"hex": hex,
"id": id,
"input": input,
"int": int,
"isinstance": isinstance,
"issubclass": issubclass,
"iter": iter,
"len": len,
"list": list,
"locals": locals,
"map": map,
"max": max,
"memoryview": memoryview,
"min": min,
"next": next,
"object": object,
"oct": oct,
"open": open,
"ord": ord,
"pow": pow,
"print": print,
"property": property,
"range": range,
"repr": repr,
"reversed": reversed,
"round": round,
"set": set,
"setattr": setattr,
"slice": slice,
"sorted": sorted,
"staticmethod": staticmethod,
"str": str,
"sum": sum,
"super": super,
"tuple": tuple,
"type": type,
"vars": vars,
"zip": zip,
},
}

def test_retry_on_error_with_single_df(self, pandasai, sample_df):
code = 'print("Hello world")'

pandasai._original_instructions = {
"question": "Print hello world",
"df_head": sample_df.head(),
"num_rows": 10,
"num_columns": 3,
}
pandasai._retry_run_code(code, e=Exception("Test error"), multiple=False)
assert (
pandasai.last_prompt
== f"""
Today is {date.today()}.
You are provided with a pandas dataframe (df) with 10 rows and 3 columns.
This is the metadata of the dataframe:
country gdp happiness_index
0 United States 19294482071552 6.94
1 United Kingdom 2891615567872 7.16
2 France 2411255037952 6.66
3 Germany 3435817336832 7.07
4 Italy 1745433788416 6.38.
The user asked the following question:
Print hello world
You generated this python code:
print("Hello world")
It fails with the following error:
Test error
Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again.
Make sure to prefix the requested python code with <startCode> exactly and suffix the code with <endCode> exactly.
Code:
""" # noqa: E501
)

def test_retry_on_error_with_multiple_df(self, pandasai, sample_df):
code = 'print("Hello world")'

pandasai._original_instructions = {
"question": "Print hello world",
"df_head": [sample_df.head()],
"num_rows": 10,
"num_columns": 3,
}
pandasai._retry_run_code(code, e=Exception("Test error"), multiple=True)
assert (
pandasai.last_prompt
== """
You are provided with the following pandas dataframes:
Dataframe df1, with 5 rows and 3 columns.
This is the metadata of the dataframe df1:
country gdp happiness_index
0 United States 19294482071552 6.94
1 United Kingdom 2891615567872 7.16
2 France 2411255037952 6.66
3 Germany 3435817336832 7.07
4 Italy 1745433788416 6.38
The user asked the following question:
Print hello world
You generated this python code:
print("Hello world")
It fails with the following error:
Test error
Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again.
Make sure to prefix the requested python code with <startCode> exactly and suffix the code with <endCode> exactly.
Code:
""" # noqa: E501
)

0 comments on commit 644e279

Please sign in to comment.