diff --git a/fastchat/conversation.py b/fastchat/conversation.py
index af916c97f..21078938e 100644
--- a/fastchat/conversation.py
+++ b/fastchat/conversation.py
@@ -365,12 +365,16 @@ def to_gradio_chatbot(self):
if i % 2 == 0:
if type(msg) is tuple:
msg, images = msg
- image = images[0] # Only one image on gradio at one time
- if image.image_format == ImageFormat.URL:
- img_str = f''
- elif image.image_format == ImageFormat.BYTES:
- img_str = f''
- msg = img_str + msg.replace("\n", "").strip()
+ combined_image_str = ""
+ for image in images:
+ if image.image_format == ImageFormat.URL:
+ img_str = (
+ f''
+ )
+ elif image.image_format == ImageFormat.BYTES:
+ img_str = f''
+ combined_image_str += img_str
+ msg = combined_image_str + msg.replace("\n", "").strip()
ret.append([msg, None])
else:
@@ -524,6 +528,7 @@ def to_anthropic_vision_api_messages(self):
def to_reka_api_messages(self):
from fastchat.serve.vision.image import ImageFormat
+ from reka import ChatMessage, TypedMediaContent, TypedText
ret = []
for i, (_, msg) in enumerate(self.messages[self.offset :]):
@@ -531,23 +536,47 @@ def to_reka_api_messages(self):
if type(msg) == tuple:
text, images = msg
for image in images:
- if image.image_format == ImageFormat.URL:
- ret.append(
- {"type": "human", "text": text, "media_url": image.url}
- )
- elif image.image_format == ImageFormat.BYTES:
+ if image.image_format == ImageFormat.BYTES:
ret.append(
- {
- "type": "human",
- "text": text,
- "media_url": f"data:image/{image.filetype};base64,{image.base64_str}",
- }
+ ChatMessage(
+ content=[
+ TypedText(
+ type="text",
+ text=text,
+ ),
+ TypedMediaContent(
+ type="image_url",
+ image_url=f"data:image/{image.filetype};base64,{image.base64_str}",
+ ),
+ ],
+ role="user",
+ )
)
else:
- ret.append({"type": "human", "text": msg})
+ ret.append(
+ ChatMessage(
+ content=[
+ TypedText(
+ type="text",
+ text=msg,
+ )
+ ],
+ role="user",
+ )
+ )
else:
if msg is not None:
- ret.append({"type": "model", "text": msg})
+ ret.append(
+ ChatMessage(
+ content=[
+ TypedText(
+ type="text",
+ text=msg,
+ )
+ ],
+ role="assistant",
+ )
+ )
return ret
@@ -557,7 +586,11 @@ def save_new_images(self, has_csam_images=False, use_remote_storage=False):
from fastchat.utils import load_image, upload_image_file_to_gcs
from PIL import Image
- _, last_user_message = self.messages[-2]
+ last_user_message = None
+ for role, message in reversed(self.messages):
+ if role == "user":
+ last_user_message = message
+ break
if type(last_user_message) == tuple:
text, images = last_user_message[0], last_user_message[1]
diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py
index a2b7979be..c3c64173b 100644
--- a/fastchat/serve/api_provider.py
+++ b/fastchat/serve/api_provider.py
@@ -1076,8 +1076,13 @@ def reka_api_stream_iter(
api_key: Optional[str] = None, # default is env var CO_API_KEY
api_base: Optional[str] = None,
):
+ from reka.client import Reka
+ from reka import TypedText
+
api_key = api_key or os.environ["REKA_API_KEY"]
+ client = Reka(api_key=api_key)
+
use_search_engine = False
if "-online" in model_name:
model_name = model_name.replace("-online", "")
@@ -1094,34 +1099,27 @@ def reka_api_stream_iter(
# Make requests for logging
text_messages = []
- for message in messages:
- text_messages.append({"type": message["type"], "text": message["text"]})
+ for turn in messages:
+ for message in turn.content:
+ if isinstance(message, TypedText):
+ text_messages.append({"type": message.type, "text": message.text})
logged_request = dict(request)
logged_request["conversation_history"] = text_messages
logger.info(f"==== request ====\n{logged_request}")
- response = requests.post(
- api_base,
- stream=True,
- json=request,
- headers={
- "X-Api-Key": api_key,
- },
+ response = client.chat.create_stream(
+ messages=messages,
+ max_tokens=max_new_tokens,
+ top_p=top_p,
+ model=model_name,
)
- if response.status_code != 200:
- error_message = response.text
- logger.error(f"==== error from reka api: {error_message} ====")
- yield {
- "text": f"**API REQUEST ERROR** Reason: {error_message}",
- "error_code": 1,
- }
- return
-
- for line in response.iter_lines():
- line = line.decode("utf8")
- if not line.startswith("data: "):
- continue
- gen = json.loads(line[6:])
- yield {"text": gen["text"], "error_code": 0}
+ for chunk in response:
+ try:
+ yield {"text": chunk.responses[0].chunk.content, "error_code": 0}
+ except:
+ yield {
+ "text": f"**API REQUEST ERROR** ",
+ "error_code": 1,
+ }
diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py
index dc9b89a0c..bfba8485e 100644
--- a/fastchat/serve/gradio_block_arena_anony.py
+++ b/fastchat/serve/gradio_block_arena_anony.py
@@ -32,24 +32,27 @@
acknowledgment_md,
get_ip,
get_model_description_md,
+ _write_to_json,
)
+from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.utils import (
build_logger,
- moderation_filter,
)
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
num_sides = 2
enable_moderation = False
+use_remote_storage = False
anony_names = ["", ""]
models = []
-def set_global_vars_anony(enable_moderation_):
- global enable_moderation
+def set_global_vars_anony(enable_moderation_, use_remote_storage_):
+ global enable_moderation, use_remote_storage
enable_moderation = enable_moderation_
+ use_remote_storage = use_remote_storage_
def load_demo_side_by_side_anony(models_, url_params):
@@ -202,6 +205,9 @@ def get_battle_pair(
if len(models) == 1:
return models[0], models[0]
+ if len(models) == 0:
+ raise ValueError("There are no models provided. Cannot get battle pair.")
+
model_weights = []
for model in models:
weight = get_sample_weight(
@@ -290,7 +296,11 @@ def add_text(
all_conv_text = (
all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text
)
- flagged = moderation_filter(all_conv_text, model_list, do_moderation=True)
+
+ content_moderator = AzureAndOpenAIContentModerator()
+ flagged = content_moderator.text_moderation_filter(
+ all_conv_text, model_list, do_moderation=True
+ )
if flagged:
logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
# overwrite the original text
@@ -343,18 +353,50 @@ def bot_response_multi(
request: gr.Request,
):
logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
+ states = [state0, state1]
+
+ if states[0] is None or states[0].skip_next:
+ if (
+ states[0].content_moderator.text_flagged
+ or states[0].content_moderator.nsfw_flagged
+ ):
+ for i in range(num_sides):
+ # This generate call is skipped due to invalid inputs
+ start_tstamp = time.time()
+ finish_tstamp = start_tstamp
+ states[i].conv.save_new_images(
+ has_csam_images=states[i].has_csam_image,
+ use_remote_storage=use_remote_storage,
+ )
+
+ filename = get_conv_log_filename(
+ is_vision=states[i].is_vision,
+ has_csam_image=states[i].has_csam_image,
+ )
+
+ _write_to_json(
+ filename,
+ start_tstamp,
+ finish_tstamp,
+ states[i],
+ temperature,
+ top_p,
+ max_new_tokens,
+ request,
+ )
+
+ # Remove the last message: the user input
+ states[i].conv.messages.pop()
+ states[i].content_moderator.update_last_moderation_response(None)
- if state0 is None or state0.skip_next:
- # This generate call is skipped due to invalid inputs
yield (
- state0,
- state1,
- state0.to_gradio_chatbot(),
- state1.to_gradio_chatbot(),
+ states[0],
+ states[1],
+ states[0].to_gradio_chatbot(),
+ states[1].to_gradio_chatbot(),
) + (no_change_btn,) * 6
return
- states = [state0, state1]
gen = []
for i in range(num_sides):
gen.append(
diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py
index 7ee19b041..6a102d0e2 100644
--- a/fastchat/serve/gradio_block_arena_named.py
+++ b/fastchat/serve/gradio_block_arena_named.py
@@ -28,22 +28,25 @@
acknowledgment_md,
get_ip,
get_model_description_md,
+ _write_to_json,
)
+from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.utils import (
build_logger,
- moderation_filter,
)
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
num_sides = 2
enable_moderation = False
+use_remote_storage = False
-def set_global_vars_named(enable_moderation_):
- global enable_moderation
+def set_global_vars_named(enable_moderation_, use_remote_storage_):
+ global enable_moderation, use_remote_storage
enable_moderation = enable_moderation_
+ use_remote_storage = use_remote_storage_
def load_demo_side_by_side_named(models, url_params):
@@ -175,19 +178,27 @@ def add_text(
no_change_btn,
]
* 6
+ + [True]
)
model_list = [states[i].model_name for i in range(num_sides)]
- all_conv_text_left = states[0].conv.get_prompt()
- all_conv_text_right = states[1].conv.get_prompt()
- all_conv_text = (
- all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text
- )
- flagged = moderation_filter(all_conv_text, model_list)
- if flagged:
- logger.info(f"violate moderation (named). ip: {ip}. text: {text}")
- # overwrite the original text
- text = MODERATION_MSG
+ text_flagged = states[0].content_moderator.text_moderation_filter(text, model_list)
+
+ if text_flagged:
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
+ for i in range(num_sides):
+ states[i].skip_next = True
+ gr.Warning(MODERATION_MSG)
+ return (
+ states
+ + [x.to_gradio_chatbot() for x in states]
+ + [""]
+ + [
+ no_change_btn,
+ ]
+ * 6
+ + [True]
+ )
conv = states[0].conv
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
@@ -202,6 +213,7 @@ def add_text(
no_change_btn,
]
* 6
+ + [True]
)
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
@@ -218,6 +230,7 @@ def add_text(
disable_btn,
]
* 6
+ + [False]
)
@@ -231,17 +244,49 @@ def bot_response_multi(
):
logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
- if state0.skip_next:
- # This generate call is skipped due to invalid inputs
+ states = [state0, state1]
+ if states[0].skip_next:
+ if (
+ states[0].content_moderator.text_flagged
+ or states[0].content_moderator.nsfw_flagged
+ ):
+ for i in range(num_sides):
+ # This generate call is skipped due to invalid inputs
+ start_tstamp = time.time()
+ finish_tstamp = start_tstamp
+ states[i].conv.save_new_images(
+ has_csam_images=states[i].has_csam_image,
+ use_remote_storage=use_remote_storage,
+ )
+
+ filename = get_conv_log_filename(
+ is_vision=states[i].is_vision,
+ has_csam_image=states[i].has_csam_image,
+ )
+
+ _write_to_json(
+ filename,
+ start_tstamp,
+ finish_tstamp,
+ states[i],
+ temperature,
+ top_p,
+ max_new_tokens,
+ request,
+ )
+
+ # Remove the last message: the user input
+ states[i].conv.messages.pop()
+ states[i].content_moderator.update_last_moderation_response(None)
+
yield (
- state0,
- state1,
- state0.to_gradio_chatbot(),
- state1.to_gradio_chatbot(),
+ states[0],
+ states[1],
+ states[0].to_gradio_chatbot(),
+ states[1].to_gradio_chatbot(),
) + (no_change_btn,) * 6
return
- states = [state0, state1]
gen = []
for i in range(num_sides):
gen.append(
@@ -296,14 +341,19 @@ def bot_response_multi(
break
-def flash_buttons():
+def flash_buttons(show_vote_buttons: bool = True):
btn_updates = [
[disable_btn] * 4 + [enable_btn] * 2,
[enable_btn] * 6,
]
- for i in range(4):
- yield btn_updates[i % 2]
- time.sleep(0.3)
+
+ if show_vote_buttons:
+ for i in range(4):
+ yield btn_updates[i % 2]
+ time.sleep(0.3)
+ else:
+ yield [no_change_btn] * 4 + [enable_btn] * 2
+ return
def build_side_by_side_ui_named(models):
@@ -323,6 +373,7 @@ def build_side_by_side_ui_named(models):
states = [gr.State() for _ in range(num_sides)]
model_selectors = [None] * num_sides
chatbots = [None] * num_sides
+ dont_show_vote_buttons = gr.State(False)
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
@@ -478,24 +529,24 @@ def build_side_by_side_ui_named(models):
textbox.submit(
add_text,
states + model_selectors + [textbox],
- states + chatbots + [textbox] + btn_list,
+ states + chatbots + [textbox] + btn_list + [dont_show_vote_buttons],
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
states + chatbots + btn_list,
).then(
- flash_buttons, [], btn_list
+ flash_buttons, [dont_show_vote_buttons], btn_list
)
send_btn.click(
add_text,
states + model_selectors + [textbox],
- states + chatbots + [textbox] + btn_list,
+ states + chatbots + [textbox] + btn_list + [dont_show_vote_buttons],
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
states + chatbots + btn_list,
).then(
- flash_buttons, [], btn_list
+ flash_buttons, [dont_show_vote_buttons], btn_list
)
return states + model_selectors
diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py
index 25ff78c08..dadb360e0 100644
--- a/fastchat/serve/gradio_block_arena_vision.py
+++ b/fastchat/serve/gradio_block_arena_vision.py
@@ -10,6 +10,7 @@
import json
import os
import time
+from typing import List, Union
import gradio as gr
from gradio.data_classes import FileData
@@ -27,6 +28,7 @@
from fastchat.model.model_adapter import (
get_conversation_template,
)
+from fastchat.serve.gradio_global_state import Context
from fastchat.serve.gradio_web_server import (
get_model_description_md,
acknowledgment_md,
@@ -37,11 +39,10 @@
get_conv_log_filename,
get_remote_logger,
)
+from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator
from fastchat.serve.vision.image import ImageFormat, Image
from fastchat.utils import (
build_logger,
- moderation_filter,
- image_moderation_filter,
)
logger = build_logger("gradio_web_server", "gradio_web_server.log")
@@ -52,8 +53,16 @@
invisible_btn = gr.Button(interactive=False, visible=False)
visible_image_column = gr.Image(visible=True)
invisible_image_column = gr.Image(visible=False)
-enable_multimodal = gr.MultimodalTextbox(
- interactive=True, visible=True, placeholder="Enter your prompt or add image here"
+enable_multimodal_keep_input = gr.MultimodalTextbox(
+ interactive=True,
+ visible=True,
+ placeholder="Enter your prompt or add image here",
+)
+enable_multimodal_clear_input = gr.MultimodalTextbox(
+ interactive=True,
+ visible=True,
+ placeholder="Enter your prompt or add image here",
+ value={"text": "", "files": []},
)
invisible_text = gr.Textbox(visible=False, value="", interactive=False)
visible_text = gr.Textbox(
@@ -76,10 +85,6 @@ def set_visible_image(textbox):
images = textbox["files"]
if len(images) == 0:
return invisible_image_column
- elif len(images) > 1:
- gr.Warning(
- "We only support single image conversations. Please start a new round if you would like to chat using this image."
- )
return visible_image_column
@@ -138,6 +143,7 @@ def regenerate(state, request: gr.Request):
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
state.conv.update_last_message(None)
+ state.content_moderator.update_last_moderation_response(None)
return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5
@@ -145,14 +151,14 @@ def clear_history(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history. ip: {ip}")
state = None
- return (state, [], None) + (disable_btn,) * 5
+ return (state, [], enable_multimodal_clear_input) + (disable_btn,) * 5
def clear_history_example(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history_example. ip: {ip}")
state = None
- return (state, [], enable_multimodal) + (disable_btn,) * 5
+ return (state, [], enable_multimodal_keep_input) + (disable_btn,) * 5
# TODO(Chris): At some point, we would like this to be a live-reporting feature.
@@ -160,13 +166,32 @@ def report_csam_image(state, image):
pass
-def _prepare_text_with_image(state, text, images, csam_flag):
+def _prepare_text_with_image(
+ state: State, text: str, images: List[Image], context: Context
+):
if len(images) > 0:
- if len(state.conv.get_images()) > 0:
- # reset convo with new image
- state.conv = get_conversation_template(state.model_name)
+ model_supports_multi_image = context.api_endpoint_info[state.model_name].get(
+ "multi_image", False
+ )
+ num_previous_images = len(state.conv.get_images())
+ images_interleaved_with_text_exists_but_model_does_not_support = (
+ num_previous_images > 0 and not model_supports_multi_image
+ )
+ multiple_image_one_turn_but_model_does_not_support = (
+ len(images) > 1 and not model_supports_multi_image
+ )
+ if images_interleaved_with_text_exists_but_model_does_not_support:
+ gr.Warning(
+ f"The model does not support interleaved image/text. We only use the very first image."
+ )
+ return text
+ elif multiple_image_one_turn_but_model_does_not_support:
+ gr.Warning(
+ f"The model does not support multiple images. Only the first image will be used."
+ )
+ return text, [images[0]]
- text = text, [images[0]]
+ text = text, images
return text
@@ -178,45 +203,35 @@ def convert_images_to_conversation_format(images):
MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5
conv_images = []
if len(images) > 0:
- conv_image = Image(url=images[0])
- conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB)
- conv_images.append(conv_image)
+ for image in images:
+ conv_image = Image(url=image)
+ conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB)
+ conv_images.append(conv_image)
return conv_images
-def moderate_input(state, text, all_conv_text, model_list, images, ip):
- text_flagged = moderation_filter(all_conv_text, model_list)
- # flagged = moderation_filter(text, [state.model_name])
- nsfw_flagged, csam_flagged = False, False
- if len(images) > 0:
- nsfw_flagged, csam_flagged = image_moderation_filter(images[0])
-
- image_flagged = nsfw_flagged or csam_flagged
- if text_flagged or image_flagged:
- logger.info(f"violate moderation. ip: {ip}. text: {all_conv_text}")
- if text_flagged and not image_flagged:
- # overwrite the original text
- text = TEXT_MODERATION_MSG
- elif not text_flagged and image_flagged:
- text = IMAGE_MODERATION_MSG
- elif text_flagged and image_flagged:
- text = MODERATION_MSG
-
- if csam_flagged:
- state.has_csam_image = True
- report_csam_image(state, images[0])
-
- return text, image_flagged, csam_flagged
-
+def add_text(state, model_selector, chat_input, context: Context, request: gr.Request):
+ if isinstance(chat_input, dict):
+ text, images = chat_input["text"], chat_input["files"]
+ else:
+ text, images = chat_input, []
-def add_text(state, model_selector, chat_input, request: gr.Request):
- text, images = chat_input["text"], chat_input["files"]
+ if (
+ len(images) > 0
+ and model_selector in context.text_models
+ and model_selector not in context.vision_models
+ ):
+ gr.Warning(f"{model_selector} is a text-only model. Image is ignored.")
+ images = []
ip = get_ip(request)
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
if state is None:
- state = State(model_selector, is_vision=True)
+ if len(images) == 0:
+ state = State(model_selector, is_vision=False)
+ else:
+ state = State(model_selector, is_vision=True)
if len(text) <= 0:
state.skip_next = True
@@ -227,33 +242,69 @@ def add_text(state, model_selector, chat_input, request: gr.Request):
images = convert_images_to_conversation_format(images)
- text, image_flagged, csam_flag = moderate_input(
- state, text, all_conv_text, [state.model_name], images, ip
+ # Use the first state to get the moderation response because this is based on user input so it is independent of the model
+ moderation_type_to_response_map = (
+ state.content_moderator.image_and_text_moderation_filter(
+ images, text, [state.model_name], do_moderation=False
+ )
+ )
+
+ text_flagged, nsfw_flag, csam_flag = (
+ moderation_type_to_response_map["text_moderation"]["flagged"],
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["nsfw_moderation"]
+ ]
+ ),
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["csam_moderation"]
+ ]
+ ),
)
- if image_flagged:
- logger.info(f"image flagged. ip: {ip}. text: {text}")
+ if csam_flag:
+ state.has_csam_image = True
+
+ state.content_moderator.append_moderation_response(moderation_type_to_response_map)
+
+ if text_flagged or nsfw_flag:
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
+ gradio_chatbot_before_user_input = state.to_gradio_chatbot()
+ post_processed_text = _prepare_text_with_image(state, text, images, context)
+ state.conv.append_message(state.conv.roles[0], post_processed_text)
state.skip_next = True
- return (state, state.to_gradio_chatbot(), {"text": IMAGE_MODERATION_MSG}) + (
- no_change_btn,
- ) * 5
+ gr.Warning(MODERATION_MSG)
+ return (
+ state,
+ gradio_chatbot_before_user_input,
+ None,
+ ) + (no_change_btn,) * 5
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
state.skip_next = True
- return (state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}) + (
- no_change_btn,
- ) * 5
+ return (
+ state,
+ state.to_gradio_chatbot(),
+ {"text": CONVERSATION_LIMIT_MSG},
+ ) + (no_change_btn,) * 5
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
- text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag)
+ text = _prepare_text_with_image(state, text, images, context)
state.conv.append_message(state.conv.roles[0], text)
state.conv.append_message(state.conv.roles[1], None)
- return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5
+ return (
+ state,
+ state.to_gradio_chatbot(),
+ None,
+ ) + (disable_btn,) * 5
def build_single_vision_language_model_ui(
- models, add_promotion_links=False, random_questions=None
+ context: Context, add_promotion_links=False, random_questions=None
):
promotion = (
f"""
@@ -275,33 +326,29 @@ def build_single_vision_language_model_ui(
state = gr.State()
gr.Markdown(notice_markdown, elem_id="notice_markdown")
+ text_and_vision_models = list(set(context.text_models + context.vision_models))
+ context_state = gr.State(context)
with gr.Group():
with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(
- choices=models,
- value=models[0] if len(models) > 0 else "",
+ choices=text_and_vision_models,
+ value=text_and_vision_models[0]
+ if len(text_and_vision_models) > 0
+ else "",
interactive=True,
show_label=False,
container=False,
)
with gr.Accordion(
- f"π Expand to see the descriptions of {len(models)} models", open=False
+ f"π Expand to see the descriptions of {len(text_and_vision_models)} models",
+ open=False,
):
- model_description_md = get_model_description_md(models)
+ model_description_md = get_model_description_md(text_and_vision_models)
gr.Markdown(model_description_md, elem_id="model_description_markdown")
with gr.Row():
- textbox = gr.MultimodalTextbox(
- file_types=["image"],
- show_label=False,
- placeholder="Enter your prompt or add image here",
- container=True,
- render=False,
- elem_id="input_box",
- )
-
with gr.Column(scale=2, visible=False) as image_column:
imagebox = gr.Image(
type="pil",
@@ -314,9 +361,13 @@ def build_single_vision_language_model_ui(
)
with gr.Row():
- textbox.render()
- # with gr.Column(scale=1, min_width=50):
- # send_btn = gr.Button(value="Send", variant="primary")
+ multimodal_textbox = gr.MultimodalTextbox(
+ file_types=["image"],
+ show_label=False,
+ placeholder="Enter your prompt or add image here",
+ container=True,
+ elem_id="input_box",
+ )
with gr.Row(elem_id="buttons"):
if random_questions:
@@ -330,22 +381,6 @@ def build_single_vision_language_model_ui(
regenerate_btn = gr.Button(value="π Regenerate", interactive=False)
clear_btn = gr.Button(value="ποΈ Clear", interactive=False)
- cur_dir = os.path.dirname(os.path.abspath(__file__))
-
- examples = gr.Examples(
- examples=[
- {
- "text": "How can I prepare a delicious meal using these ingredients?",
- "files": [f"{cur_dir}/example_images/fridge.jpg"],
- },
- {
- "text": "What might the woman on the right be thinking about?",
- "files": [f"{cur_dir}/example_images/distracted.jpg"],
- },
- ],
- inputs=[textbox],
- )
-
with gr.Accordion("Parameters", open=False) as parameter_row:
temperature = gr.Slider(
minimum=0.0,
@@ -380,40 +415,45 @@ def build_single_vision_language_model_ui(
upvote_btn.click(
upvote_last_response,
[state, model_selector],
- [textbox, upvote_btn, downvote_btn, flag_btn],
+ [multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
downvote_btn.click(
downvote_last_response,
[state, model_selector],
- [textbox, upvote_btn, downvote_btn, flag_btn],
+ [multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
flag_btn.click(
flag_last_response,
[state, model_selector],
- [textbox, upvote_btn, downvote_btn, flag_btn],
+ [multimodal_textbox, upvote_btn, downvote_btn, flag_btn],
)
- regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
+ regenerate_btn.click(
+ regenerate, state, [state, chatbot, multimodal_textbox] + btn_list
+ ).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
[state, chatbot] + btn_list,
)
- clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
+ clear_btn.click(
+ clear_history,
+ None,
+ [state, chatbot, multimodal_textbox] + btn_list,
+ )
model_selector.change(
- clear_history, None, [state, chatbot, textbox] + btn_list
- ).then(set_visible_image, [textbox], [image_column])
- examples.dataset.click(
- clear_history_example, None, [state, chatbot, textbox] + btn_list
- )
+ clear_history,
+ None,
+ [state, chatbot, multimodal_textbox] + btn_list,
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
- textbox.input(add_image, [textbox], [imagebox]).then(
- set_visible_image, [textbox], [image_column]
- ).then(clear_history_example, None, [state, chatbot, textbox] + btn_list)
+ multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
+ set_visible_image, [multimodal_textbox], [image_column]
+ )
- textbox.submit(
+ multimodal_textbox.submit(
add_text,
- [state, model_selector, textbox],
- [state, chatbot, textbox] + btn_list,
+ [state, model_selector, multimodal_textbox, context_state],
+ [state, chatbot, multimodal_textbox] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response,
[state, temperature, top_p, max_output_tokens],
@@ -424,9 +464,7 @@ def build_single_vision_language_model_ui(
random_btn.click(
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
- [textbox, imagebox], # Outputs are textbox and imagebox
- ).then(set_visible_image, [textbox], [image_column]).then(
- clear_history_example, None, [state, chatbot, textbox] + btn_list
- )
+ [multimodal_textbox, imagebox], # Outputs are textbox and imagebox
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
return [state, model_selector]
diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py
index 76bc47329..d24a54723 100644
--- a/fastchat/serve/gradio_block_arena_vision_anony.py
+++ b/fastchat/serve/gradio_block_arena_vision_anony.py
@@ -8,6 +8,7 @@
import gradio as gr
import numpy as np
+from typing import Union
from fastchat.constants import (
TEXT_MODERATION_MSG,
@@ -34,6 +35,9 @@
get_model_description_md,
disable_text,
enable_text,
+ use_remote_storage,
+ show_vote_button,
+ dont_show_vote_button,
)
from fastchat.serve.gradio_block_arena_anony import (
flash_buttons,
@@ -45,7 +49,6 @@
regenerate,
clear_history,
share_click,
- add_text,
bot_response_multi,
set_global_vars_anony,
load_demo_side_by_side_anony,
@@ -60,19 +63,22 @@
set_invisible_image,
set_visible_image,
add_image,
- moderate_input,
- enable_multimodal,
+ enable_multimodal_keep_input,
_prepare_text_with_image,
convert_images_to_conversation_format,
invisible_text,
visible_text,
disable_multimodal,
+ enable_multimodal_clear_input,
)
+from fastchat.serve.moderation.moderator import (
+ BaseContentModerator,
+ AzureAndOpenAIContentModerator,
+)
+from fastchat.serve.gradio_global_state import Context
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.utils import (
build_logger,
- moderation_filter,
- image_moderation_filter,
)
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
@@ -96,6 +102,7 @@
"llava-v1.6-34b": 4,
"reka-core-20240501": 4,
"reka-flash-preview-20240611": 4,
+ "reka-flash": 4,
}
# TODO(chris): Find battle targets that make sense
@@ -115,16 +122,12 @@ def get_vqa_sample():
return (res, path)
-def load_demo_side_by_side_vision_anony(all_text_models, all_vl_models, url_params):
- global text_models, vl_models
- text_models = all_text_models
- vl_models = all_vl_models
-
- states = (None,) * num_sides
- selector_updates = (
+def load_demo_side_by_side_vision_anony():
+ states = [None] * num_sides
+ selector_updates = [
gr.Markdown(visible=True),
gr.Markdown(visible=True),
- )
+ ]
return states + selector_updates
@@ -135,7 +138,7 @@ def clear_history_example(request: gr.Request):
[None] * num_sides
+ [None] * num_sides
+ anony_names
- + [enable_multimodal, invisible_text]
+ + [enable_multimodal_keep_input]
+ [invisible_btn] * 4
+ [disable_btn] * 2
+ [enable_btn]
@@ -221,6 +224,7 @@ def regenerate(state0, state1, request: gr.Request):
if state0.regen_support and state1.regen_support:
for i in range(num_sides):
states[i].conv.update_last_message(None)
+ states[i].content_moderator.update_last_moderation_response(None)
return (
states
+ [x.to_gradio_chatbot() for x in states]
@@ -240,7 +244,7 @@ def clear_history(request: gr.Request):
[None] * num_sides
+ [None] * num_sides
+ anony_names
- + [enable_multimodal, invisible_text]
+ + [enable_multimodal_clear_input]
+ [invisible_btn] * 4
+ [disable_btn] * 2
+ [enable_btn]
@@ -249,7 +253,13 @@ def clear_history(request: gr.Request):
def add_text(
- state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request
+ state0,
+ state1,
+ model_selector0,
+ model_selector1,
+ chat_input: Union[str, dict],
+ context: Context,
+ request: gr.Request,
):
if isinstance(chat_input, dict):
text, images = chat_input["text"], chat_input["files"]
@@ -268,7 +278,7 @@ def add_text(
if len(images) > 0:
model_left, model_right = get_battle_pair(
- vl_models,
+ context.all_vision_models,
VISION_BATTLE_TARGETS,
VISION_OUTAGE_MODELS,
VISION_SAMPLING_WEIGHTS,
@@ -280,7 +290,7 @@ def add_text(
]
else:
model_left, model_right = get_battle_pair(
- text_models,
+ context.all_text_models,
BATTLE_TARGETS,
OUTAGE_MODELS,
SAMPLING_WEIGHTS,
@@ -298,22 +308,49 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [None, ""]
+ + [None]
+ [
no_change_btn,
]
* 7
+ [""]
+ + [dont_show_vote_button]
)
model_list = [states[i].model_name for i in range(num_sides)]
images = convert_images_to_conversation_format(images)
- text, image_flagged, csam_flag = moderate_input(
- state0, text, text, model_list, images, ip
+ # Use the first state to get the moderation response because this is based on user input so it is independent of the model
+ moderation_type_to_response_map = states[
+ 0
+ ].content_moderator.image_and_text_moderation_filter(
+ images, text, model_list, do_moderation=True
+ )
+ text_flagged, nsfw_flag, csam_flag = (
+ moderation_type_to_response_map["text_moderation"]["flagged"],
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["nsfw_moderation"]
+ ]
+ ),
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["csam_moderation"]
+ ]
+ ),
)
+ if csam_flag:
+ states[0].has_csam_image, states[1].has_csam_image = True, True
+
+ for state in states:
+ state.content_moderator.append_moderation_response(
+ moderation_type_to_response_map
+ )
+
conv = states[0].conv
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
@@ -322,37 +359,40 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [{"text": CONVERSATION_LIMIT_MSG}, ""]
+ + [{"text": CONVERSATION_LIMIT_MSG}]
+ [
no_change_btn,
]
* 7
+ [""]
+ + [dont_show_vote_button]
)
- if image_flagged:
- logger.info(f"image flagged. ip: {ip}. text: {text}")
+ if text_flagged or nsfw_flag:
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
+ # We call this before appending the text so it does not appear in the UI
+ gradio_chatbot_list = [x.to_gradio_chatbot() for x in states]
for i in range(num_sides):
+ post_processed_text = _prepare_text_with_image(
+ states[i], text, images, context
+ )
+ states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].skip_next = True
+ gr.Warning(MODERATION_MSG)
return (
states
- + [x.to_gradio_chatbot() for x in states]
+ + gradio_chatbot_list
+ [
- {
- "text": IMAGE_MODERATION_MSG
- + " PLEASE CLICK π² NEW ROUND TO START A NEW CONVERSATION."
- },
- "",
+ None,
]
- + [no_change_btn] * 7
+ + [disable_btn] * 7
+ [""]
+ + [dont_show_vote_button]
)
text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
- post_processed_text = _prepare_text_with_image(
- states[i], text, images, csam_flag=csam_flag
- )
+ post_processed_text = _prepare_text_with_image(states[i], text, images, context)
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
states[i].skip_next = False
@@ -364,17 +404,18 @@ def add_text(
return (
states
+ [x.to_gradio_chatbot() for x in states]
- + [disable_multimodal, visible_text]
+ + [None]
+ [
disable_btn,
]
* 7
+ [hint_msg]
+ + [show_vote_button]
)
-def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=None):
- notice_markdown = f"""
+def build_side_by_side_vision_ui_anony(context: Context, random_questions=None):
+ notice_markdown = """
# βοΈ LMSYS Chatbot Arena (Multimodal): Benchmarking LLMs and VLMs in the Wild
[Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena)
@@ -395,8 +436,11 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
states = [gr.State() for _ in range(num_sides)]
model_selectors = [None] * num_sides
chatbots = [None] * num_sides
+ show_vote_buttons = gr.State(True)
+ context_state = gr.State(context)
gr.Markdown(notice_markdown, elem_id="notice_markdown")
+ text_and_vision_models = list(set(context.text_models + context.vision_models))
with gr.Row():
with gr.Column(scale=2, visible=False) as image_column:
@@ -409,11 +453,11 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
with gr.Column(scale=5):
with gr.Group(elem_id="share-region-anony"):
with gr.Accordion(
- f"π Expand to see the descriptions of {len(text_models) + len(vl_models)} models",
+ f"π Expand to see the descriptions of {len(text_and_vision_models)} models",
open=False,
):
model_description_md = get_model_description_md(
- text_models + vl_models
+ text_and_vision_models
)
gr.Markdown(
model_description_md, elem_id="model_description_markdown"
@@ -452,13 +496,6 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
)
with gr.Row():
- textbox = gr.Textbox(
- show_label=False,
- placeholder="π Enter your prompt and press ENTER",
- elem_id="input_box",
- visible=False,
- )
-
multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
show_label=False,
@@ -466,7 +503,6 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
placeholder="Enter your prompt or add image here",
elem_id="input_box",
)
- # send_btn = gr.Button(value="Send", variant="primary", scale=0)
with gr.Row() as button_row:
if random_questions:
@@ -518,25 +554,29 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
leftvote_btn.click(
leftvote_last_response,
states + model_selectors,
- model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ model_selectors
+ + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
rightvote_btn.click(
rightvote_last_response,
states + model_selectors,
- model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ model_selectors
+ + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
tie_btn.click(
tievote_last_response,
states + model_selectors,
- model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ model_selectors
+ + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
bothbad_btn.click(
bothbad_vote_last_response,
states + model_selectors,
- model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ model_selectors
+ + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
regenerate_btn.click(
- regenerate, states, states + chatbots + [textbox] + btn_list
+ regenerate, states, states + chatbots + [multimodal_textbox] + btn_list
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
@@ -550,7 +590,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
states
+ chatbots
+ model_selectors
- + [multimodal_textbox, textbox]
+ + [multimodal_textbox]
+ btn_list
+ [random_btn]
+ [slow_warning],
@@ -580,47 +620,25 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
set_visible_image, [multimodal_textbox], [image_column]
- ).then(
- clear_history_example,
- None,
- states + chatbots + model_selectors + [multimodal_textbox, textbox] + btn_list,
)
multimodal_textbox.submit(
add_text,
- states + model_selectors + [multimodal_textbox],
+ states + model_selectors + [multimodal_textbox, context_state],
states
+ chatbots
- + [multimodal_textbox, textbox]
+ + [multimodal_textbox]
+ btn_list
+ [random_btn]
- + [slow_warning],
+ + [slow_warning]
+ + [show_vote_buttons],
).then(set_invisible_image, [], [image_column]).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
states + chatbots + btn_list,
).then(
flash_buttons,
- [],
- btn_list,
- )
-
- textbox.submit(
- add_text,
- states + model_selectors + [textbox],
- states
- + chatbots
- + [multimodal_textbox, textbox]
- + btn_list
- + [random_btn]
- + [slow_warning],
- ).then(
- bot_response_multi,
- states + [temperature, top_p, max_output_tokens],
- states + chatbots + btn_list,
- ).then(
- flash_buttons,
- [],
+ [show_vote_buttons],
btn_list,
)
@@ -629,15 +647,6 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
[multimodal_textbox, imagebox], # Outputs are textbox and imagebox
- ).then(set_visible_image, [multimodal_textbox], [image_column]).then(
- clear_history_example,
- None,
- states
- + chatbots
- + model_selectors
- + [multimodal_textbox, textbox]
- + btn_list
- + [random_btn],
- )
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
return states + model_selectors
diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py
index ecca169ca..218eb344e 100644
--- a/fastchat/serve/gradio_block_arena_vision_named.py
+++ b/fastchat/serve/gradio_block_arena_vision_named.py
@@ -6,6 +6,7 @@
import json
import os
import time
+from typing import List, Union
import gradio as gr
import numpy as np
@@ -31,11 +32,20 @@
set_invisible_image,
set_visible_image,
add_image,
- moderate_input,
_prepare_text_with_image,
convert_images_to_conversation_format,
- enable_multimodal,
+ enable_multimodal_keep_input,
+ enable_multimodal_clear_input,
+ disable_multimodal,
+ invisible_text,
+ invisible_btn,
+ visible_text,
+)
+from fastchat.serve.moderation.moderator import (
+ BaseContentModerator,
+ AzureAndOpenAIContentModerator,
)
+from fastchat.serve.gradio_global_state import Context
from fastchat.serve.gradio_web_server import (
State,
bot_response,
@@ -52,8 +62,6 @@
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.utils import (
build_logger,
- moderation_filter,
- image_moderation_filter,
)
@@ -63,12 +71,35 @@
enable_moderation = False
+def load_demo_side_by_side_vision_named(context: Context):
+ states = [None] * num_sides
+
+ # default to the text models
+ models = context.text_models
+
+ model_left = models[0] if len(models) > 0 else ""
+ if len(models) > 1:
+ weights = ([8] * 4 + [4] * 8 + [1] * 64)[: len(models) - 1]
+ weights = weights / np.sum(weights)
+ model_right = np.random.choice(models[1:], p=weights)
+ else:
+ model_right = model_left
+
+ all_models = list(set(context.text_models + context.vision_models))
+ selector_updates = [
+ gr.Dropdown(choices=all_models, value=model_left, visible=True),
+ gr.Dropdown(choices=all_models, value=model_right, visible=True),
+ ]
+
+ return states + selector_updates
+
+
def clear_history_example(request: gr.Request):
logger.info(f"clear_history_example (named). ip: {get_ip(request)}")
return (
[None] * num_sides
+ [None] * num_sides
- + [enable_multimodal]
+ + [enable_multimodal_keep_input]
+ [invisible_btn] * 4
+ [disable_btn] * 2
)
@@ -134,6 +165,7 @@ def regenerate(state0, state1, request: gr.Request):
if state0.regen_support and state1.regen_support:
for i in range(num_sides):
states[i].conv.update_last_message(None)
+ states[i].content_moderator.update_last_moderation_response(None)
return (
states
+ [x.to_gradio_chatbot() for x in states]
@@ -152,16 +184,40 @@ def clear_history(request: gr.Request):
return (
[None] * num_sides
+ [None] * num_sides
- + [enable_multimodal]
+ + [enable_multimodal_clear_input]
+ [invisible_btn] * 4
+ [disable_btn] * 2
)
def add_text(
- state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request
+ state0,
+ state1,
+ model_selector0,
+ model_selector1,
+ chat_input: Union[str, dict],
+ context: Context,
+ request: gr.Request,
):
- text, images = chat_input["text"], chat_input["files"]
+ if isinstance(chat_input, dict):
+ text, images = chat_input["text"], chat_input["files"]
+ else:
+ text, images = chat_input, []
+
+ if len(images) > 0:
+ if (
+ model_selector0 in context.text_models
+ and model_selector0 not in context.vision_models
+ ):
+ gr.Warning(f"{model_selector0} is a text-only model. Image is ignored.")
+ images = []
+ if (
+ model_selector1 in context.text_models
+ and model_selector1 not in context.vision_models
+ ):
+ gr.Warning(f"{model_selector1} is a text-only model. Image is ignored.")
+ images = []
+
ip = get_ip(request)
logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
states = [state0, state1]
@@ -169,7 +225,9 @@ def add_text(
# Init states if necessary
for i in range(num_sides):
- if states[i] is None:
+ if states[i] is None and len(images) == 0:
+ states[i] = State(model_selectors[i], is_vision=False)
+ elif states[i] is None and len(images) > 0:
states[i] = State(model_selectors[i], is_vision=True)
if len(text) <= 0:
@@ -194,10 +252,41 @@ def add_text(
images = convert_images_to_conversation_format(images)
- text, image_flagged, csam_flag = moderate_input(
- state0, text, all_conv_text, model_list, images, ip
+ # Use the first state to get the moderation response because this is based on user input so it is independent of the model
+ moderation_image_input = images if len(images) > 0 else None
+ moderation_type_to_response_map = states[
+ 0
+ ].content_moderator.image_and_text_moderation_filter(
+ moderation_image_input, text, model_list, do_moderation=False
+ )
+
+ text_flagged, nsfw_flag, csam_flag = (
+ moderation_type_to_response_map["text_moderation"]["flagged"],
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["nsfw_moderation"]
+ ]
+ ),
+ any(
+ [
+ response["flagged"]
+ for response in moderation_type_to_response_map["csam_moderation"]
+ ]
+ ),
)
+ if csam_flag:
+ states[0].has_csam_image, states[1].has_csam_image = True, True
+
+ for state in states:
+ state.content_moderator.append_moderation_response(
+ moderation_type_to_response_map
+ )
+
+ if text_flagged or nsfw_flag:
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
+
conv = states[0].conv
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
@@ -213,14 +302,20 @@ def add_text(
* 6
)
- if image_flagged:
- logger.info(f"image flagged. ip: {ip}. text: {text}")
+ if text_flagged or nsfw_flag:
+ logger.info(f"violate moderation. ip: {ip}. text: {text}")
+ gradio_chatbot_list = [x.to_gradio_chatbot() for x in states]
for i in range(num_sides):
+ post_processed_text = _prepare_text_with_image(
+ states[i], text, images, context
+ )
+ states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].skip_next = True
+ gr.Warning(MODERATION_MSG)
return (
states
- + [x.to_gradio_chatbot() for x in states]
- + [{"text": IMAGE_MODERATION_MSG}]
+ + gradio_chatbot_list
+ + [None]
+ [
no_change_btn,
]
@@ -230,7 +325,10 @@ def add_text(
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
for i in range(num_sides):
post_processed_text = _prepare_text_with_image(
- states[i], text, images, csam_flag=csam_flag
+ states[i],
+ text,
+ images,
+ context,
)
states[i].conv.append_message(states[i].conv.roles[0], post_processed_text)
states[i].conv.append_message(states[i].conv.roles[1], None)
@@ -247,8 +345,8 @@ def add_text(
)
-def build_side_by_side_vision_ui_named(models, random_questions=None):
- notice_markdown = f"""
+def build_side_by_side_vision_ui_named(context: Context, random_questions=None):
+ notice_markdown = """
# βοΈ LMSYS Chatbot Arena (Multimodal): Benchmarking LLMs and VLMs in the Wild
[Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx)
@@ -271,6 +369,9 @@ def build_side_by_side_vision_ui_named(models, random_questions=None):
notice = gr.Markdown(notice_markdown, elem_id="notice_markdown")
+ text_and_vision_models = list(set(context.text_models + context.vision_models))
+ context_state = gr.State(context)
+
with gr.Row():
with gr.Column(scale=2, visible=False) as image_column:
imagebox = gr.Image(
@@ -282,10 +383,12 @@ def build_side_by_side_vision_ui_named(models, random_questions=None):
with gr.Column(scale=5):
with gr.Group(elem_id="share-region-anony"):
with gr.Accordion(
- f"π Expand to see the descriptions of {len(models)} models",
+ f"π Expand to see the descriptions of {len(text_and_vision_models)} models",
open=False,
):
- model_description_md = get_model_description_md(models)
+ model_description_md = get_model_description_md(
+ text_and_vision_models
+ )
gr.Markdown(
model_description_md, elem_id="model_description_markdown"
)
@@ -294,8 +397,10 @@ def build_side_by_side_vision_ui_named(models, random_questions=None):
for i in range(num_sides):
with gr.Column():
model_selectors[i] = gr.Dropdown(
- choices=models,
- value=models[i] if len(models) > i else "",
+ choices=text_and_vision_models,
+ value=text_and_vision_models[i]
+ if len(text_and_vision_models) > i
+ else "",
interactive=True,
show_label=False,
container=False,
@@ -325,7 +430,7 @@ def build_side_by_side_vision_ui_named(models, random_questions=None):
)
with gr.Row():
- textbox = gr.MultimodalTextbox(
+ multimodal_textbox = gr.MultimodalTextbox(
file_types=["image"],
show_label=False,
placeholder="Enter your prompt or add image here",
@@ -383,25 +488,25 @@ def build_side_by_side_vision_ui_named(models, random_questions=None):
leftvote_btn.click(
leftvote_last_response,
states + model_selectors,
- [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
rightvote_btn.click(
rightvote_last_response,
states + model_selectors,
- [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
tie_btn.click(
tievote_last_response,
states + model_selectors,
- [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
bothbad_btn.click(
bothbad_vote_last_response,
states + model_selectors,
- [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
+ [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
)
regenerate_btn.click(
- regenerate, states, states + chatbots + [textbox] + btn_list
+ regenerate, states, states + chatbots + [multimodal_textbox] + btn_list
).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
@@ -409,7 +514,11 @@ def build_side_by_side_vision_ui_named(models, random_questions=None):
).then(
flash_buttons, [], btn_list
)
- clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list)
+ clear_btn.click(
+ clear_history,
+ None,
+ states + chatbots + [multimodal_textbox] + btn_list,
+ )
share_js = """
function (a, b, c, d) {
@@ -435,17 +544,19 @@ def build_side_by_side_vision_ui_named(models, random_questions=None):
for i in range(num_sides):
model_selectors[i].change(
- clear_history, None, states + chatbots + [textbox] + btn_list
- ).then(set_visible_image, [textbox], [image_column])
+ clear_history,
+ None,
+ states + chatbots + [multimodal_textbox] + btn_list,
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
- textbox.input(add_image, [textbox], [imagebox]).then(
- set_visible_image, [textbox], [image_column]
- ).then(clear_history_example, None, states + chatbots + [textbox] + btn_list)
+ multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then(
+ set_visible_image, [multimodal_textbox], [image_column]
+ )
- textbox.submit(
+ multimodal_textbox.submit(
add_text,
- states + model_selectors + [textbox],
- states + chatbots + [textbox] + btn_list,
+ states + model_selectors + [multimodal_textbox, context_state],
+ states + chatbots + [multimodal_textbox] + btn_list,
).then(set_invisible_image, [], [image_column]).then(
bot_response_multi,
states + [temperature, top_p, max_output_tokens],
@@ -458,9 +569,7 @@ def build_side_by_side_vision_ui_named(models, random_questions=None):
random_btn.click(
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
- [textbox, imagebox], # Outputs are textbox and imagebox
- ).then(set_visible_image, [textbox], [image_column]).then(
- clear_history_example, None, states + chatbots + [textbox] + btn_list
- )
+ [multimodal_textbox, imagebox], # Outputs are textbox and imagebox
+ ).then(set_visible_image, [multimodal_textbox], [image_column])
return states + model_selectors
diff --git a/fastchat/serve/gradio_global_state.py b/fastchat/serve/gradio_global_state.py
new file mode 100644
index 000000000..de911985d
--- /dev/null
+++ b/fastchat/serve/gradio_global_state.py
@@ -0,0 +1,11 @@
+from dataclasses import dataclass, field
+from typing import List
+
+
+@dataclass
+class Context:
+ text_models: List[str] = field(default_factory=list)
+ all_text_models: List[str] = field(default_factory=list)
+ vision_models: List[str] = field(default_factory=list)
+ all_vision_models: List[str] = field(default_factory=list)
+ api_endpoint_info: dict = field(default_factory=dict)
diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py
index 2ef47b14d..46ff0c421 100644
--- a/fastchat/serve/gradio_web_server.py
+++ b/fastchat/serve/gradio_web_server.py
@@ -11,6 +11,7 @@
import random
import time
import uuid
+from typing import List
import gradio as gr
import requests
@@ -33,12 +34,13 @@
)
from fastchat.model.model_registry import get_model_info, model_info
from fastchat.serve.api_provider import get_api_provider_stream_iter
+from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator
+from fastchat.serve.gradio_global_state import Context
from fastchat.serve.remote_logger import get_remote_logger
from fastchat.utils import (
build_logger,
get_window_url_params_js,
get_window_url_params_with_tos_js,
- moderation_filter,
parse_gradio_auth_creds,
load_image,
)
@@ -59,6 +61,8 @@
visible=True,
placeholder='Press "π² New Round" to start overπ (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)',
)
+show_vote_button = True
+dont_show_vote_button = False
controller_url = None
enable_moderation = False
@@ -117,6 +121,7 @@ def __init__(self, model_name, is_vision=False):
self.model_name = model_name
self.oai_thread_id = None
self.is_vision = is_vision
+ self.content_moderator = AzureAndOpenAIContentModerator()
# NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes.
self.has_csam_image = False
@@ -143,6 +148,7 @@ def dict(self):
{
"conv_id": self.conv_id,
"model_name": self.model_name,
+ "moderation": self.content_moderator.conv_moderation_responses,
}
)
@@ -151,7 +157,11 @@ def dict(self):
return base
-def set_global_vars(controller_url_, enable_moderation_, use_remote_storage_):
+def set_global_vars(
+ controller_url_,
+ enable_moderation_,
+ use_remote_storage_,
+):
global controller_url, enable_moderation, use_remote_storage
controller_url = controller_url_
enable_moderation = enable_moderation_
@@ -218,16 +228,27 @@ def get_model_list(controller_url, register_api_endpoint_file, vision_arena):
return visible_models, models
-def load_demo_single(models, url_params):
+def _get_api_endpoint_info():
+ global api_endpoint_info
+ return api_endpoint_info
+
+
+def load_demo_single(context: Context, query_params):
+ # default to text models
+ models = context.text_models
+
selected_model = models[0] if len(models) > 0 else ""
- if "model" in url_params:
- model = url_params["model"]
+ if "model" in query_params:
+ model = query_params["model"]
if model in models:
selected_model = model
- dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True)
+ all_models = list(set(context.text_models + context.vision_models))
+ dropdown_update = gr.Dropdown(
+ choices=all_models, value=selected_model, visible=True
+ )
state = None
- return state, dropdown_update
+ return [state, dropdown_update]
def load_demo(url_params, request: gr.Request):
@@ -289,6 +310,7 @@ def regenerate(state, request: gr.Request):
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
state.conv.update_last_message(None)
+ state.content_moderator.update_last_moderation_response(None)
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
@@ -322,14 +344,24 @@ def add_text(state, model_selector, text, request: gr.Request):
state.skip_next = True
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
- all_conv_text = state.conv.get_prompt()
- all_conv_text = all_conv_text[-2000:] + "\nuser: " + text
- flagged = moderation_filter(all_conv_text, [state.model_name])
- # flagged = moderation_filter(text, [state.model_name])
- if flagged:
+ content_moderator = AzureAndOpenAIContentModerator()
+ text_flagged = content_moderator.text_moderation_filter(text, [state.model_name])
+
+ if text_flagged:
logger.info(f"violate moderation. ip: {ip}. text: {text}")
# overwrite the original text
- text = MODERATION_MSG
+ content_moderator.write_to_json(get_ip(request))
+ state.skip_next = True
+ gr.Warning(MODERATION_MSG)
+ return (
+ [state]
+ + [state.to_gradio_chatbot()]
+ + [""]
+ + [
+ no_change_btn,
+ ]
+ * 5
+ )
if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
@@ -400,6 +432,34 @@ def is_limit_reached(model_name, ip):
return None
+def _write_to_json(
+ filename: str,
+ start_tstamp: float,
+ finish_tstamp: float,
+ state: State,
+ temperature: float,
+ top_p: float,
+ max_new_tokens: int,
+ request: gr.Request,
+):
+ with open(filename, "a") as fout:
+ data = {
+ "tstamp": round(finish_tstamp, 4),
+ "type": "chat",
+ "model": state.model_name,
+ "gen_params": {
+ "temperature": temperature,
+ "top_p": top_p,
+ "max_new_tokens": max_new_tokens,
+ },
+ "start": round(start_tstamp, 4),
+ "finish": round(finish_tstamp, 4),
+ "state": state.dict(),
+ "ip": get_ip(request),
+ }
+ fout.write(json.dumps(data) + "\n")
+
+
def bot_response(
state,
temperature,
@@ -419,6 +479,32 @@ def bot_response(
if state.skip_next:
# This generate call is skipped due to invalid inputs
state.skip_next = False
+ if state.content_moderator.text_flagged or state.content_moderator.nsfw_flagged:
+ start_tstamp = time.time()
+ finish_tstamp = start_tstamp
+ state.conv.save_new_images(
+ has_csam_images=state.has_csam_image,
+ use_remote_storage=use_remote_storage,
+ )
+
+ filename = get_conv_log_filename(
+ is_vision=state.is_vision, has_csam_image=state.has_csam_image
+ )
+
+ _write_to_json(
+ filename,
+ start_tstamp,
+ finish_tstamp,
+ state,
+ temperature,
+ top_p,
+ max_new_tokens,
+ request,
+ )
+
+ # Remove the last message: the user input
+ state.conv.messages.pop()
+ state.content_moderator.update_last_moderation_response(None)
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
return
@@ -570,22 +656,23 @@ def bot_response(
is_vision=state.is_vision, has_csam_image=state.has_csam_image
)
- with open(filename, "a") as fout:
- data = {
- "tstamp": round(finish_tstamp, 4),
- "type": "chat",
- "model": model_name,
- "gen_params": {
- "temperature": temperature,
- "top_p": top_p,
- "max_new_tokens": max_new_tokens,
- },
- "start": round(start_tstamp, 4),
- "finish": round(finish_tstamp, 4),
- "state": state.dict(),
- "ip": get_ip(request),
- }
- fout.write(json.dumps(data) + "\n")
+ moderation_type_to_response_map = (
+ state.content_moderator.image_and_text_moderation_filter(
+ None, output, [state.model_name], do_moderation=True
+ )
+ )
+ state.content_moderator.append_moderation_response(moderation_type_to_response_map)
+
+ _write_to_json(
+ filename,
+ start_tstamp,
+ finish_tstamp,
+ state,
+ temperature,
+ top_p,
+ max_new_tokens,
+ request,
+ )
get_remote_logger().log(data)
diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py
index 14f254bf3..f19fab6ec 100644
--- a/fastchat/serve/gradio_web_server_multi.py
+++ b/fastchat/serve/gradio_web_server_multi.py
@@ -6,6 +6,7 @@
import argparse
import pickle
import time
+from typing import List
import gradio as gr
@@ -28,7 +29,9 @@
)
from fastchat.serve.gradio_block_arena_vision_named import (
build_side_by_side_vision_ui_named,
+ load_demo_side_by_side_vision_named,
)
+from fastchat.serve.gradio_global_state import Context
from fastchat.serve.gradio_web_server import (
set_global_vars,
@@ -38,6 +41,7 @@
get_model_list,
load_demo_single,
get_ip,
+ _get_api_endpoint_info,
)
from fastchat.serve.monitor.monitor import build_leaderboard_tab
from fastchat.utils import (
@@ -51,57 +55,66 @@
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
-def load_demo(url_params, request: gr.Request):
- global models, all_models, vl_models, all_vl_models
-
+def load_demo(context: Context, url_params, request: gr.Request):
ip = get_ip(request)
logger.info(f"load_demo. ip: {ip}. params: {url_params}")
inner_selected = 0
- if "arena" in url_params:
+ if "arena" in request.query_params:
inner_selected = 0
- elif "vision" in url_params:
- inner_selected = 1
- elif "compare" in url_params:
+ elif "vision" in request.query_params:
+ inner_selected = 0
+ elif "compare" in request.query_params:
inner_selected = 1
- elif "direct" in url_params or "model" in url_params:
+ elif "direct" in request.query_params or "model" in request.query_params:
+ inner_selected = 2
+ elif "leaderboard" in request.query_params:
inner_selected = 3
- elif "leaderboard" in url_params:
+ elif "about" in request.query_params:
inner_selected = 4
- elif "about" in url_params:
- inner_selected = 5
if args.model_list_mode == "reload":
- models, all_models = get_model_list(
+ context.text_models, context.all_text_models = get_model_list(
args.controller_url,
args.register_api_endpoint_file,
vision_arena=False,
)
- vl_models, all_vl_models = get_model_list(
+ context.vision_models, context.all_vision_models = get_model_list(
args.controller_url,
args.register_api_endpoint_file,
vision_arena=True,
)
- single_updates = load_demo_single(models, url_params)
- side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params)
- side_by_side_named_updates = load_demo_side_by_side_named(models, url_params)
+ # Text models
+ if args.vision_arena:
+ side_by_side_anony_updates = load_demo_side_by_side_vision_anony()
- side_by_side_vision_anony_updates = load_demo_side_by_side_vision_anony(
- all_models, all_vl_models, url_params
- )
+ side_by_side_named_updates = load_demo_side_by_side_vision_named(
+ context,
+ )
- return (
- (gr.Tabs(selected=inner_selected),)
- + single_updates
+ direct_chat_updates = load_demo_single(context, request.query_params)
+ else:
+ direct_chat_updates = load_demo_single(context, request.query_params)
+ side_by_side_anony_updates = load_demo_side_by_side_anony(
+ context.all_text_models, request.query_params
+ )
+ side_by_side_named_updates = load_demo_side_by_side_named(
+ context.text_models, request.query_params
+ )
+
+ tabs_list = (
+ [gr.Tabs(selected=inner_selected)]
+ side_by_side_anony_updates
+ side_by_side_named_updates
- + side_by_side_vision_anony_updates
+ + direct_chat_updates
)
+ return tabs_list
-def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
+
+def build_demo(context: Context, elo_results_file: str, leaderboard_table_file):
if args.show_terms_of_use:
load_js = get_window_url_params_with_tos_js
else:
@@ -134,41 +147,62 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
with gr.Tab("βοΈ Arena (battle)", id=0) as arena_tab:
arena_tab.select(None, None, None, js=load_js)
side_by_side_anony_list = build_side_by_side_vision_ui_anony(
- all_models,
- all_vl_models,
+ context,
random_questions=args.random_questions,
)
+ with gr.Tab("βοΈ Arena (side-by-side)", id=1) as side_by_side_tab:
+ side_by_side_tab.select(None, None, None, js=alert_js)
+ side_by_side_named_list = build_side_by_side_vision_ui_named(
+ context, random_questions=args.random_questions
+ )
+
+ with gr.Tab("π¬ Direct Chat", id=2) as direct_tab:
+ direct_tab.select(None, None, None, js=alert_js)
+ single_model_list = build_single_vision_language_model_ui(
+ context,
+ add_promotion_links=True,
+ random_questions=args.random_questions,
+ )
+
else:
with gr.Tab("βοΈ Arena (battle)", id=0) as arena_tab:
arena_tab.select(None, None, None, js=load_js)
- side_by_side_anony_list = build_side_by_side_ui_anony(models)
+ side_by_side_anony_list = build_side_by_side_ui_anony(
+ context.all_text_models
+ )
- with gr.Tab("βοΈ Arena (side-by-side)", id=2) as side_by_side_tab:
- side_by_side_tab.select(None, None, None, js=alert_js)
- side_by_side_named_list = build_side_by_side_ui_named(models)
+ with gr.Tab("βοΈ Arena (side-by-side)", id=1) as side_by_side_tab:
+ side_by_side_tab.select(None, None, None, js=alert_js)
+ side_by_side_named_list = build_side_by_side_ui_named(
+ context.text_models
+ )
- with gr.Tab("π¬ Direct Chat", id=3) as direct_tab:
- direct_tab.select(None, None, None, js=alert_js)
- single_model_list = build_single_model_ui(
- models, add_promotion_links=True
- )
+ with gr.Tab("π¬ Direct Chat", id=2) as direct_tab:
+ direct_tab.select(None, None, None, js=alert_js)
+ single_model_list = build_single_model_ui(
+ context.text_models, add_promotion_links=True
+ )
demo_tabs = (
[inner_tabs]
- + single_model_list
+ side_by_side_anony_list
+ side_by_side_named_list
+ + single_model_list
)
if elo_results_file:
- with gr.Tab("π Leaderboard", id=4):
+ with gr.Tab("π Leaderboard", id=3):
build_leaderboard_tab(
- elo_results_file, leaderboard_table_file, show_plot=True
+ elo_results_file,
+ leaderboard_table_file,
+ arena_hard_leaderboard=None,
+ show_plot=True,
)
- with gr.Tab("βΉοΈ About Us", id=5):
+ with gr.Tab("βΉοΈ About Us", id=4):
about = build_about()
+ context_state = gr.State(context)
url_params = gr.JSON(visible=False)
if args.model_list_mode not in ["once", "reload"]:
@@ -176,7 +210,7 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
demo.load(
load_demo,
- [url_params],
+ [context_state, url_params],
demo_tabs,
js=load_js,
)
@@ -272,19 +306,28 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
# Set global variables
set_global_vars(args.controller_url, args.moderate, args.use_remote_storage)
- set_global_vars_named(args.moderate)
- set_global_vars_anony(args.moderate)
- models, all_models = get_model_list(
+ set_global_vars_named(args.moderate, args.use_remote_storage)
+ set_global_vars_anony(args.moderate, args.use_remote_storage)
+ text_models, all_text_models = get_model_list(
args.controller_url,
args.register_api_endpoint_file,
vision_arena=False,
)
- vl_models, all_vl_models = get_model_list(
+ vision_models, all_vision_models = get_model_list(
args.controller_url,
args.register_api_endpoint_file,
vision_arena=True,
)
+ api_endpoint_info = _get_api_endpoint_info()
+
+ context = Context(
+ text_models,
+ all_text_models,
+ vision_models,
+ all_vision_models,
+ api_endpoint_info,
+ )
# Set authorization credentials
auth = None
@@ -293,8 +336,7 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
# Launch the demo
demo = build_demo(
- models,
- all_vl_models,
+ context,
args.elo_results_file,
args.leaderboard_table_file,
)
diff --git a/fastchat/serve/moderation/moderator.py b/fastchat/serve/moderation/moderator.py
new file mode 100644
index 000000000..efafdcd88
--- /dev/null
+++ b/fastchat/serve/moderation/moderator.py
@@ -0,0 +1,268 @@
+import datetime
+import hashlib
+import os
+import json
+import time
+import base64
+import requests
+from typing import Tuple, Dict, List, Union
+
+from fastchat.constants import LOGDIR
+from fastchat.serve.vision.image import Image
+from fastchat.utils import load_image, upload_image_file_to_gcs
+
+
+class BaseContentModerator:
+ def __init__(self):
+ self.conv_moderation_responses: List[
+ Dict[str, Dict[str, Union[str, Dict[str, float]]]]
+ ] = []
+
+ def _image_moderation_filter(self, image: Image) -> Tuple[bool, bool]:
+ """Function that detects whether image violates moderation policies.
+
+ Returns:
+ Tuple[bool, bool]: A tuple of two boolean values indicating whether the image was flagged for nsfw and csam respectively.
+ """
+ raise NotImplementedError
+
+ def _text_moderation_filter(self, text: str) -> bool:
+ """Function that detects whether text violates moderation policies.
+
+ Returns:
+ bool: A boolean value indicating whether the text was flagged.
+ """
+ raise NotImplementedError
+
+ def image_and_text_moderation_filter(
+ self, images: List[Image], text: str
+ ) -> Dict[str, Dict[str, Union[str, Dict[str, float]]]]:
+ """Function that detects whether image and text violate moderation policies.
+
+ Returns:
+ Dict[str, Dict[str, Union[str, Dict[str, float]]]]: A dictionary that maps the type of moderation (text, nsfw, csam) to a dictionary that contains the moderation response.
+ """
+ raise NotImplementedError
+
+ def append_moderation_response(
+ self, moderation_response: Dict[str, Dict[str, Union[str, Dict[str, float]]]]
+ ):
+ """Function that appends the moderation response to the list of moderation responses."""
+ if (
+ len(self.conv_moderation_responses) == 0
+ or self.conv_moderation_responses[-1] is not None
+ ):
+ self.conv_moderation_responses.append(moderation_response)
+ else:
+ self.update_last_moderation_response(moderation_response)
+
+ def update_last_moderation_response(
+ self, moderation_response: Dict[str, Dict[str, Union[str, Dict[str, float]]]]
+ ):
+ """Function that updates the last moderation response."""
+ self.conv_moderation_responses[-1] = moderation_response
+
+
+class AzureAndOpenAIContentModerator(BaseContentModerator):
+ _NON_TOXIC_IMAGE_MODERATION_MAP = {
+ "nsfw_moderation": [{"flagged": False}],
+ "csam_moderation": [{"flagged": False}],
+ }
+
+ def __init__(self, use_remote_storage: bool = False):
+ """This class is used to moderate content using Azure and OpenAI.
+
+ conv_to_moderation_responses: A dictionary that is a map from the type of moderation
+ (text, nsfw, csam) moderation to the moderation response returned from the request sent
+ to the moderation API.
+ """
+ super().__init__()
+ self.text_flagged = False
+ self.csam_flagged = False
+ self.nsfw_flagged = False
+
+ def _image_moderation_request(
+ self, image_bytes: bytes, endpoint: str, api_key: str
+ ) -> dict:
+ headers = {"Content-Type": "image/jpeg", "Ocp-Apim-Subscription-Key": api_key}
+
+ MAX_RETRIES = 3
+ for _ in range(MAX_RETRIES):
+ response = requests.post(endpoint, headers=headers, data=image_bytes).json()
+ try:
+ if response["Status"]["Code"] == 3000:
+ break
+ except:
+ time.sleep(0.5)
+ return response
+
+ def _image_moderation_provider(self, image_bytes: bytes, api_type: str) -> bool:
+ if api_type == "nsfw":
+ endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"]
+ api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"]
+ response = self._image_moderation_request(image_bytes, endpoint, api_key)
+ flagged = response["IsImageAdultClassified"]
+ elif api_type == "csam":
+ endpoint = (
+ "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false"
+ )
+ api_key = os.environ["PHOTODNA_API_KEY"]
+ response = self._image_moderation_request(image_bytes, endpoint, api_key)
+ flagged = response["IsMatch"]
+
+ image_md5_hash = hashlib.md5(image_bytes).hexdigest()
+ moderation_response_map = {
+ "image_hash": image_md5_hash,
+ "response": response,
+ "flagged": False,
+ }
+ if flagged:
+ moderation_response_map["flagged"] = True
+
+ return moderation_response_map
+
+ def image_moderation_filter(self, images: List[Image]):
+ print(f"moderating images")
+
+ images_moderation_response: Dict[
+ str, List[Dict[str, Union[str, Dict[str, float]]]]
+ ] = {
+ "nsfw_moderation": [],
+ "csam_moderation": [],
+ }
+
+ for image in images:
+ image_bytes = base64.b64decode(image.base64_str)
+
+ nsfw_flagged_map = self._image_moderation_provider(image_bytes, "nsfw")
+
+ if nsfw_flagged_map["flagged"]:
+ csam_flagged_map = self._image_moderation_provider(image_bytes, "csam")
+ else:
+ csam_flagged_map = {"flagged": False}
+
+ self.nsfw_flagged |= nsfw_flagged_map["flagged"]
+ self.csam_flagged |= csam_flagged_map["flagged"]
+
+ images_moderation_response["nsfw_moderation"].append(nsfw_flagged_map)
+ images_moderation_response["csam_moderation"].append(csam_flagged_map)
+
+ return images_moderation_response
+
+ def _openai_moderation_filter(
+ self, text: str, custom_thresholds: dict = None
+ ) -> bool:
+ """
+ Check whether the text violates OpenAI moderation API.
+ """
+ import openai
+
+ client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
+
+ # default to true to be conservative
+ flagged = True
+ MAX_RETRY = 3
+ moderation_response_map = {"content": text, "response": None, "flagged": False}
+ for _ in range(MAX_RETRY):
+ try:
+ res = client.moderations.create(input=text)
+ flagged = res.results[0].flagged
+ if custom_thresholds is not None:
+ for category, threshold in custom_thresholds.items():
+ if (
+ getattr(res.results[0].category_scores, category)
+ > threshold
+ ):
+ flagged = True
+ moderation_response_map = {
+ "response": dict(res.results[0].category_scores),
+ "flagged": flagged,
+ }
+ break
+ except (openai.OpenAIError, KeyError, IndexError) as e:
+ print(f"MODERATION ERROR: {e}\nInput: {text}")
+
+ return moderation_response_map
+
+ def text_moderation_filter(
+ self, text: str, model_list: List[str], do_moderation: bool = False
+ ):
+ # Apply moderation for below models
+ MODEL_KEYWORDS = [
+ "claude",
+ "gpt",
+ "bard",
+ "mistral-large",
+ "command-r",
+ "dbrx",
+ "gemini",
+ "reka",
+ ]
+
+ custom_thresholds = {"sexual": 0.3}
+ # set a stricter threshold for claude
+ for model in model_list:
+ if "claude" in model:
+ custom_thresholds = {"sexual": 0.2}
+
+ for keyword in MODEL_KEYWORDS:
+ for model in model_list:
+ if keyword in model:
+ do_moderation = True
+ break
+
+ moderation_response_map = {"flagged": False}
+ if do_moderation:
+ moderation_response_map = self._openai_moderation_filter(
+ text, custom_thresholds
+ )
+ self.text_flagged = moderation_response_map["flagged"]
+
+ return {"text_moderation": moderation_response_map}
+
+ def image_and_text_moderation_filter(
+ self, images: List[Image], text: str, model_list: List[str], do_moderation=True
+ ) -> Dict[str, bool]:
+ """Function that detects whether image and text violate moderation policies using the Azure and OpenAI moderation APIs.
+
+ Returns:
+ Dict[str, Dict[str, Union[str, Dict[str, float]]]]: A dictionary that maps the type of moderation (text, nsfw, csam) to a dictionary that contains the moderation response.
+
+ Example:
+ {
+ "text_moderation": {
+ "content": "This is a test",
+ "response": {
+ "sexual": 0.1
+ },
+ "flagged": True
+ },
+ "nsfw_moderation": {
+ "image_hash": "1234567890",
+ "response": {
+ "IsImageAdultClassified": True
+ },
+ "flagged": True
+ },
+ "csam_moderation": {
+ "image_hash": "1234567890",
+ "response": {
+ "IsMatch": True
+ },
+ "flagged": True
+ }
+ }
+ """
+ print("moderating text: ", text)
+ text_flagged_map = self.text_moderation_filter(text, model_list, do_moderation)
+
+ if images:
+ image_flagged_map = self.image_moderation_filter(images)
+ else:
+ image_flagged_map = self._NON_TOXIC_IMAGE_MODERATION_MAP
+
+ res = {}
+ res.update(text_flagged_map)
+ res.update(image_flagged_map)
+
+ return res
diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py
index f5842c7c4..85ad2b647 100644
--- a/fastchat/serve/monitor/monitor.py
+++ b/fastchat/serve/monitor/monitor.py
@@ -455,7 +455,7 @@ def update_leaderboard_and_plots(category):
"Knowledge Cutoff",
],
datatype=[
- "str",
+ "number",
"markdown",
"number",
"str",
@@ -534,7 +534,7 @@ def update_leaderboard_and_plots(category):
"Knowledge Cutoff",
],
datatype=[
- "str",
+ "number",
"markdown",
"number",
"str",
@@ -562,9 +562,7 @@ def update_leaderboard_and_plots(category):
elem_id="leaderboard_markdown",
)
- if not vision:
- # only live update the text tab
- leader_component_values[:] = [default_md, p1, p2, p3, p4]
+ leader_component_values[:] = [default_md, p1, p2, p3, p4]
if show_plot:
more_stats_md = gr.Markdown(
@@ -777,6 +775,7 @@ def build_leaderboard_tab(
elo_results_file,
leaderboard_table_file,
arena_hard_leaderboard,
+ vision=True,
show_plot=False,
mirror=False,
):
@@ -855,14 +854,18 @@ def build_leaderboard_tab(
default_md,
show_plot=show_plot,
)
- with gr.Tab("Arena (Vision)", id=2):
- build_arena_tab(
- elo_results_vision,
- model_table_df,
- default_md,
- vision=True,
- show_plot=show_plot,
- )
+
+ if vision:
+ with gr.Tab("Arena (Vision)", id=2):
+ build_arena_tab(
+ elo_results_vision,
+ model_table_df,
+ default_md,
+ vision=True,
+ show_plot=show_plot,
+ )
+
+ model_to_score = {}
if arena_hard_leaderboard is not None:
with gr.Tab("Arena-Hard-Auto", id=3):
dataFrame = arena_hard_process(
@@ -882,7 +885,6 @@ def build_leaderboard_tab(
"avg_tokens": "Average Tokens",
}
)
- model_to_score = {}
for i in range(len(dataFrame)):
model_to_score[dataFrame.loc[i, "Model"]] = dataFrame.loc[
i, "Win-rate"
diff --git a/fastchat/utils.py b/fastchat/utils.py
index 545e01414..4ec8249e1 100644
--- a/fastchat/utils.py
+++ b/fastchat/utils.py
@@ -149,61 +149,6 @@ def get_gpu_memory(max_gpus=None):
return gpu_memory
-def oai_moderation(text, custom_thresholds=None):
- """
- Check whether the text violates OpenAI moderation API.
- """
- import openai
-
- client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])
-
- # default to true to be conservative
- flagged = True
- MAX_RETRY = 3
- for _ in range(MAX_RETRY):
- try:
- res = client.moderations.create(input=text)
- flagged = res.results[0].flagged
- if custom_thresholds is not None:
- for category, threshold in custom_thresholds.items():
- if getattr(res.results[0].category_scores, category) > threshold:
- flagged = True
- break
- except (openai.OpenAIError, KeyError, IndexError) as e:
- print(f"MODERATION ERROR: {e}\nInput: {text}")
- return flagged
-
-
-def moderation_filter(text, model_list, do_moderation=False):
- # Apply moderation for below models
- MODEL_KEYWORDS = [
- "claude",
- "gpt",
- "bard",
- "mistral-large",
- "command-r",
- "dbrx",
- "gemini",
- "reka",
- ]
-
- custom_thresholds = {"sexual": 0.3}
- # set a stricter threshold for claude
- for model in model_list:
- if "claude" in model:
- custom_thresholds = {"sexual": 0.2}
-
- for keyword in MODEL_KEYWORDS:
- for model in model_list:
- if keyword in model:
- do_moderation = True
- break
-
- if do_moderation:
- return oai_moderation(text, custom_thresholds)
- return False
-
-
def clean_flant5_ckpt(ckpt_path):
"""
Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
@@ -438,47 +383,3 @@ def get_image_file_from_gcs(filename):
contents = blob.download_as_bytes()
return contents
-
-
-def image_moderation_request(image_bytes, endpoint, api_key):
- headers = {"Content-Type": "image/jpeg", "Ocp-Apim-Subscription-Key": api_key}
-
- MAX_RETRIES = 3
- for _ in range(MAX_RETRIES):
- response = requests.post(endpoint, headers=headers, data=image_bytes).json()
- try:
- if response["Status"]["Code"] == 3000:
- break
- except:
- time.sleep(0.5)
- return response
-
-
-def image_moderation_provider(image, api_type):
- if api_type == "nsfw":
- endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"]
- api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"]
- response = image_moderation_request(image, endpoint, api_key)
- print(response)
- return response["IsImageAdultClassified"]
- elif api_type == "csam":
- endpoint = (
- "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false"
- )
- api_key = os.environ["PHOTODNA_API_KEY"]
- response = image_moderation_request(image, endpoint, api_key)
- return response["IsMatch"]
-
-
-def image_moderation_filter(image):
- print(f"moderating image")
-
- image_bytes = base64.b64decode(image.base64_str)
-
- nsfw_flagged = image_moderation_provider(image_bytes, "nsfw")
- csam_flagged = False
-
- if nsfw_flagged:
- csam_flagged = image_moderation_provider(image_bytes, "csam")
-
- return nsfw_flagged, csam_flagged
diff --git a/playground/__init__.py b/playground/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/playground/benchmark/benchmark_api_provider.py b/playground/benchmark/benchmark_api_provider.py
new file mode 100644
index 000000000..89ca02ece
--- /dev/null
+++ b/playground/benchmark/benchmark_api_provider.py
@@ -0,0 +1,135 @@
+"""
+Usage:
+python3 -m playground.benchmark.benchmark_api_provider --api-endpoint-file api_endpoints.json --output-file ./benchmark_results.json --random-questions metadata_sampled.json
+"""
+import argparse
+import json
+import time
+
+import numpy as np
+
+from fastchat.serve.api_provider import get_api_provider_stream_iter
+from fastchat.serve.gradio_web_server import State
+from fastchat.serve.vision.image import Image
+
+
+class Metrics:
+ def __init__(self):
+ self.ttft = None
+ self.avg_token_time = None
+
+ def to_dict(self):
+ return {"ttft": self.ttft, "avg_token_time": self.avg_token_time}
+
+
+def sample_image_and_question(random_questions_dict, index):
+ # message = np.random.choice(random_questions_dict)
+ message = random_questions_dict[index]
+ question = message["question"]
+ path = message["path"]
+
+ if isinstance(question, list):
+ question = question[0]
+
+ return (question, path)
+
+
+def call_model(
+ conv,
+ model_name,
+ model_api_dict,
+ state,
+ temperature=0.4,
+ top_p=0.9,
+ max_new_tokens=2048,
+):
+ prev_message = ""
+ prev_time = time.time()
+ CHARACTERS_PER_TOKEN = 4
+ metrics = Metrics()
+
+ stream_iter = get_api_provider_stream_iter(
+ conv, model_name, model_api_dict, temperature, top_p, max_new_tokens, state
+ )
+ call_time = time.time()
+ token_times = []
+ for i, data in enumerate(stream_iter):
+ output = data["text"].strip()
+ if i == 0:
+ metrics.ttft = time.time() - call_time
+ prev_message = output
+ prev_time = time.time()
+ else:
+ token_diff_length = (len(output) - len(prev_message)) / CHARACTERS_PER_TOKEN
+ if token_diff_length == 0:
+ continue
+
+ token_diff_time = time.time() - prev_time
+ token_time = token_diff_time / token_diff_length
+ token_times.append(token_time)
+ prev_time = time.time()
+
+ metrics.avg_token_time = np.mean(token_times)
+ return metrics
+
+
+def run_benchmark(model_name, model_api_dict, random_questions_dict, num_calls=20):
+ model_results = []
+
+ for index in range(num_calls):
+ state = State(model_name)
+ text, image_path = sample_image_and_question(random_questions_dict, index)
+ max_image_size_mb = 5 / 1.5
+
+ images = [
+ Image(url=image_path).to_conversation_format(
+ max_image_size_mb=max_image_size_mb
+ )
+ ]
+ message = (text, images)
+
+ state.conv.append_message(state.conv.roles[0], message)
+ state.conv.append_message(state.conv.roles[1], None)
+
+ metrics = call_model(state.conv, model_name, model_api_dict, state)
+ model_results.append(metrics.to_dict())
+
+ return model_results
+
+
+def benchmark_models(api_endpoint_info, random_questions_dict, models):
+ results = {model_name: [] for model_name in models}
+
+ for model_name in models:
+ model_results = run_benchmark(
+ model_name,
+ api_endpoint_info[model_name],
+ random_questions_dict,
+ num_calls=20,
+ )
+ results[model_name] = model_results
+
+ print(results)
+ return results
+
+
+def main(api_endpoint_file, random_questions, output_file):
+ api_endpoint_info = json.load(open(api_endpoint_file))
+ random_questions_dict = json.load(open(random_questions))
+ models = ["reka-core-20240501", "gpt-4o-2024-05-13"]
+
+ models_results = benchmark_models(api_endpoint_info, random_questions_dict, models)
+
+ with open(output_file, "w") as f:
+ json.dump(models_results, f)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--api-endpoint-file", required=True)
+ parser.add_argument("--random-questions", required=True)
+ parser.add_argument("--output-file", required=True)
+
+ args = parser.parse_args()
+
+ main(args.api_endpoint_file, args.random_questions, args.output_file)