From f33ecd1e009c9ec30fd21ea26512d98b280c8dcc Mon Sep 17 00:00:00 2001 From: Rishabh Srivastava Date: Thu, 23 Jan 2025 17:04:24 +0800 Subject: [PATCH] add demo ipython file --- demo.ipynb | 557 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 557 insertions(+) create mode 100644 demo.ipynb diff --git a/demo.ipynb b/demo.ipynb new file mode 100644 index 0000000..f5adbdc --- /dev/null +++ b/demo.ipynb @@ -0,0 +1,557 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Demoing `defog_utils`" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Purpose**\n", + "- Be able to switch between different LLM providers without having to rewrite the code\n", + "\n", + "**Assumptions**\n", + "- Every LLM provider will have common features, like structure outputs, and streaming. There are features that are specific to individual LLM providers (like predicted outputs for OpenAI models). But for the vast majority of use-cases, we can use chat_async" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ[\"DEEPSEEK_API_KEY\"] = \"sk-6a95286f3a4243d4ba209d76db36595e\"\n", + "from defog_utils.utils_multi_llm import chat_async\n", + "\n", + "\n", + "# remember to have the following in your environment:\n", + "# - OPENAI_API_KEY\n", + "# - GEMINI_API_KEY\n", + "# - ANTHROPIC_API_KEY\n", + "# - DEEPSEEK_API_KEY\n", + "# (optionally) - TOGETHER_API_KEY" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "from pprint import pprint\n", + "\n", + "def pretty_print_llm_response(resp):\n", + " if resp.output_tokens_details:\n", + " resp.output_tokens_details = resp.output_tokens_details.__dict__\n", + " \n", + " if type(resp.content) != str:\n", + " resp.content = resp.content.__dict__\n", + "\n", + " # format cost_in_cents to 3 significant figures\n", + " # it's okay to use scientific notation for this\n", + " sig_figs = 3\n", + " format_string = f\"{{:.{sig_figs - 1}e}}\"\n", + " resp.cost_in_cents = format_string.format(resp.cost_in_cents)\n", + " pprint(resp.__dict__, width=100, indent=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### First, let's do normal chat messages" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"Your task is to generate SQL given a natural language question and schema of the user's database. Do not use aliases. Return only the SQL without ```.\"\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"\"\"Question: What is the total number of orders?\n", + "Schema:\n", + "```sql\n", + "CREATE TABLE orders (\n", + " order_id int,\n", + " customer_id int,\n", + " employee_id int,\n", + " order_date date\n", + ");\n", + "```\"\"\"\n", + "}]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Let's try gpt-4o-mini first" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ 'content': 'SELECT COUNT(*) FROM '\n", + " 'orders;',\n", + " 'cost_in_cents': '3.72e-04',\n", + " 'input_tokens': 82,\n", + " 'model': 'gpt-4o-mini',\n", + " 'output_tokens': 6,\n", + " 'output_tokens_details': { 'accepted_prediction_tokens': 0,\n", + " 'audio_tokens': 0,\n", + " 'reasoning_tokens': 0,\n", + " 'rejected_prediction_tokens': 0},\n", + " 'time': 0.639}\n" + ] + } + ], + "source": [ + "resp = await chat_async(\n", + " model=\"gpt-4o-mini\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + ")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Now let's try gpt-4o" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ 'content': 'SELECT COUNT(*) FROM '\n", + " 'orders;',\n", + " 'cost_in_cents': '6.21e-03',\n", + " 'input_tokens': 82,\n", + " 'model': 'gpt-4o',\n", + " 'output_tokens': 6,\n", + " 'output_tokens_details': { 'accepted_prediction_tokens': 0,\n", + " 'audio_tokens': 0,\n", + " 'reasoning_tokens': 0,\n", + " 'rejected_prediction_tokens': 0},\n", + " 'time': 0.652}\n" + ] + } + ], + "source": [ + "resp = await chat_async(\n", + " model=\"gpt-4o\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + ")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Now let's try o1-mini\n", + "\n", + "Note that o1 models do not support temperature OR system prompts. defog_utils will automatically take care of this" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ 'content': 'SELECT COUNT(*) FROM '\n", + " 'orders;',\n", + " 'cost_in_cents': '1.79e-01',\n", + " 'input_tokens': 84,\n", + " 'model': 'o1-mini',\n", + " 'output_tokens': 149,\n", + " 'output_tokens_details': { 'accepted_prediction_tokens': 0,\n", + " 'audio_tokens': 0,\n", + " 'reasoning_tokens': 128,\n", + " 'rejected_prediction_tokens': 0},\n", + " 'time': 2.041}\n" + ] + } + ], + "source": [ + "resp = await chat_async(\n", + " model=\"o1-mini\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + ")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Finally, try o1" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{ 'content': 'SELECT COUNT(*) FROM '\n", + " 'orders;',\n", + " 'cost_in_cents': '9.02e-01',\n", + " 'input_tokens': 107,\n", + " 'model': 'o1',\n", + " 'output_tokens': 150,\n", + " 'output_tokens_details': { 'accepted_prediction_tokens': 0,\n", + " 'audio_tokens': 0,\n", + " 'reasoning_tokens': 128,\n", + " 'rejected_prediction_tokens': 0},\n", + " 'time': 2.582}\n" + ] + } + ], + "source": [ + "resp = await chat_async(\n", + " model=\"o1\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + ")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### We can use exactly the same thing with non-openai models as well" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CLAUDE 3.5 SONNET\n", + "{ 'content': 'SELECT COUNT(*) FROM '\n", + " 'orders',\n", + " 'cost_in_cents': '1.42e-02',\n", + " 'input_tokens': 245,\n", + " 'model': 'claude-3-5-sonnet-latest',\n", + " 'output_tokens': 9,\n", + " 'output_tokens_details': None,\n", + " 'time': 0.644}\n", + "GEMINI 2.0 FLASH EXP\n", + "{ 'content': '```sql\\n'\n", + " 'SELECT count(*) FROM '\n", + " 'orders\\n'\n", + " '```',\n", + " 'cost_in_cents': '3.18e-04',\n", + " 'input_tokens': 246,\n", + " 'model': 'gemini-2.0-flash-exp',\n", + " 'output_tokens': 10,\n", + " 'output_tokens_details': None,\n", + " 'time': 0.614}\n", + "DEEPSEEK REASONER\n", + "{ 'content': 'SELECT COUNT(*) FROM '\n", + " 'orders;',\n", + " 'cost_in_cents': '5.55e-02',\n", + " 'input_tokens': 238,\n", + " 'model': 'deepseek-reasoner',\n", + " 'output_tokens': 253,\n", + " 'output_tokens_details': { 'accepted_prediction_tokens': None,\n", + " 'audio_tokens': None,\n", + " 'reasoning_tokens': 245,\n", + " 'rejected_prediction_tokens': None},\n", + " 'time': 6.347}\n" + ] + } + ], + "source": [ + "# claude-3.5-sonnet\n", + "resp = await chat_async(\n", + " model=\"claude-3-5-sonnet-latest\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + ")\n", + "print(\"CLAUDE 3.5 SONNET\")\n", + "pretty_print_llm_response(resp)\n", + "\n", + "# gemini-2.0-flash-exp\n", + "resp = await chat_async(\n", + " model=\"gemini-2.0-flash-exp\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + ")\n", + "print(\"GEMINI 2.0 FLASH EXP\")\n", + "pretty_print_llm_response(resp)\n", + "\n", + "# deepseek-reasoner\n", + "resp = await chat_async(\n", + " model=\"deepseek-reasoner\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + ")\n", + "print(\"DEEPSEEK REASONER\")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now, let's try the same thing but with structured outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "GPT4o\n", + "{ 'content': { 'reasoning': 'To find the total number of orders, we need to count the number of '\n", + " \"rows in the 'orders' table. The SQL COUNT function is used to count \"\n", + " 'the number of rows in a table. Since we want the total number of '\n", + " \"orders, we can use COUNT(*) to count all rows in the 'orders' table.\",\n", + " 'sql': 'SELECT COUNT(*) FROM orders;'},\n", + " 'cost_in_cents': '7.88e-02',\n", + " 'input_tokens': 310,\n", + " 'model': 'gpt-4o',\n", + " 'output_tokens': 78,\n", + " 'output_tokens_details': { 'accepted_prediction_tokens': 0,\n", + " 'audio_tokens': 0,\n", + " 'reasoning_tokens': 0,\n", + " 'rejected_prediction_tokens': 0},\n", + " 'time': 2.486}\n" + ] + } + ], + "source": [ + "from pydantic import BaseModel\n", + "\n", + "class ResponseFormat(BaseModel):\n", + " reasoning: str\n", + " sql: str\n", + "\n", + "resp = await chat_async(\n", + " model=\"gpt-4o\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + " response_format=ResponseFormat,\n", + ")\n", + "print(\"GPT4o\")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Gemini 2.0 Flash\n", + "{ 'content': { 'reasoning': 'To find the total number of orders, I need to count the number of '\n", + " \"rows in the 'orders' table.\",\n", + " 'sql': 'SELECT count(*) FROM orders'},\n", + " 'cost_in_cents': '1.31e-03',\n", + " 'input_tokens': 278,\n", + " 'model': 'gemini-2.0-flash-exp',\n", + " 'output_tokens': 43,\n", + " 'output_tokens_details': None,\n", + " 'time': 1.918}\n" + ] + } + ], + "source": [ + "from pydantic import BaseModel\n", + "\n", + "class ResponseFormat(BaseModel):\n", + " reasoning: str\n", + " sql: str\n", + "\n", + "resp = await chat_async(\n", + " model=\"gemini-2.0-flash-exp\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + " response_format=ResponseFormat,\n", + ")\n", + "print(\"Gemini 2.0 Flash\")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Gemini 2.0 Flash\n", + "{ 'content': { 'reasoning': 'We need to find the total count of all orders in the orders table, so '\n", + " 'we use COUNT(*) on that table.',\n", + " 'sql': 'SELECT COUNT(*) FROM orders'},\n", + " 'cost_in_cents': '2.58e+00',\n", + " 'input_tokens': 305,\n", + " 'model': 'o1',\n", + " 'output_tokens': 430,\n", + " 'output_tokens_details': { 'accepted_prediction_tokens': 0,\n", + " 'audio_tokens': 0,\n", + " 'reasoning_tokens': 384,\n", + " 'rejected_prediction_tokens': 0},\n", + " 'time': 12.532}\n" + ] + } + ], + "source": [ + "from pydantic import BaseModel\n", + "\n", + "class ResponseFormat(BaseModel):\n", + " reasoning: str\n", + " sql: str\n", + "\n", + "resp = await chat_async(\n", + " model=\"o1\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + " response_format=ResponseFormat,\n", + ")\n", + "print(\"Gemini 2.0 Flash\")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "o1-mini\n", + "{ 'content': 'SELECT COUNT(*) FROM orders;',\n", + " 'cost_in_cents': '2.14e-04',\n", + " 'input_tokens': 332,\n", + " 'model': 'deepseek-chat',\n", + " 'output_tokens': 6,\n", + " 'output_tokens_details': None,\n", + " 'time': 1.644}\n" + ] + } + ], + "source": [ + "# if you try a model that does not support response_format\n", + "# it will just return the content as a string\n", + "\n", + "from pydantic import BaseModel\n", + "\n", + "class ResponseFormat(BaseModel):\n", + " reasoning: str\n", + " sql: str\n", + "\n", + "resp = await chat_async(\n", + " model=\"o1-mini\",\n", + " messages=messages,\n", + " max_completion_tokens=4000,\n", + " temperature=0.0,\n", + " seed=0,\n", + " response_format=ResponseFormat,\n", + ")\n", + "print(\"o1-mini\")\n", + "pretty_print_llm_response(resp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}