diff --git a/backend/app/agent.py b/backend/app/agent.py index 2658c9b7..eb3bd830 100644 --- a/backend/app/agent.py +++ b/backend/app/agent.py @@ -21,6 +21,7 @@ get_mixtral_fireworks, get_ollama_llm, get_openai_llm, + get_huggingface_textgen_inference_llm, ) from app.retrieval import get_retrieval_executor from app.tools import ( @@ -69,6 +70,7 @@ class AgentType(str, Enum): BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)" GEMINI = "GEMINI" OLLAMA = "Ollama" + HUGGINGFACE_TEXTGEN_INFERENCE = "HuggingFace TextGenInference" DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant." @@ -117,7 +119,11 @@ def get_agent_executor( return get_tools_agent_executor( tools, llm, system_message, interrupt_before_action, CHECKPOINTER ) - + elif agent == AgentType.HUGGINGFACE_TEXTGEN_INFERENCE: + llm = get_huggingface_textgen_inference_llm() + return get_xml_agent_executor( + tools, llm, system_message, interrupt_before_action, CHECKPOINTER + ) else: raise ValueError("Unexpected agent type") @@ -188,6 +194,7 @@ class LLMType(str, Enum): GEMINI = "GEMINI" MIXTRAL = "Mixtral" OLLAMA = "Ollama" + HUGGINGFACE_TEXTGEN_INFERENCE = "HuggingFace TextGenInference" def get_chatbot( @@ -210,6 +217,8 @@ def get_chatbot( llm = get_mixtral_fireworks() elif llm_type == LLMType.OLLAMA: llm = get_ollama_llm() + elif llm_type == LLMType.HUGGINGFACE_TEXTGEN_INFERENCE: + llm = get_huggingface_textgen_inference_llm() else: raise ValueError("Unexpected llm type") return get_chatbot_executor(llm, system_message, CHECKPOINTER) @@ -290,6 +299,8 @@ def __init__( llm = get_mixtral_fireworks() elif llm_type == LLMType.OLLAMA: llm = get_ollama_llm() + elif llm_type == LLMType.HUGGINGFACE_TEXTGEN_INFERENCE: + llm = get_huggingface_textgen_inference_llm() else: raise ValueError("Unexpected llm type") chatbot = get_retrieval_executor(llm, retriever, system_message, CHECKPOINTER) diff --git a/backend/app/llms.py b/backend/app/llms.py index dc7fbae7..4cbca4ba 100644 --- a/backend/app/llms.py +++ b/backend/app/llms.py @@ -10,9 +10,18 @@ from langchain_community.chat_models.ollama import ChatOllama from langchain_google_vertexai import ChatVertexAI from langchain_openai import AzureChatOpenAI, ChatOpenAI +import huggingface_hub +from langchain.llms import HuggingFaceTextGenInference +from langchain_community.chat_models.huggingface import ChatHuggingFace logger = logging.getLogger(__name__) +def load_env_var(key: str) -> str: + """Load environment variable safely.""" + value = os.getenv(key) + if value is None: + raise ValueError(f"Environment variable {key} is required.") + return value @lru_cache(maxsize=4) def get_openai_llm(gpt_4: bool = False, azure: bool = False): @@ -98,3 +107,31 @@ def get_ollama_llm(): ollama_base_url = "http://localhost:11434" return ChatOllama(model=model_name, base_url=ollama_base_url) + +@lru_cache(maxsize=1) +def get_huggingface_textgen_inference_llm( + max_new_tokens=2048, + top_k=10, + top_p=0.95, + typical_p=0.95, + temperature=0.3, + repetition_penalty=1.1, + streaming=True, + model_id="HuggingFaceH4/zephyr-7b-beta" +): + """Initialize the HuggingFace TextGenInference model with dynamic parameters.""" + huggingface_hub.login(load_env_var("HUGGINGFACE_TOKEN")) + llm = HuggingFaceTextGenInference( + inference_server_url=load_env_var("HUGGINGFACE_INFERENCE_SERVER_URL"), + max_new_tokens=max_new_tokens, + top_k=top_k, + top_p=top_p, + typical_p=typical_p, + temperature=temperature, + repetition_penalty=repetition_penalty, + streaming=streaming, + server_kwargs={ + "headers": {"Content-Type": "application/json"} + }, + ) + return ChatHuggingFace(llm=llm, model_id=model_id) \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 3184bd27..2cc3aa37 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -44,6 +44,9 @@ asyncpg = "^0.29.0" langchain-core = "^0.1.44" pyjwt = {extras = ["crypto"], version = "^2.8.0"} langchain-anthropic = "^0.1.8" +text-generation = "^0.6.1" +transformers = "^4.38.1" +Jinja2 = "^3.1.3" [tool.poetry.group.dev.dependencies] uvicorn = "^0.23.2"