diff --git a/mdagent/utils/makellm.py b/mdagent/utils/makellm.py index e884d58f..b9f1be60 100644 --- a/mdagent/utils/makellm.py +++ b/mdagent/utils/makellm.py @@ -1,5 +1,6 @@ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain_openai import ChatOpenAI +from langchain_anthropic import ChatAnthropic def _make_llm(model, temp, verbose): @@ -11,25 +12,14 @@ def _make_llm(model, temp, verbose): streaming=True if verbose else False, callbacks=[StreamingStdOutCallbackHandler()] if verbose else None, ) - elif model.startswith("llama"): - from langchain_fireworks import ChatFireworks - - llm = ChatFireworks( + elif model.startswith("claude"): + llm = ChatAnthropic( temperature=temp, - model_name=f"accounts/fireworks/models/{model}", - request_timeout=1000, + model_name=model, streaming=True if verbose else False, callbacks=[StreamingStdOutCallbackHandler()] if verbose else None, ) - # elif model.startswith("Meta-Llama"): - # from langchain_together import ChatTogether - # llm = ChatTogether( - # temperature=temp, - # model=f"meta-llama/{model}", - # request_timeout=1000, - # streaming=True if verbose else False, - # callbacks=[StreamingStdOutCallbackHandler()] if verbose else None, - # ) + else: raise ValueError(f"Invalid or Unsupported model name: {model}") return llm