Skip to content

Commit

Permalink
Rename TurnMessage to ChatMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
musabgultekin committed Aug 5, 2023
1 parent ad284c3 commit d0ea7c1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
28 changes: 14 additions & 14 deletions functionary/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,50 @@
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM

from functionary.openai_types import FunctionCall, Function, TurnMessage
from functionary.openai_types import FunctionCall, Function, ChatMessage
from functionary.schema import generate_schema_from_functions

SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""


def tokenize(message: TurnMessage, tokenizer: LlamaTokenizer):
def tokenize(message: ChatMessage, tokenizer: LlamaTokenizer):
text = str(message)
return tokenizer(text, add_special_tokens=False, return_tensors="pt").input_ids.to(
"cuda:0"
)


def prepare_messages_for_inference(
tokenizer: LlamaTokenizer, messages: List[TurnMessage], functions=None
):
tokenizer: LlamaTokenizer, messages: List[ChatMessage], functions=None
) -> torch.Tensor:
all_messages = []
if functions is not None:
all_messages.append(
TurnMessage(
ChatMessage(
role="system", content=generate_schema_from_functions(functions)
)
)

all_messages.append(TurnMessage(role="system", content=SYSTEM_MESSAGE))
all_messages.append(ChatMessage(role="system", content=SYSTEM_MESSAGE))

for message in messages:
if message.role == "assistant":
if message:
all_messages.append(
TurnMessage(role="assistant", content=message.content)
ChatMessage(role="assistant", content=message.content)
)
if message.function_call:
fc = message.function_call
all_messages.append(
TurnMessage(
ChatMessage(
role="assistant",
_to=f"functions.{fc.name}",
content=fc.arguments,
)
)
elif message.role == "function":
all_messages.append(
TurnMessage(
ChatMessage(
role="function",
name=f"functions.{message.name}",
content=message.content,
Expand All @@ -55,7 +55,7 @@ def prepare_messages_for_inference(
else:
all_messages.append(message)

all_messages.append(TurnMessage(role="assistant", content=None))
all_messages.append(ChatMessage(role="assistant", content=None))

# ! should this be done as concatting strings and then tokenizing?
# ! >>> text = "".join([str(msg) for msg in all_messages]
Expand All @@ -69,11 +69,11 @@ def prepare_messages_for_inference(
def generate_message(
model: LlamaForCausalLM,
tokenizer: LlamaTokenizer,
messages: List[TurnMessage],
messages: List[ChatMessage],
functions: Optional[List[Function]] = None,
temperature: float = 0.7,
max_new_tokens=256,
) -> TurnMessage:
) -> ChatMessage:
inputs = prepare_messages_for_inference(
tokenizer=tokenizer, messages=messages, functions=functions
)
Expand All @@ -90,11 +90,11 @@ def generate_message(
if generated_content.startswith("to=functions."):
function_call_content = generated_content[len("to=functions.") :]
function_name, arguments = function_call_content.split(":\n")
return TurnMessage(
return ChatMessage(
role="assistant",
function_call=FunctionCall(name=function_name, arguments=arguments),
)
return TurnMessage(
return ChatMessage(
role="assistant",
content=generated_content.lstrip("assistant:\n").rstrip("\n user:\n"),
)
8 changes: 4 additions & 4 deletions functionary/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Function(BaseModel):
parameters: dict


class TurnMessage(BaseModel):
class ChatMessage(BaseModel):
role: str
content: Optional[str] = None
name: Optional[str] = None
Expand Down Expand Up @@ -52,18 +52,18 @@ def __str__(self) -> str:


class ChatInput(BaseModel):
messages: List[TurnMessage]
messages: List[ChatMessage]
functions: Optional[List[Function]] = None
temperature: float = 0.9


class Choice(BaseModel):
message: TurnMessage
message: ChatMessage
finish_reason: str = "stop"
index: int = 0

@classmethod
def from_message(cls, message: TurnMessage):
def from_message(cls, message: ChatMessage):
return cls(message=message)


Expand Down
4 changes: 2 additions & 2 deletions modal_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import modal
from fastapi import FastAPI

from functionary.openai_types import ChatCompletion, ChatInput, Choice, Function, TurnMessage
from functionary.openai_types import ChatCompletion, ChatInput, Choice, Function, ChatMessage
from functionary.inference import generate_message

stub = modal.Stub("functionary")
Expand Down Expand Up @@ -62,7 +62,7 @@ def __enter__(self):
@modal.method()
def generate(
self,
messages: List[TurnMessage],
messages: List[ChatMessage],
functions: List[Function],
temperature: float,
):
Expand Down

0 comments on commit d0ea7c1

Please sign in to comment.