From 7eeb57e44aa187381d732edfcc2cb8948bd3d592 Mon Sep 17 00:00:00 2001 From: Nicolas van Kempen Date: Wed, 4 Oct 2023 19:03:58 +0100 Subject: [PATCH] Add fallback behavior --- README.md | 2 +- src/cwhy/__main__.py | 4 ++-- src/cwhy/cwhy.py | 13 +++++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d054abe..63b2148 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Explains and suggests fixes for compiler error messages for a wide range of prog > > CWhy needs to be connected to an [OpenAI account](https://openai.com/api/). _Your account will need to have a positive balance for this to work_ ([check your balance](https://platform.openai.com/account/usage)). [Get a key here.](https://platform.openai.com/account/api-keys) > -> CWhy currently uses GPT-3.5 as its default model. If you want to use the newest and best model (GPT-4), you need to have purchased at least $1 in credits (if your API account was created before August 13, 2023) or $0.50 (if you have a newer API account). +> CWhy currently uses GPT-3.5 as its default model. If you want to use the newest and best model (GPT-4), you need to have purchased at least $1 in credits (if your API account was created before August 13, 2023) or $0.50 (if you have a newer API account). > > Once you have an API key, set it as an environment variable called `OPENAI_API_KEY`. > diff --git a/src/cwhy/__main__.py b/src/cwhy/__main__.py index bd2251f..8d1980c 100755 --- a/src/cwhy/__main__.py +++ b/src/cwhy/__main__.py @@ -35,8 +35,8 @@ def main(): parser.add_argument( "--llm", type=str, - default="gpt-3.5-turbo", - help="the language model to use, e.g., 'gpt-3.5-turbo' or 'gpt-4' (default: gpt-3.5-turbo)", + default="default", + help="the language model to use, e.g., 'gpt-3.5-turbo' or 'gpt-4' (default: 'default', which tries gpt-4 and falls back to gpt-3.5-turbo)", ) parser.add_argument( "--timeout", diff --git a/src/cwhy/cwhy.py b/src/cwhy/cwhy.py index 69e8e33..0a040cd 100755 --- a/src/cwhy/cwhy.py +++ b/src/cwhy/cwhy.py @@ -170,8 +170,21 @@ def evaluate_diff(args, stdin): return completion +def evaluate_with_fallback(args, stdin): + DEFAULT_FALLBACK_MODELS = ["gpt-4", "gpt-3.5-turbo"] + for i, model in enumerate(DEFAULT_FALLBACK_MODELS): + if i != 0: + print(f"Falling back to {model}...") + args["llm"] = model + try: + return evaluate(args, stdin) + except openai.error.InvalidRequestError as e: + print(e) def evaluate(args, stdin): + if args["llm"] == "default": + return evaluate_with_fallback(args, stdin) + if args["subcommand"] == "explain": return evaluate_text_prompt(args, explain_prompt(args, stdin)) elif args["subcommand"] == "fix":