From 644e27945690b593ca2a924f41efc15d2da37bdc Mon Sep 17 00:00:00 2001 From: Gabriele Venturi Date: Sat, 17 Jun 2023 11:40:50 +0200 Subject: [PATCH] refactor: extract functions to get environment and the retry mechanism --- pandasai/__init__.py | 96 ++++++++++++-------- tests/test_pandasai.py | 200 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+), 36 deletions(-) diff --git a/pandasai/__init__.py b/pandasai/__init__.py index f1df1bf4a..57115d71a 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -491,6 +491,58 @@ def _clean_code(self, code: str) -> str: new_tree = ast.Module(body=new_body) return astor.to_source(new_tree).strip() + def _get_environment(self) -> dict: + """ + Returns the environment for the code to be executed. + + Returns (dict): A dictionary of environment variables + """ + + return { + "pd": pd, + **{ + lib["alias"]: getattr(import_dependency(lib["module"]), lib["name"]) + if hasattr(import_dependency(lib["module"]), lib["name"]) + else import_dependency(lib["module"]) + for lib in self._additional_dependencies + }, + "__builtins__": { + **{builtin: __builtins__[builtin] for builtin in WHITELISTED_BUILTINS}, + }, + } + + def _retry_run_code(self, code: str, e: Exception, multiple: bool = False): + """ + A method to retry the code execution with error correction framework. + + Args: + code (str): A python code + e (Exception): An exception + multiple (bool): A boolean to indicate if the code is for multiple + dataframes + + Returns (str): A python code + """ + + if multiple: + error_correcting_instruction = CorrectMultipleDataframesErrorPrompt( + code=code, + error_returned=e, + question=self._original_instructions["question"], + df_head=self._original_instructions["df_head"], + ) + else: + error_correcting_instruction = CorrectErrorPrompt( + code=code, + error_returned=e, + question=self._original_instructions["question"], + df_head=self._original_instructions["df_head"], + num_rows=self._original_instructions["num_rows"], + num_columns=self._original_instructions["num_columns"], + ) + + return self._llm.generate_code(error_correcting_instruction, "") + def run_code( self, code: str, @@ -529,24 +581,12 @@ def run_code( ```""" ) - environment: dict = { - "pd": pd, - **{ - lib["alias"]: getattr(import_dependency(lib["module"]), lib["name"]) - if hasattr(import_dependency(lib["module"]), lib["name"]) - else import_dependency(lib["module"]) - for lib in self._additional_dependencies - }, - "__builtins__": { - **{builtin: __builtins__[builtin] for builtin in WHITELISTED_BUILTINS}, - }, - } + environment: dict = self._get_environment() if multiple: environment.update( {f"df{i}": dataframe for i, dataframe in enumerate(data_frame, start=1)} ) - else: environment["df"] = data_frame @@ -565,29 +605,7 @@ def run_code( count += 1 - if multiple: - error_correcting_instruction = ( - CorrectMultipleDataframesErrorPrompt( - code=code, - error_returned=e, - question=self._original_instructions["question"], - df_head=self._original_instructions["df_head"], - ) - ) - - else: - error_correcting_instruction = CorrectErrorPrompt( - code=code, - error_returned=e, - question=self._original_instructions["question"], - df_head=self._original_instructions["df_head"], - num_rows=self._original_instructions["num_rows"], - num_columns=self._original_instructions["num_columns"], - ) - - code_to_run = self._llm.generate_code( - error_correcting_instruction, "" - ) + code_to_run = self._retry_run_code(code, e, multiple) captured_output = output.getvalue() @@ -617,3 +635,9 @@ def last_prompt_id(self) -> str: if self._prompt_id is None: raise ValueError("Pandas AI has not been run yet.") return self._prompt_id + + @property + def last_prompt(self) -> str: + """Return the last prompt that was executed.""" + if self._llm: + return self._llm.last_prompt diff --git a/tests/test_pandasai.py b/tests/test_pandasai.py index ba00fe353..2afeadd75 100644 --- a/tests/test_pandasai.py +++ b/tests/test_pandasai.py @@ -6,6 +6,8 @@ from uuid import UUID import pandas as pd +import matplotlib.pyplot as plt +import numpy as np import pytest from pandasai import PandasAI @@ -26,6 +28,49 @@ def llm(self, output: Optional[str] = None): def pandasai(self, llm): return PandasAI(llm, enable_cache=False) + @pytest.fixture + def sample_df(self, llm): + return pd.DataFrame( + { + "country": [ + "United States", + "United Kingdom", + "France", + "Germany", + "Italy", + "Spain", + "Canada", + "Australia", + "Japan", + "China", + ], + "gdp": [ + 19294482071552, + 2891615567872, + 2411255037952, + 3435817336832, + 1745433788416, + 1181205135360, + 1607402389504, + 1490967855104, + 4380756541440, + 14631844184064, + ], + "happiness_index": [ + 6.94, + 7.16, + 6.66, + 7.07, + 6.38, + 6.4, + 7.23, + 7.22, + 5.87, + 5.12, + ], + } + ) + @pytest.fixture def test_middleware(self): class TestMiddleware(Middleware): @@ -408,3 +453,158 @@ def test_load_llm_with_langchain_llm(self, pandasai): pandasai._load_llm(langchain_llm) assert pandasai._llm._langchain_llm == langchain_llm + + def test_get_environment(self, pandasai): + pandasai._additional_dependencies = [ + {"name": "pyplot", "alias": "plt", "module": "matplotlib"}, + {"name": "numpy", "alias": "np", "module": "numpy"}, + ] + assert pandasai._get_environment() == { + "pd": pd, + "plt": plt, + "np": np, + "__builtins__": { + "abs": abs, + "all": all, + "any": any, + "ascii": ascii, + "bin": bin, + "bool": bool, + "bytearray": bytearray, + "bytes": bytes, + "callable": callable, + "chr": chr, + "classmethod": classmethod, + "complex": complex, + "delattr": delattr, + "dict": dict, + "dir": dir, + "divmod": divmod, + "enumerate": enumerate, + "filter": filter, + "float": float, + "format": format, + "frozenset": frozenset, + "getattr": getattr, + "hasattr": hasattr, + "hash": hash, + "help": help, + "hex": hex, + "id": id, + "input": input, + "int": int, + "isinstance": isinstance, + "issubclass": issubclass, + "iter": iter, + "len": len, + "list": list, + "locals": locals, + "map": map, + "max": max, + "memoryview": memoryview, + "min": min, + "next": next, + "object": object, + "oct": oct, + "open": open, + "ord": ord, + "pow": pow, + "print": print, + "property": property, + "range": range, + "repr": repr, + "reversed": reversed, + "round": round, + "set": set, + "setattr": setattr, + "slice": slice, + "sorted": sorted, + "staticmethod": staticmethod, + "str": str, + "sum": sum, + "super": super, + "tuple": tuple, + "type": type, + "vars": vars, + "zip": zip, + }, + } + + def test_retry_on_error_with_single_df(self, pandasai, sample_df): + code = 'print("Hello world")' + + pandasai._original_instructions = { + "question": "Print hello world", + "df_head": sample_df.head(), + "num_rows": 10, + "num_columns": 3, + } + pandasai._retry_run_code(code, e=Exception("Test error"), multiple=False) + assert ( + pandasai.last_prompt + == f""" +Today is {date.today()}. +You are provided with a pandas dataframe (df) with 10 rows and 3 columns. +This is the metadata of the dataframe: + country gdp happiness_index +0 United States 19294482071552 6.94 +1 United Kingdom 2891615567872 7.16 +2 France 2411255037952 6.66 +3 Germany 3435817336832 7.07 +4 Italy 1745433788416 6.38. + +The user asked the following question: +Print hello world + +You generated this python code: +print("Hello world") + +It fails with the following error: +Test error + +Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again. +Make sure to prefix the requested python code with exactly and suffix the code with exactly. + + +Code: +""" # noqa: E501 + ) + + def test_retry_on_error_with_multiple_df(self, pandasai, sample_df): + code = 'print("Hello world")' + + pandasai._original_instructions = { + "question": "Print hello world", + "df_head": [sample_df.head()], + "num_rows": 10, + "num_columns": 3, + } + pandasai._retry_run_code(code, e=Exception("Test error"), multiple=True) + assert ( + pandasai.last_prompt + == """ +You are provided with the following pandas dataframes: +Dataframe df1, with 5 rows and 3 columns. +This is the metadata of the dataframe df1: + country gdp happiness_index +0 United States 19294482071552 6.94 +1 United Kingdom 2891615567872 7.16 +2 France 2411255037952 6.66 +3 Germany 3435817336832 7.07 +4 Italy 1745433788416 6.38 +The user asked the following question: +Print hello world + +You generated this python code: +print("Hello world") + +It fails with the following error: +Test error + +Correct the python code and return a new python code (do not import anything) that fixes the above mentioned error. Do not generate the same code again. +Make sure to prefix the requested python code with exactly and suffix the code with exactly. + + +Code: +""" # noqa: E501 + )