Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Towards multi-image #3510

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 52 additions & 19 deletions fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,12 +365,16 @@ def to_gradio_chatbot(self):
if i % 2 == 0:
if type(msg) is tuple:
msg, images = msg
image = images[0] # Only one image on gradio at one time
if image.image_format == ImageFormat.URL:
img_str = f'<img src="{image.url}" alt="user upload image" />'
elif image.image_format == ImageFormat.BYTES:
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()
combined_image_str = ""
for image in images:
if image.image_format == ImageFormat.URL:
img_str = (
f'<img src="{image.url}" alt="user upload image" />'
)
elif image.image_format == ImageFormat.BYTES:
img_str = f'<img src="data:image/{image.filetype};base64,{image.base64_str}" alt="user upload image" />'
combined_image_str += img_str
msg = combined_image_str + msg.replace("<image>\n", "").strip()

ret.append([msg, None])
else:
Expand Down Expand Up @@ -524,30 +528,55 @@ def to_anthropic_vision_api_messages(self):

def to_reka_api_messages(self):
from fastchat.serve.vision.image import ImageFormat
from reka import ChatMessage, TypedMediaContent, TypedText

ret = []
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) == tuple:
text, images = msg
for image in images:
if image.image_format == ImageFormat.URL:
ret.append(
{"type": "human", "text": text, "media_url": image.url}
)
elif image.image_format == ImageFormat.BYTES:
if image.image_format == ImageFormat.BYTES:
ret.append(
{
"type": "human",
"text": text,
"media_url": f"data:image/{image.filetype};base64,{image.base64_str}",
}
ChatMessage(
content=[
TypedText(
type="text",
text=text,
),
TypedMediaContent(
type="image_url",
image_url=f"data:image/{image.filetype};base64,{image.base64_str}",
),
],
role="user",
)
)
else:
ret.append({"type": "human", "text": msg})
ret.append(
ChatMessage(
content=[
TypedText(
type="text",
text=msg,
)
],
role="user",
)
)
else:
if msg is not None:
ret.append({"type": "model", "text": msg})
ret.append(
ChatMessage(
content=[
TypedText(
type="text",
text=msg,
)
],
role="assistant",
)
)

return ret

Expand All @@ -557,7 +586,11 @@ def save_new_images(self, has_csam_images=False, use_remote_storage=False):
from fastchat.utils import load_image, upload_image_file_to_gcs
from PIL import Image

_, last_user_message = self.messages[-2]
last_user_message = None
for role, message in reversed(self.messages):
if role == "user":
last_user_message = message
break

if type(last_user_message) == tuple:
text, images = last_user_message[0], last_user_message[1]
Expand Down
46 changes: 22 additions & 24 deletions fastchat/serve/api_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,8 +1076,13 @@ def reka_api_stream_iter(
api_key: Optional[str] = None, # default is env var CO_API_KEY
api_base: Optional[str] = None,
):
from reka.client import Reka
from reka import TypedText

api_key = api_key or os.environ["REKA_API_KEY"]

client = Reka(api_key=api_key)

use_search_engine = False
if "-online" in model_name:
model_name = model_name.replace("-online", "")
Expand All @@ -1094,34 +1099,27 @@ def reka_api_stream_iter(

# Make requests for logging
text_messages = []
for message in messages:
text_messages.append({"type": message["type"], "text": message["text"]})
for turn in messages:
for message in turn.content:
if isinstance(message, TypedText):
text_messages.append({"type": message.type, "text": message.text})
logged_request = dict(request)
logged_request["conversation_history"] = text_messages

logger.info(f"==== request ====\n{logged_request}")

response = requests.post(
api_base,
stream=True,
json=request,
headers={
"X-Api-Key": api_key,
},
response = client.chat.create_stream(
messages=messages,
max_tokens=max_new_tokens,
top_p=top_p,
model=model_name,
)

if response.status_code != 200:
error_message = response.text
logger.error(f"==== error from reka api: {error_message} ====")
yield {
"text": f"**API REQUEST ERROR** Reason: {error_message}",
"error_code": 1,
}
return

for line in response.iter_lines():
line = line.decode("utf8")
if not line.startswith("data: "):
continue
gen = json.loads(line[6:])
yield {"text": gen["text"], "error_code": 0}
for chunk in response:
try:
yield {"text": chunk.responses[0].chunk.content, "error_code": 0}
except:
yield {
"text": f"**API REQUEST ERROR** ",
"error_code": 1,
}
64 changes: 53 additions & 11 deletions fastchat/serve/gradio_block_arena_anony.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,27 @@
acknowledgment_md,
get_ip,
get_model_description_md,
_write_to_json,
)
from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.utils import (
build_logger,
moderation_filter,
)

logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")

num_sides = 2
enable_moderation = False
use_remote_storage = False
anony_names = ["", ""]
models = []


def set_global_vars_anony(enable_moderation_):
global enable_moderation
def set_global_vars_anony(enable_moderation_, use_remote_storage_):
global enable_moderation, use_remote_storage
enable_moderation = enable_moderation_
use_remote_storage = use_remote_storage_


def load_demo_side_by_side_anony(models_, url_params):
Expand Down Expand Up @@ -202,6 +205,9 @@ def get_battle_pair(
if len(models) == 1:
return models[0], models[0]

if len(models) == 0:
raise ValueError("There are no models provided. Cannot get battle pair.")

model_weights = []
for model in models:
weight = get_sample_weight(
Expand Down Expand Up @@ -290,7 +296,11 @@ def add_text(
all_conv_text = (
all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text
)
flagged = moderation_filter(all_conv_text, model_list, do_moderation=True)

content_moderator = AzureAndOpenAIContentModerator()
flagged = content_moderator.text_moderation_filter(
all_conv_text, model_list, do_moderation=True
)
if flagged:
logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
# overwrite the original text
Expand Down Expand Up @@ -343,18 +353,50 @@ def bot_response_multi(
request: gr.Request,
):
logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
states = [state0, state1]

if states[0] is None or states[0].skip_next:
if (
states[0].content_moderator.text_flagged
or states[0].content_moderator.nsfw_flagged
):
for i in range(num_sides):
# This generate call is skipped due to invalid inputs
start_tstamp = time.time()
finish_tstamp = start_tstamp
states[i].conv.save_new_images(
has_csam_images=states[i].has_csam_image,
use_remote_storage=use_remote_storage,
)

filename = get_conv_log_filename(
is_vision=states[i].is_vision,
has_csam_image=states[i].has_csam_image,
)

_write_to_json(
filename,
start_tstamp,
finish_tstamp,
states[i],
temperature,
top_p,
max_new_tokens,
request,
)

# Remove the last message: the user input
states[i].conv.messages.pop()
states[i].content_moderator.update_last_moderation_response(None)

if state0 is None or state0.skip_next:
# This generate call is skipped due to invalid inputs
yield (
state0,
state1,
state0.to_gradio_chatbot(),
state1.to_gradio_chatbot(),
states[0],
states[1],
states[0].to_gradio_chatbot(),
states[1].to_gradio_chatbot(),
) + (no_change_btn,) * 6
return

states = [state0, state1]
gen = []
for i in range(num_sides):
gen.append(
Expand Down
Loading
Loading