Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PythonCodeAgent #896

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo

from lumen.views.base import ExecPythonView

from ..base import Component
from ..dashboard import Config
from ..pipeline import Pipeline
Expand All @@ -35,7 +37,8 @@
from .llm import Llm, Message
from .memory import _Memory
from .models import (
JoinRequired, RetrySpec, Sql, TableJoins, VegaLiteSpec, make_table_model,
JoinRequired, PythonCodeSpec, RetrySpec, Sql, TableJoins, VegaLiteSpec,
make_table_model,
)
from .tools import DocumentLookup, TableLookup
from .translate import param_to_pydantic
Expand Down Expand Up @@ -900,6 +903,34 @@ async def _extract_spec(self, spec: dict[str, Any]):
return {'spec': vega_spec, "sizing_mode": "stretch_both", "min_height": 500}


class PythonCodeAgent(BaseViewAgent):
"""
Agent that generates Python code, commonly for visualizations.
"""

purpose = param.String(default="""
Generates Python code based on the user's visualization request.""")

prompts = param.Dict(
default={
"main": {
"response_model": PythonCodeSpec,
"template": PROMPTS_DIR / "PythonCodeAgent" / "main.jinja2"
},
}
)

view_type = ExecPythonView

async def _update_spec(self, memory: _Memory, event: param.parameterized.Event):
spec = yaml.load(event.new, Loader=yaml.SafeLoader)
memory['view'] = dict(await self._extract_spec({"code": spec}), type=self.view_type)

async def _extract_spec(self, spec: dict[str, Any]):
python_spec = spec["code"]
return {'spec': python_spec}


class AnalysisAgent(LumenBaseAgent):

analyses = param.List([])
Expand Down
10 changes: 10 additions & 0 deletions lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ class VegaLiteSpec(BaseModel):
json_spec: str = Field(description="A vega-lite JSON specification. Do not under any circumstances generate the data field.")


class PythonCodeSpec(BaseModel):

chain_of_thought: str = Field(
description="Explain the thought process behind the code snippet."
)

code: str = Field(
description="The code snippet that answers the user query."
)


class RetrySpec(BaseModel):

Expand Down
25 changes: 25 additions & 0 deletions lumen/ai/prompts/PythonCodeAgent/main.jinja2
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{% extends 'BaseViewAgent/main.jinja2' %}

{% block instructions %}
Help write a Python program that will satisfy the user's request.

The following variables are already defined for you:
- `pn`: The HoloViz Panel package.
- `pipeline`: The current pipeline, if it's avaialble.
- `data`: The current data from the pipeline, if it's available.

Be sure to import any necessary libraries used in the code.
If no `return` statement is present, the last line of the code will be returned.
Prefer quotations `"` over apostrophes `'` in Python code.
{% endblock %}

{%- block examples %}
If the user asks to plot a line chart of the pipeline, do not call plt.show():

'''
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.plot(data["x"], data["y"])
fig
'''
{% endblock %}
10 changes: 9 additions & 1 deletion lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,15 @@ async def get_schema(
elif limit and len(spec["enum"]) == 1 and spec["enum"][0] is None:
spec["enum"] = [f"(unknown; truncated to {get_kwargs['limit']} rows)"]
# truncate each enum to 100 characters
spec["enum"] = [enum if enum is None or len(enum) < 100 else f"{enum[:100]} ..." for enum in spec["enum"]]
spec["enum"] = [
enum
if (
enum is None
or not isinstance(enum, str)
or len(enum) < 100
)
else f"{enum[:100]} ..."
for enum in spec["enum"]]

if count and include_count:
schema["count"] = count
Expand Down
187 changes: 186 additions & 1 deletion lumen/tests/views/test_base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import pathlib

from pathlib import Path

import holoviews as hv # type: ignore
import numpy as np
import pandas as pd
import panel as pn
import pytest

from panel.layout import Column
from panel.pane import Markdown
from panel.pane.alert import Alert
from panel.widgets import Checkbox

from lumen.filters.base import ConstantFilter
from lumen.panel import DownloadButton
from lumen.pipeline import Pipeline
from lumen.sources.base import FileSource
from lumen.state import state
from lumen.variables.base import Variables
from lumen.views.base import (
Panel, View, hvOverlayView, hvPlotView,
ExecPythonView, Panel, View, hvOverlayView, hvPlotView,
)


Expand Down Expand Up @@ -442,3 +449,181 @@ def test_panel_cross_reference_rx():
'type': 'panel.layout.base.Column'
}
}


