Skip to content

Commit

Permalink
update llm initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
allenanie committed Dec 10, 2024
1 parent 24099f8 commit cf39af6
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 1 deletion.
108 changes: 108 additions & 0 deletions docs/examples/basic/greeting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,114 @@
],
"id": "72b76d44a5423795"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Add API keys for LLM calls. Run the code below:",
"id": "88243c6b69d0c2ad"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-12-10T00:10:08.564966Z",
"start_time": "2024-12-10T00:10:08.520705Z"
}
},
"cell_type": "code",
"source": [
"import os\n",
"import ipywidgets as widgets\n",
"from IPython.display import display\n",
"\n",
"# Function to save the environment variable and API key\n",
"def save_env_variable(env_name, api_key):\n",
" # Validate inputs\n",
" if not env_name.strip():\n",
" print(\"⚠️ Environment variable name cannot be empty.\")\n",
" return\n",
" if not api_key.strip():\n",
" print(\"⚠️ API key cannot be empty.\")\n",
" return\n",
" \n",
" # Store the API key as an environment variable\n",
" os.environ[env_name] = api_key\n",
" globals()[env_name] = api_key # Set it as a global variable\n",
" print(f\"✅ API key has been set for environment variable: {env_name}\")\n",
"\n",
"# Create the input widgets\n",
"env_name_input = widgets.Text(\n",
" value=\"OPENAI_API_KEY\", # Default value\n",
" description=\"Env Name:\",\n",
" placeholder=\"Enter env variable name (e.g., MY_API_KEY)\",\n",
")\n",
"\n",
"api_key_input = widgets.Password(\n",
" description=\"API Key:\",\n",
" placeholder=\"Enter your API key\",\n",
")\n",
"\n",
"# Create the button to submit the inputs\n",
"submit_button = widgets.Button(description=\"Set API Key\")\n",
"\n",
"# Display the widgets\n",
"display(env_name_input, api_key_input, submit_button)\n",
"\n",
"# Callback function for the button click\n",
"def on_button_click(b):\n",
" env_name = env_name_input.value\n",
" api_key = api_key_input.value\n",
" save_env_variable(env_name, api_key)\n",
"\n",
"# Attach the callback to the button\n",
"submit_button.on_click(on_button_click)"
],
"id": "3242fb533b7cb3f4",
"outputs": [
{
"data": {
"text/plain": [
"Text(value='OPENAI_API_KEY', description='Env Name:', placeholder='Enter env variable name (e.g., MY_API_KEY)'…"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "1bd6aa77089941b6bf1387d59df773d2"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Password(description='API Key:', placeholder='Enter your API key')"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "2c985d3f3ddd439bb6366c58833af31c"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Button(description='Set API Key', style=ButtonStyle())"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "29026f7b286643a7bd31f4b2ac0533ff"
}
},
"metadata": {},
"output_type": "display_data"
}
],
"execution_count": 1
},
{
"metadata": {},
"cell_type": "markdown",
Expand Down
24 changes: 23 additions & 1 deletion opto/utils/llm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List, Tuple, Dict, Any, Callable, Union
import os
import time
import json
import autogen # We import autogen here to avoid the need of installing autogen

class AbstractModel:
Expand Down Expand Up @@ -47,7 +49,12 @@ class AutoGenLLM(AbstractModel):

def __init__(self, config_list: List = None, filter_dict: Dict = None, reset_freq: Union[int, None] = None) -> None:
if config_list is None:
config_list = autogen.config_list_from_json("OAI_CONFIG_LIST")
try:
config_list = autogen.config_list_from_json("OAI_CONFIG_LIST")
except:
config_list = auto_construct_oai_config_list_from_env()
os.environ.update({"OAI_CONFIG_LIST": json.dumps(config_list)})
config_list = autogen.config_list_from_json("OAI_CONFIG_LIST")
if filter_dict is not None:
config_list = autogen.filter_config_list(config_list, filter_dict)

Expand Down Expand Up @@ -101,3 +108,18 @@ def yes_or_no_filter(context, response):
- APIError: If any model client create call raises an APIError
"""
return self._model.create(**config)

def auto_construct_oai_config_list_from_env() -> List:
"""
Collect various API keys saved in the environment and return a format like:
[{"model": "gpt-4", "api_key": xxx}, {"model": "claude-3.5-sonnet", "api_key": xxx}]
Note this is a lazy function that defaults to gpt-40 and claude-3.5-sonnet.
If you want to specify your own model, please provide an OAI_CONFIG_LIST in the environment or as a file
"""
config_list = []
if os.environ.get("OPENAI_API_KEY") is not None:
config_list.append({"model": "gpt-4o", "api_key": os.environ.get("OPENAI_API_KEY")})
if os.environ.get("ANTHROPIC_API_KEY") is not None:
config_list.append({"model": "claude-3-5-sonnet-latest", "api_key": os.environ.get("ANTHROPIC_API_KEY")})
return config_list

0 comments on commit cf39af6

Please sign in to comment.