Skip to content

Commit

Permalink
Adding groq support (#71)
Browse files Browse the repository at this point in the history
* adding groq support

* made chatgroq subclass chatopenai

* remove breakpoint
  • Loading branch information
vaibhavnayel authored Jul 23, 2024
1 parent 5cd7cc0 commit 4a5063e
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 0 deletions.
4 changes: 4 additions & 0 deletions textgrad/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,9 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:
from .vllm import ChatVLLM
engine_name = engine_name.replace("vllm-", "")
return ChatVLLM(model_string=engine_name, **kwargs)
elif "groq" in engine_name:
from .groq import ChatGroq
engine_name = engine_name.replace("groq-", "")
return ChatGroq(model_string=engine_name, **kwargs)
else:
raise ValueError(f"Engine {engine_name} not supported")
48 changes: 48 additions & 0 deletions textgrad/engine/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
try:
from groq import Groq
except ImportError:
raise ImportError("If you'd like to use Groq models, please install the groq package by running `pip install groq`, and add 'GROQ_API_KEY' to your environment variables.")

import os
import json
import base64
import platformdirs
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from typing import List, Union

from .base import EngineLM, CachedEngine
from .engine_utils import get_image_type_from_bytes
from .openai import ChatOpenAI


class ChatGroq(ChatOpenAI):
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant1."

def __init__(
self,
model_string: str="groq-llama3-70b-8192",
system_prompt: str=DEFAULT_SYSTEM_PROMPT,
**kwargs):
"""
:param model_string:
:param system_prompt:
:param base_url: Used to support Ollama
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_groq_{model_string}.db")
CachedEngine.__init__(self, cache_path=cache_path)

if os.getenv("GROQ_API_KEY") is None:
raise ValueError("Please set the GROQ_API_KEY environment variable if you'd like to use Groq models.")
self.client = Groq(
api_key=os.getenv("GROQ_API_KEY")
)

self.model_string = model_string
self.system_prompt = system_prompt
assert isinstance(self.system_prompt, str)
self.is_multimodal = False

0 comments on commit 4a5063e

Please sign in to comment.