Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Adding function calling to AnyscaleLLM #227

Merged
merged 11 commits into from
Jan 21, 2024
6 changes: 5 additions & 1 deletion config/anyscale.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ chat_engine:
# Since Anyscale's LLM endpoint currently doesn't support function calling, we will use the LastMessageQueryGenerator
# --------------------------------------------------------------------
query_builder:
type: LastMessageQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator]
type: FunctionCallingQueryGenerator # Options: [FunctionCallingQueryGenerator, LastMessageQueryGenerator]
llm:
type: AnyscaleLLM # Options: [OpenAILLM, AnyscaleLLM]
params:
model_name: mistralai/Mistral-7B-Instruct-v0.1

# -------------------------------------------------------------------------------------------------------------
# ContextEngine configuration
Expand Down
25 changes: 0 additions & 25 deletions src/canopy/llm/anyscale.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
from typing import Union, Iterable, Optional, Any, List
import jsonschema
import json
import os
from tenacity import (
retry,
stop_after_attempt,
retry_if_exception_type,
)
from canopy.llm import OpenAILLM
from canopy.llm.models import Function
from canopy.models.api_models import ChatResponse, StreamingChatChunk
from canopy.models.data_models import Messages, Query

Expand Down Expand Up @@ -41,23 +33,6 @@ def __init__(
ae_base_url = base_url
super().__init__(model_name, api_key=ae_api_key, base_url=ae_base_url, **kwargs)

@retry(
reraise=True,
stop=stop_after_attempt(3),
retry=retry_if_exception_type(
(json.decoder.JSONDecodeError, jsonschema.ValidationError)
),
)
def enforced_function_call(
self,
messages: Messages,
function: Function,
*,
max_tokens: Optional[int] = None,
model_params: Optional[dict] = None,
) -> dict:
raise NotImplementedError()

async def achat_completion(
self,
messages: Messages,
Expand Down
1 change: 1 addition & 0 deletions src/canopy/llm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class FunctionArrayProperty(BaseModel):
def dict(self, *args, **kwargs):
super_dict = super().dict(*args, **kwargs)
if "items_type" in super_dict:
super_dict["type"] = "array"
super_dict["items"] = {"type": super_dict.pop("items_type")}
return super_dict

Expand Down
71 changes: 64 additions & 7 deletions tests/system/llm/test_anyscale.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kylehh we've changed the format of test_openai.py in a manner where it parameterizes all LLM classes that inherit from OpenAILLM (like AnyscaleLLM), and runs the same test cases for all of these classes.
This way - we don't need a dedicated test_anyscale.py file which simply duplicates the entire test cases

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from canopy.models.data_models import Role, MessageBase # noqa
from canopy.models.api_models import ChatResponse, StreamingChatChunk # noqa
from canopy.llm.anyscale import AnyscaleLLM # noqa
from canopy.llm.models import (
Function,
FunctionParameters,
FunctionArrayProperty,
) # noqa
from openai import BadRequestError # noqa


Expand All @@ -31,29 +36,44 @@ class TestAnyscaleLLM:
@staticmethod
@pytest.fixture
def model_name():
return "meta-llama/Llama-2-7b-chat-hf"
return "mistralai/Mistral-7B-Instruct-v0.1"

@staticmethod
@pytest.fixture
def messages():
# Create a list of MessageBase objects
return [
MessageBase(role=Role.USER, content="Hello, assistant."),
MessageBase(
role=Role.ASSISTANT, content="Hello, user. How can I assist you?"
),
MessageBase(role=Role.SYSTEM, content="You are a helpful AI assistant."),
MessageBase(role=Role.USER, content="Hello, assistant. "),
]

@staticmethod
@pytest.fixture
def function_query_knowledgebase():
return Function(
name="query_knowledgebase",
description="Query search engine for relevant information",
parameters=FunctionParameters(
required_properties=[
FunctionArrayProperty(
name="queries",
items_type="string",
description="List of queries to send to the search engine.",
),
]
),
)

@staticmethod
@pytest.fixture
def model_params_high_temperature():
# `n` parameter is not supported yet. set to 1 always
return {"temperature": 0.9, "top_p": 0.95, "n": 1}
return {"temperature": 0.9, "n": 1}

@staticmethod
@pytest.fixture
def model_params_low_temperature():
return {"temperature": 0.2, "top_p": 0.5, "n": 1}
return {"temperature": 0.2, "n": 1}

@staticmethod
@pytest.fixture
Expand Down Expand Up @@ -81,6 +101,15 @@ def test_chat_completion(anyscale_llm, messages):
response = anyscale_llm.chat_completion(messages=messages)
assert_chat_completion(response)

@staticmethod
def test_enforced_function_call(
anyscale_llm, messages, function_query_knowledgebase
):
result = anyscale_llm.enforced_function_call(
messages=messages, function=function_query_knowledgebase
)
assert_function_call_format(result)

@staticmethod
def test_chat_completion_high_temperature(
anyscale_llm, messages, model_params_high_temperature
Expand All @@ -99,6 +128,34 @@ def test_chat_completion_low_temperature(
)
assert_chat_completion(response, num_choices=model_params_low_temperature["n"])

@staticmethod
def test_enforced_function_call_high_temperature(
anyscale_llm,
messages,
function_query_knowledgebase,
model_params_high_temperature,
):
result = anyscale_llm.enforced_function_call(
messages=messages,
function=function_query_knowledgebase,
model_params=model_params_high_temperature,
)
assert_function_call_format(result)

@staticmethod
def test_enforced_function_call_low_temperature(
anyscale_llm,
messages,
function_query_knowledgebase,
model_params_low_temperature,
):
result = anyscale_llm.enforced_function_call(
messages=messages,
function=function_query_knowledgebase,
model_params=model_params_low_temperature,
)
assert_function_call_format(result)

@staticmethod
def test_chat_streaming(anyscale_llm, messages):
stream = True
Expand Down