Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/plasma-umass/cwhy into main
Browse files Browse the repository at this point in the history
  • Loading branch information
emeryberger committed Feb 5, 2024
2 parents 72c8cfa + 74a5503 commit 37e7d51
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions src/cwhy/cwhy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,19 @@ def print_key_info():
print(" export AWS_REGION_NAME=us-west-2")


# If keys are defined in the environment, we use the appropriate service.
service = None
_DEFAULT_FALLBACK_MODELS = []

with contextlib.suppress(KeyError):
if os.environ["OPENAI_API_KEY"]:
service = "OpenAI"
_DEFAULT_FALLBACK_MODELS = ["openai/gpt-4", "openai/gpt-3.5-turbo"]
with contextlib.suppress(KeyError):
if not _DEFAULT_FALLBACK_MODELS:
if {
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
} <= os.environ.keys():
service = "Bedrock"
_DEFAULT_FALLBACK_MODELS = ["bedrock/anthropic.claude-v2:1"]
if "OPENAI_API_KEY" in os.environ:
_DEFAULT_FALLBACK_MODELS = ["openai/gpt-4", "openai/gpt-3.5-turbo"]
elif {
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
} <= os.environ.keys():
_DEFAULT_FALLBACK_MODELS = ["bedrock/anthropic.claude-v2:1"]
else:
print_key_info()
sys.exit(1)


def complete(args, user_prompt, **kwargs):
Expand Down Expand Up @@ -159,10 +155,6 @@ def evaluate(args, stdin):


def main(args: argparse.Namespace) -> None:
if not service:
print_key_info()
sys.exit(1)

process = subprocess.run(
args.command,
stdout=subprocess.PIPE,
Expand Down Expand Up @@ -203,7 +195,7 @@ def main(args: argparse.Namespace) -> None:
def evaluate_text_prompt(args, prompt, wrap=True, **kwargs):
completion = complete(args, prompt, **kwargs)

msg = f"Analysis from {service}:"
msg = f"Analysis from {args.llm}:"
print(msg)
print("-" * len(msg))
text = completion.choices[0].message.content
Expand Down

0 comments on commit 37e7d51

Please sign in to comment.