def test_exec_python_basic_execution():
view = ExecPythonView(
spec='''"Hello World"'''
)
result = view.execute_code()
assert result == "Hello World"


def test_exec_python_holoviews_execution():
"""Test execution with HoloViews output"""
view = ExecPythonView(
spec="""
import numpy as np
import holoviews as hv

data = np.array([[1, 2], [3, 4]])
hv.Scatter(data)
"""
)
result = view.execute_code()
assert isinstance(result, hv.Element)
assert np.array_equal(result.array(), np.array([[1, 2], [3, 4]]))


def test_exec_python_inside_function_execution():
"""Test execution with HoloViews output"""
view = ExecPythonView(
spec="""
def plot():
import numpy as np
import holoviews as hv
data = np.array([[1, 2], [3, 4]])
return hv.Scatter(data)

plot()
"""
)
result = view.execute_code()
assert isinstance(result, hv.Element)
assert np.array_equal(result.array(), np.array([[1, 2], [3, 4]]))


def test_exec_python_error_handling():
view = ExecPythonView(
spec="""
this is not valid python
"""
)
result = view.execute_code()
assert isinstance(result, Alert)
assert "Error executing code:" in result.object


def test_exec_python_panel_output():
"""Test that Panel objects are handled correctly"""
view = ExecPythonView(
spec="""
return pn.Column(
pn.pane.Markdown("# Title"),
pn.pane.Markdown("Content")
)
"""
)
result = view.execute_code()
assert isinstance(result, Column)
assert len(result) == 2
assert all(isinstance(obj, Markdown) for obj in result)


def test_exec_python_indentation_handling():
view = ExecPythonView(
spec="""
x = 1
y = 2
# Extra indented comment
return x + y
"""
)
result = view.execute_code()
assert result == 3


def test_exec_python_exec_code_view_roundtrip():
original_code = '''
import holoviews as hv
import numpy as np
data = np.array([[1, 2], [3, 4]])
hv.Scatter(data)
'''

view = ExecPythonView(spec=original_code)
spec = view.to_spec()

assert spec == {
'type': 'exec_code',
'spec': original_code
}

reconstructed = ExecPythonView.from_spec(spec)
assert reconstructed.spec == original_code

original_result = view.execute_code()
reconstructed_result = reconstructed.execute_code()
assert type(original_result) is type(reconstructed_result)
assert np.array_equal(original_result.array(), reconstructed_result.array())


def test_exec_python_return_statement():
view = ExecPythonView(
spec="""
x = 42
return x
"""
)
result = view.execute_code()
assert result == 42


def test_exec_python_within_python_code_fence():
view = ExecPythonView(
spec="""
```python
x = 42
return x
```
"""
)
result = view.execute_code()
assert result == 42


def test_exec_python_multiple_code_fence():
view = ExecPythonView(
spec="""
Here's x
```python
x = 42
```
Here's y
```
y = 42
return x + y
```
"""
)
result = view.execute_code()
assert result == 84


def test_exec_python_pipeline_data_access(make_filesource):
root = pathlib.Path(__file__).parent / ".." / 'sources'
source = make_filesource(str(root))
cfilter = ConstantFilter(field='A', value=(1, 2))
pipeline = Pipeline(source=source, filters=[cfilter], table='test')
view = ExecPythonView(
pipeline=pipeline,
spec="""
return data["A"].sum()
"""
)
result = view.execute_code()
assert result == 3.0


def test_exec_python_multiple_imports():
view = ExecPythonView(
spec="""
import numpy as np
import pandas as pd
arr = np.array([1, 2, 3])
df = pd.DataFrame({'data': arr})
return df['data'].sum()
"""
)
result = view.execute_code()
assert result == 6
Loading
Loading