Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

p2l stuff #3660

Merged
merged 3 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2489,7 +2489,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation:

class NoSystemAdapter(BaseModelAdapter):
def match(self, model_path: str):
keyword_list = ["athene-70b"]
keyword_list = ["athene-70b", "p2l"]
efrick2002 marked this conversation as resolved.
Show resolved Hide resolved

for keyword in keyword_list:
if keyword == model_path.lower():
Expand Down
76 changes: 76 additions & 0 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,17 @@ def get_api_provider_stream_iter(
api_key=model_api_dict["api_key"],
conversation_id=state.conv_id,
)
elif model_api_dict["api_type"] == "p2l":
prompt = conv.to_openai_api_messages()
stream_iter = p2l_api_stream_iter(
model_api_dict["model_name"],
prompt,
temperature,
top_p,
max_new_tokens,
api_base=model_api_dict["api_base"],
api_key=model_api_dict["api_key"],
)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -412,6 +423,71 @@ def column_api_stream_iter(
}


def p2l_api_stream_iter(
model_name,
messages,
temperature,
top_p,
max_new_tokens,
api_base=None,
api_key=None,
):
import openai

client = openai.OpenAI(
base_url=api_base,
api_key=api_key or "-",
timeout=180,
)

# Make requests for logging
text_messages = []
for message in messages:
if type(message["content"]) == str: # text-only model
text_messages.append(message)
else: # vision model
filtered_content_list = [
content for content in message["content"] if content["type"] == "text"
]
text_messages.append(
{"role": message["role"], "content": filtered_content_list}
)

gen_params = {
"model": model_name,
"prompt": text_messages,
"temperature": None,
"top_p": None,
"max_new_tokens": max_new_tokens,
}
logger.info(f"==== request ====\n{gen_params}")

res = client.chat.completions.create(
model=model_name,
messages=messages,
max_tokens=max_new_tokens,
stream=True,
)
text = ""
for chunk_idx, chunk in enumerate(res):
if len(chunk.choices) > 0:
text += chunk.choices[0].delta.content or ""

data = {
"text": text,
"error_code": 0,
}

if chunk_idx == 0:
if hasattr(chunk.choices[0].delta, "model"):
data["ans_model"] = chunk.choices[0].delta.model

if hasattr(chunk, "router_outputs"):
data["router_outputs"] = chunk.router_outputs

yield data


def upload_openai_file_to_gcs(file_id):
import openai
from google.cloud import storage
Expand Down
38 changes: 36 additions & 2 deletions fastchat/serve/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import random
import time
import uuid
from typing import List
from typing import List, Dict

import gradio as gr
import requests
Expand Down Expand Up @@ -119,6 +119,8 @@ def __init__(self, model_name, is_vision=False):
self.model_name = model_name
self.oai_thread_id = None
self.is_vision = is_vision
self.ans_models = []
self.router_outputs = []
efrick2002 marked this conversation as resolved.
Show resolved Hide resolved

# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
self.has_csam_image = False
Expand All @@ -128,6 +130,12 @@ def __init__(self, model_name, is_vision=False):
self.regen_support = False
self.init_system_prompt(self.conv, is_vision)

def update_ans_models(self, ans: str) -> None:
self.ans_models.append(ans)

def update_router_outputs(self, outputs: Dict[str, float]) -> None:
self.router_outputs.append(outputs)

def init_system_prompt(self, conv, is_vision):
system_prompt = conv.get_system_message(is_vision)
if len(system_prompt) == 0:
Expand All @@ -154,6 +162,20 @@ def dict(self):
}
)

if self.ans_models:
base.update(
{
"ans_models": self.ans_models,
}
)

if self.router_outputs:
base.update(
{
"router_outputs": self.router_outputs,
}
)

if self.is_vision:
base.update({"has_csam_image": self.has_csam_image})
return base
Expand Down Expand Up @@ -420,7 +442,7 @@ def is_limit_reached(model_name, ip):


def bot_response(
state,
state: State,
temperature,
top_p,
max_new_tokens,
Expand Down Expand Up @@ -532,6 +554,18 @@ def bot_response(
try:
data = {"text": ""}
for i, data in enumerate(stream_iter):
# Change for P2L:
if i == 0:
if "ans_model" in data:
ans_model = data.get("ans_model")

state.update_ans_models(ans_model)

if "router_outputs" in data:
router_outputs = data.get("router_outputs")

state.update_router_outputs(router_outputs)

if data["error_code"] == 0:
output = data["text"].strip()
conv.update_last_message(output + "▌")
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ dependencies = [
]

[project.optional-dependencies]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf"]
webui = ["gradio>=4.10"]
model_worker = ["accelerate>=0.21", "peft", "sentencepiece", "torch", "transformers>=4.31.0", "protobuf", "openai", "anthropic"]
webui = ["gradio>=4.10", "plotly", "scipy"]
train = ["einops", "flash-attn>=2.0", "wandb"]
llm_judge = ["openai<1", "anthropic>=0.3", "ray"]
dev = ["black==23.3.0", "pylint==2.8.2"]
Expand Down
Loading