Skip to content

Commit

Permalink
OpenAI-compatible API configuration per assistant (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
kharvd authored Jan 21, 2025
1 parent d61cbeb commit 194ab0b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 7 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ The prefix is stripped before sending the request to the API.

Similarly, use the `oai-azure:` model name prefix to use a model deployed via Azure Open AI. For example, `oai-azure:my-deployment-name`.

With assistant configuration, you can override the base URL and API key for a specific assistant.

```yaml
# ~/.config/gpt-cli/gpt.yml
assistants:
llama:
model: oai-compat:meta-llama/llama-3.3-70b-instruct
openai_base_url_override: https://openrouter.ai/api/v1
openai_api_key_override: $OPENROUTER_API_KEY
```
## Other chat bots
### Anthropic Claude
Expand Down
20 changes: 16 additions & 4 deletions gptcli/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
class AssistantConfig(TypedDict, total=False):
messages: List[Message]
model: str
openai_base_url_override: Optional[str]
openai_api_key_override: Optional[str]
temperature: float
top_p: float

Expand Down Expand Up @@ -66,15 +68,21 @@ class AssistantConfig(TypedDict, total=False):
}


def get_completion_provider(model: str) -> CompletionProvider:
def get_completion_provider(
model: str,
openai_base_url_override: Optional[str] = None,
openai_api_key_override: Optional[str] = None,
) -> CompletionProvider:
if (
model.startswith("gpt")
or model.startswith("ft:gpt")
or model.startswith("oai-compat:")
or model.startswith("chatgpt")
or model.startswith("o1")
):
return OpenAICompletionProvider()
return OpenAICompletionProvider(
openai_base_url_override, openai_api_key_override
)
elif model.startswith("oai-azure:"):
return AzureOpenAICompletionProvider()
elif model.startswith("claude"):
Expand Down Expand Up @@ -112,11 +120,15 @@ def init_messages(self) -> List[Message]:
def _param(self, param: str) -> Any:
# Use the value from the config if exists
# Otherwise, use the default value
return self.config.get(param, CONFIG_DEFAULTS[param])
return self.config.get(param, CONFIG_DEFAULTS.get(param, None))

def complete_chat(self, messages, stream: bool = True) -> Iterator[CompletionEvent]:
model = self._param("model")
completion_provider = get_completion_provider(model)
completion_provider = get_completion_provider(
model,
self._param("openai_base_url_override"),
self._param("openai_api_key_override"),
)
return completion_provider.complete(
messages,
{
Expand Down
9 changes: 6 additions & 3 deletions gptcli/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Iterator, List, Optional, cast
import openai
from openai import AzureOpenAI, OpenAI
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam

from gptcli.completion import (
Expand All @@ -17,8 +17,10 @@


class OpenAICompletionProvider(CompletionProvider):
def __init__(self):
self.client = OpenAI(api_key=openai.api_key, base_url=openai.base_url)
def __init__(self, base_url: Optional[str] = None, api_key: Optional[str] = None):
self.client = OpenAI(
api_key=api_key or openai.api_key, base_url=base_url or openai.base_url
)

def complete(
self, messages: List[Message], args: dict, stream: bool = False
Expand Down Expand Up @@ -135,6 +137,7 @@ def complete(
"response": 12.0 / 1_000_000,
}


def gpt_pricing(model: str) -> Optional[Pricing]:
if model.startswith("gpt-3.5-turbo-16k"):
return GPT_3_5_TURBO_16K_PRICE_PER_TOKEN
Expand Down

0 comments on commit 194ab0b

Please sign in to comment.