diff --git a/bin/find-notebooks-to-test.sh b/bin/find-notebooks-to-test.sh index ebaefdc8..cbeca3b2 100755 --- a/bin/find-notebooks-to-test.sh +++ b/bin/find-notebooks-to-test.sh @@ -21,6 +21,7 @@ EXEMPT_NOTEBOOKS=( "notebooks/integrations/gemini/vector-search-gemini-elastic.ipynb" "notebooks/integrations/gemini/qa-langchain-gemini-elasticsearch.ipynb" "notebooks/integrations/openai/openai-KNN-RAG.ipynb" + "notebooks/integrations/openai/function-calling.ipynb" "notebooks/integrations/gemma/rag-gemma-huggingface-elastic.ipynb" "notebooks/integrations/azure-openai/vector-search-azure-openai-elastic.ipynb" "notebooks/enterprise-search/app-search-engine-exporter.ipynb" diff --git a/notebooks/integrations/openai/function-calling.ipynb b/notebooks/integrations/openai/function-calling.ipynb new file mode 100644 index 00000000..0b8b6cd6 --- /dev/null +++ b/notebooks/integrations/openai/function-calling.ipynb @@ -0,0 +1,502 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "eac77545-c31c-4daa-bf5d-1a06f88ba8e8", + "metadata": {}, + "source": [ + "# OpenAI function calling with Elasticsearch" + ] + }, + { + "cell_type": "markdown", + "id": "75f21771-cafe-496f-8aaf-7cc0f14f89ec", + "metadata": {}, + "source": [ + "[Function calling](https://platform.openai.com/docs/guides/function-calling) in OpenAI refers to the capability of AI models to interact with external functions or APIs, allowing them to perform tasks beyond text generation. This feature enables the model to execute code, retrieve information from databases, interact with external services, and more, by calling predefined functions.\n", + "\n", + "In this notebook we’re going to create two function: \n", + "`fetch_from_elasticsearch()` - Fetch data from Elasticsearch using natural language query. \n", + "`weather_report()` - Fetch a weather report for a particular location.\n", + "\n", + "We'll integrate function calling to dynamically determine which function to call based on the user's query and generate the necessary arguments accordingly." + ] + }, + { + "cell_type": "markdown", + "id": "755e60d5-a42a-4001-a3a4-c5fe2da7d51a", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "markdown", + "id": "12d5e15c-5f1d-462a-9b04-83a9d58ad28c", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "### Elastic\n", + "\n", + "Create an [Elastic Cloud deployment](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud) to get all Elastic credentials. \n", + "`ES_API_KEY`: [Create](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#creating-an-api-key) an API key. \n", + "`ES_ENDPOINT`: [Copy](https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#finding-your-cloud-id) endpoint of Elasticsearch.\n", + "\n", + "### Open AI\n", + "\n", + "`OPENAI_API_KEY`: Setup an [Open AI account and create a secret key](https://platform.openai.com/docs/quickstart). \n", + "`GPT_MODEL`: We’re going to use the `gpt-4o` model but you can check [here](https://platform.openai.com/docs/guides/function-calling) which model is being supported for function calling.\n", + "\n", + "### Weather Union\n", + "\n", + "We will use the [Weather Union API](https://weatherunion.com). Which gives the weather reports of different cities in India. \n", + "`WU_API_KEY` & `WU_ENDPOINT` = [Create an account](https://www.weatherunion.com/login) and generate API Key.\n", + "\n", + "### Sample Data\n", + "After creating Elastic cloud deployment, Let’s [add sample flight data](https://www.elastic.co/guide/en/kibana/8.13/get-started.html#gs-get-data-into-kibana) on kibana. Sample data will be stored into the `kibana_sample_data_flights` index.\n", + "\n", + "### Install depndencies\n", + "\n", + "```sh\n", + "pip install openai\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "7295009d-4eb5-4b9f-8ce1-0e8aa7def226", + "metadata": {}, + "source": [ + "## Import packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7477e2b3-8af1-42dd-bba3-53827941714a", + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "from getpass import getpass\n", + "import json\n", + "import requests" + ] + }, + { + "cell_type": "markdown", + "id": "578a47e0-f9d7-49e5-a2b6-c417378392c3", + "metadata": {}, + "source": [ + "## Add Credentials" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d052237f-5874-4f25-b290-5a5f87006536", + "metadata": {}, + "outputs": [], + "source": [ + "OPENAI_API_KEY = getpass(\"OpenAI API Key:\")\n", + "client = OpenAI(\n", + " api_key=OPENAI_API_KEY,\n", + ")\n", + "GPT_MODEL = input(\"GPT Model:\")\n", + "\n", + "ES_API_KEY = getpass(\"Elastic API Key:\")\n", + "ES_ENDPOINT = input(\"Elasticsearch Endpoint:\")\n", + "ES_INDEX = \"kibana_sample_data_flights\"\n", + "\n", + "WU_API_KEY = getpass(\"Weather Union API Key:\")\n", + "WU_API_ENDPOINT = input(\"Weather Union API Endpoint:\")" + ] + }, + { + "cell_type": "markdown", + "id": "a255228d-7994-4bbe-ae53-bb574ea320f8", + "metadata": {}, + "source": [ + "## Functions to get data from Elasticsearch" + ] + }, + { + "cell_type": "markdown", + "id": "06d09f3e-38b3-43c5-a156-5b3096f45d99", + "metadata": {}, + "source": [ + "### Get Index mapping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c67afad-c0a0-4bab-98c8-5f8da582336a", + "metadata": {}, + "outputs": [], + "source": [ + "def get_index_mapping():\n", + "\n", + " url = f\"\"\"{ES_ENDPOINT}/{ES_INDEX}/_mappings\"\"\"\n", + "\n", + " headers = {\n", + " \"Content-type\": \"application/json\",\n", + " \"Authorization\": f\"\"\"ApiKey {ES_API_KEY}\"\"\",\n", + " }\n", + "\n", + " resp = requests.request(\"GET\", url, headers=headers)\n", + " resp = json.loads(resp.text)\n", + " mapping = json.dumps(resp, indent=4)\n", + "\n", + " return mapping" + ] + }, + { + "cell_type": "markdown", + "id": "c37bf4b0-fb96-4cff-9ddc-9aef781e30bb", + "metadata": {}, + "source": [ + "### Get reference document" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0db00590-e417-4751-a0c6-8e5a2f924fa1", + "metadata": {}, + "outputs": [], + "source": [ + "def get_ref_document():\n", + "\n", + " url = f\"\"\"{ES_ENDPOINT}/{ES_INDEX}/_search?size=1\"\"\"\n", + "\n", + " headers = {\n", + " \"Content-type\": \"application/json\",\n", + " \"Authorization\": f\"\"\"ApiKey {ES_API_KEY}\"\"\",\n", + " }\n", + "\n", + " resp = requests.request(\"GET\", url, headers=headers)\n", + " resp = json.loads(resp.text)\n", + " json_resp = json.dumps(resp[\"hits\"][\"hits\"][0], indent=4)\n", + "\n", + " return json_resp" + ] + }, + { + "cell_type": "markdown", + "id": "6dc83930-89c0-465b-9002-e6aa0d3f5e6f", + "metadata": {}, + "source": [ + "### Generate Elasticsearch Query DSL based on user query" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dcff892-6ce7-4c4e-9d4e-78777ef5c7fa", + "metadata": {}, + "outputs": [], + "source": [ + "def build_query(nl_query):\n", + "\n", + " index_mapping = get_index_mapping()\n", + " ref_document = get_ref_document()\n", + "\n", + " prompt = f\"\"\"\n", + " Use below index mapping and reference document to build Elasticsearch query:\n", + "\n", + " Index mapping:\n", + " {index_mapping}\n", + "\n", + " Reference elasticsearch document:\n", + " {ref_document}\n", + "\n", + " Return single line Elasticsearch Query DSL according to index mapping for the below search query. Just return Query DSL without REST specification (e.g. GET, POST etc.) and json markdown format (e.g. ```json):\n", + "\n", + " {nl_query}\n", + "\n", + " If any field has a `keyword` type, Just use field name instead of field.keyword.\n", + " \"\"\"\n", + "\n", + " resp = client.chat.completions.create(\n", + " model=GPT_MODEL,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": prompt,\n", + " }\n", + " ],\n", + " temperature=0,\n", + " )\n", + "\n", + " return resp.choices[0].message.content" + ] + }, + { + "cell_type": "markdown", + "id": "9678a30d-f2d9-457b-8a49-b41b409c8a73", + "metadata": {}, + "source": [ + "### Execute Query on Elasticsearch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "107cb8d5-8899-46ad-b352-7d7cb29035e5", + "metadata": {}, + "outputs": [], + "source": [ + "def fetch_from_elasticsearch(nl_query):\n", + "\n", + " query_dsl = build_query(nl_query)\n", + " # print(f\"\"\"Query DSL: ==== \\n\\n {query_dsl}\"\"\")\n", + "\n", + " url = f\"\"\"{ES_ENDPOINT}/{ES_INDEX}/_search\"\"\"\n", + "\n", + " payload = query_dsl\n", + "\n", + " headers = {\n", + " \"Content-type\": \"application/json\",\n", + " \"Authorization\": f\"\"\"ApiKey {ES_API_KEY}\"\"\",\n", + " }\n", + "\n", + " resp = requests.request(\"GET\", url, headers=headers, data=payload)\n", + " resp = json.loads(resp.text)\n", + " json_resp = json.dumps(resp, indent=4)\n", + "\n", + " # print(f\"\"\"\\n\\nElasticsearch response: ==== \\n\\n {json_resp}\"\"\")\n", + " return json_resp" + ] + }, + { + "cell_type": "markdown", + "id": "a7a792b6-f232-43a3-9186-697dfd06354f", + "metadata": {}, + "source": [ + "## Function to get weather report" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1df33f25-a77b-4402-be51-6f94962c3e54", + "metadata": {}, + "outputs": [], + "source": [ + "def weather_report(latitude, longitude):\n", + "\n", + " url = f\"\"\"{WU_API_ENDPOINT}?latitude={latitude}&longitude={longitude}\"\"\"\n", + "\n", + " headers = {\"x-zomato-api-key\": f\"\"\"{WU_API_KEY}\"\"\"}\n", + "\n", + " resp = requests.request(\"GET\", url, headers=headers)\n", + " resp = json.loads(resp.text)\n", + " json_resp = json.dumps(resp, indent=4)\n", + "\n", + " # print(f\"\"\"\\n\\nWeatherUnion response: ==== \\n\\n {json_resp}\"\"\")\n", + " return json_resp" + ] + }, + { + "cell_type": "markdown", + "id": "cfb9b64c-95a1-431e-8bf9-20ca5470c4ee", + "metadata": {}, + "source": [ + "## Function calling" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2816a1b6-b7c7-45ed-8537-efdee7292a8d", + "metadata": {}, + "outputs": [], + "source": [ + "def run_conversation(query):\n", + "\n", + " all_functions = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"fetch_from_elasticsearch\",\n", + " \"description\": \"All flights/airline related data is stored into Elasticsearch. Call this function if receiving any query around airlines/flights.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"query\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"Exact query string which is asked by user.\",\n", + " }\n", + " },\n", + " \"required\": [\"query\"],\n", + " },\n", + " },\n", + " },\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"weather_report\",\n", + " \"description\": \"It will return weather report in json format for given location co-ordinates.\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"latitude\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The latitude of a location with 0.01 degree\",\n", + " },\n", + " \"longitude\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The longitude of a location with 0.01 degree\",\n", + " },\n", + " },\n", + " \"required\": [\"latitude\", \"longitude\"],\n", + " },\n", + " },\n", + " },\n", + " ]\n", + "\n", + " messages = []\n", + " messages.append(\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"If no data received from any function. Just say there is issue fetching details from function(function_name)\",\n", + " }\n", + " )\n", + "\n", + " messages.append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": query,\n", + " }\n", + " )\n", + "\n", + " response = client.chat.completions.create(\n", + " model=GPT_MODEL,\n", + " messages=messages,\n", + " tools=all_functions,\n", + " tool_choice=\"auto\",\n", + " )\n", + "\n", + " response_message = response.choices[0].message\n", + " tool_calls = response_message.tool_calls\n", + "\n", + " # print(tool_calls)\n", + "\n", + " if tool_calls:\n", + "\n", + " available_functions = {\n", + " \"fetch_from_elasticsearch\": fetch_from_elasticsearch,\n", + " \"weather_report\": weather_report,\n", + " }\n", + " messages.append(response_message)\n", + "\n", + " for tool_call in tool_calls:\n", + "\n", + " function_name = tool_call.function.name\n", + " function_to_call = available_functions[function_name]\n", + " function_args = json.loads(tool_call.function.arguments)\n", + "\n", + " if function_name == \"fetch_from_elasticsearch\":\n", + " function_response = function_to_call(\n", + " nl_query=function_args.get(\"query\"),\n", + " )\n", + "\n", + " if function_name == \"weather_report\":\n", + " function_response = function_to_call(\n", + " latitude=function_args.get(\"latitude\"),\n", + " longitude=function_args.get(\"longitude\"),\n", + " )\n", + "\n", + " # print(function_response)\n", + " messages.append(\n", + " {\n", + " \"tool_call_id\": tool_call.id,\n", + " \"role\": \"tool\",\n", + " \"name\": function_name,\n", + " \"content\": function_response,\n", + " }\n", + " )\n", + "\n", + " second_response = client.chat.completions.create(\n", + " model=GPT_MODEL,\n", + " messages=messages,\n", + " )\n", + "\n", + " return second_response.choices[0].message.content" + ] + }, + { + "cell_type": "markdown", + "id": "c5bd9009-b89c-4675-acae-7e5568213863", + "metadata": {}, + "source": [ + "## Ask Query" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "bbf1a2dc-49f8-4bfd-8317-7f28c5ef4715", + "metadata": {}, + "outputs": [ + { + "name": "stdin", + "output_type": "stream", + "text": [ + "Ask: last 10 flight delay to bangalore, show in table\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Here are the last 10 flight delays to Bangalore:\n", + "\n", + "| Flight Number | Origin City | Carrier | Delay (Minutes) | Delay Type | Date & Time |\n", + "|---------------|-------------|-------------------|------------------|------------------------|----------------------|\n", + "| B2JWDRX | Catania | Kibana Airlines | 60 | Security Delay | 2024-03-03 13:32:15 |\n", + "| C9C7VBY | Frankfurt | Logstash Airways | 285 | Security Delay | 2024-03-01 05:20:01 |\n", + "| 09P9K2Z | Paris | Kibana Airlines | 195 | Late Aircraft Delay | 2024-02-29 06:02:38 |\n", + "| 0FXK4HG | Osaka | Logstash Airways | 195 | Carrier Delay | 2024-02-23 03:34:21 |\n", + "| 5EYOHJR | Genova | ES-Air | 360 | NAS Delay | 2024-02-21 15:51:26 |\n", + "| X5HA5YJ | Bangor | Kibana Airlines | 330 | Weather Delay | 2024-02-19 13:50:58 |\n", + "| 4BZUCXP | Bogota | ES-Air | 30 | Late Aircraft Delay | 2024-02-16 05:59:37 |\n", + "| O8I6UU8 | Catania | ES-Air | 135 | Late Aircraft Delay | 2024-02-09 03:12:49 |\n", + "| 56HYVZQ | Denver | Logstash Airways | 60 | NAS Delay | 2024-02-08 15:52:44 |\n", + "| X4025SP | Paris | Kibana Airlines | 30 | Late Aircraft Delay | 2024-02-08 10:57:29 |\n", + "\n", + "If you need more details or further assistance, please let me know.\n" + ] + } + ], + "source": [ + "i = input(\"Ask:\")\n", + "answer = run_conversation(i)\n", + "print(answer)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}