Skip to content

Commit

Permalink
refactor(simplebot)🔧: Refactor JSON response handling to use Pydantic…
Browse files Browse the repository at this point in the history
… models

- Added import for Pydantic BaseModel.
- Implemented checks for pydantic_model attribute and its inheritance from BaseModel.
- Updated the JSON response format to use json_schema from the Pydantic model.
  • Loading branch information
ericmjl committed Dec 20, 2024
1 parent 7e29952 commit a6b2702
Showing 1 changed file with 16 additions and 63 deletions.
79 changes: 16 additions & 63 deletions llamabot/bot/simplebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from llamabot.recorder import autorecord, sqlite_log
from llamabot.config import default_language_model
from pydantic import BaseModel

prompt_recorder_var = contextvars.ContextVar("prompt_recorder")

Expand Down Expand Up @@ -211,69 +212,21 @@ def _make_response(bot: SimpleBot, messages: list[BaseMessage], stream: bool = T
if bot.mock_response:
completion_kwargs["mock_response"] = bot.mock_response
if bot.json_mode:
completion_kwargs["response_format"] = {"type": "json_object"}
# Check if bot has pydantic_model attribute and it's a BaseModel
if not hasattr(bot, "pydantic_model"):
raise ValueError(
"Please set a pydantic_model for this bot to use JSON mode!"
)
if not issubclass(getattr(bot, "pydantic_model"), BaseModel):
raise ValueError("pydantic_model must be a Pydantic BaseModel class")

model = getattr(bot, "pydantic_model")

completion_kwargs["response_format"] = {
"type": "json_schema",
"json_schema": model.model_json_schema(),
"strict": True,
}
if bot.api_key:
completion_kwargs["api_key"] = bot.api_key
return completion(**completion_kwargs)

# Commented out until later.
# def panel(
# self,
# input_text_label="Input",
# output_text_label="Output",
# submit_button_label="Submit",
# site_name="SimpleBot",
# title="SimpleBot",
# show=False,
# ):
# """Create a Panel app that wraps a LlamaBot.

# :param input_text_label: The label for the input text.
# :param output_text_label: The label for the output text.
# :param submit_button_label: The label for the submit button.
# :param site_name: The name of the site.
# :param title: The title of the site.
# :param show: Whether to show the app.
# If False, we return the Panel app directly.
# If True, we call `.show()` on the app.
# :return: The Panel app, either showed or directly.
# """
# input_text = pn.widgets.TextAreaInput(
# name=input_text_label, value="", height=200, width=500
# )
# output_text = pn.pane.Markdown("")
# submit = pn.widgets.Button(name=submit_button_label, button_type="success")

# def b(event):
# """Button click handler.

# :param event: The button click event.
# """
# logger.info(input_text.value)
# output_text.object = ""
# markdown_handler = PanelMarkdownCallbackHandler(output_text)
# self.model.callback_manager.set_handler(markdown_handler)
# response = self(input_text.value)
# logger.info(response)

# submit.on_click(b)

# app = pn.template.FastListTemplate(
# site=site_name,
# title=title,
# main=[
# pn.Column(
# *[
# input_text,
# submit,
# pn.pane.Markdown(output_text_label),
# output_text,
# ]
# )
# ],
# main_max_width="768px",
# )
# app = pn.panel(app)
# if show:
# return app.show()
# return app

0 comments on commit a6b2702

Please sign in to comment.