diff --git a/fastchat/conversation.py b/fastchat/conversation.py index af916c97f..21078938e 100644 --- a/fastchat/conversation.py +++ b/fastchat/conversation.py @@ -365,12 +365,16 @@ def to_gradio_chatbot(self): if i % 2 == 0: if type(msg) is tuple: msg, images = msg - image = images[0] # Only one image on gradio at one time - if image.image_format == ImageFormat.URL: - img_str = f'user upload image' - elif image.image_format == ImageFormat.BYTES: - img_str = f'user upload image' - msg = img_str + msg.replace("\n", "").strip() + combined_image_str = "" + for image in images: + if image.image_format == ImageFormat.URL: + img_str = ( + f'user upload image' + ) + elif image.image_format == ImageFormat.BYTES: + img_str = f'user upload image' + combined_image_str += img_str + msg = combined_image_str + msg.replace("\n", "").strip() ret.append([msg, None]) else: @@ -524,6 +528,7 @@ def to_anthropic_vision_api_messages(self): def to_reka_api_messages(self): from fastchat.serve.vision.image import ImageFormat + from reka import ChatMessage, TypedMediaContent, TypedText ret = [] for i, (_, msg) in enumerate(self.messages[self.offset :]): @@ -531,23 +536,47 @@ def to_reka_api_messages(self): if type(msg) == tuple: text, images = msg for image in images: - if image.image_format == ImageFormat.URL: - ret.append( - {"type": "human", "text": text, "media_url": image.url} - ) - elif image.image_format == ImageFormat.BYTES: + if image.image_format == ImageFormat.BYTES: ret.append( - { - "type": "human", - "text": text, - "media_url": f"data:image/{image.filetype};base64,{image.base64_str}", - } + ChatMessage( + content=[ + TypedText( + type="text", + text=text, + ), + TypedMediaContent( + type="image_url", + image_url=f"data:image/{image.filetype};base64,{image.base64_str}", + ), + ], + role="user", + ) ) else: - ret.append({"type": "human", "text": msg}) + ret.append( + ChatMessage( + content=[ + TypedText( + type="text", + text=msg, + ) + ], + role="user", + ) + ) else: if msg is not None: - ret.append({"type": "model", "text": msg}) + ret.append( + ChatMessage( + content=[ + TypedText( + type="text", + text=msg, + ) + ], + role="assistant", + ) + ) return ret @@ -557,7 +586,11 @@ def save_new_images(self, has_csam_images=False, use_remote_storage=False): from fastchat.utils import load_image, upload_image_file_to_gcs from PIL import Image - _, last_user_message = self.messages[-2] + last_user_message = None + for role, message in reversed(self.messages): + if role == "user": + last_user_message = message + break if type(last_user_message) == tuple: text, images = last_user_message[0], last_user_message[1] diff --git a/fastchat/serve/api_provider.py b/fastchat/serve/api_provider.py index a2b7979be..c3c64173b 100644 --- a/fastchat/serve/api_provider.py +++ b/fastchat/serve/api_provider.py @@ -1076,8 +1076,13 @@ def reka_api_stream_iter( api_key: Optional[str] = None, # default is env var CO_API_KEY api_base: Optional[str] = None, ): + from reka.client import Reka + from reka import TypedText + api_key = api_key or os.environ["REKA_API_KEY"] + client = Reka(api_key=api_key) + use_search_engine = False if "-online" in model_name: model_name = model_name.replace("-online", "") @@ -1094,34 +1099,27 @@ def reka_api_stream_iter( # Make requests for logging text_messages = [] - for message in messages: - text_messages.append({"type": message["type"], "text": message["text"]}) + for turn in messages: + for message in turn.content: + if isinstance(message, TypedText): + text_messages.append({"type": message.type, "text": message.text}) logged_request = dict(request) logged_request["conversation_history"] = text_messages logger.info(f"==== request ====\n{logged_request}") - response = requests.post( - api_base, - stream=True, - json=request, - headers={ - "X-Api-Key": api_key, - }, + response = client.chat.create_stream( + messages=messages, + max_tokens=max_new_tokens, + top_p=top_p, + model=model_name, ) - if response.status_code != 200: - error_message = response.text - logger.error(f"==== error from reka api: {error_message} ====") - yield { - "text": f"**API REQUEST ERROR** Reason: {error_message}", - "error_code": 1, - } - return - - for line in response.iter_lines(): - line = line.decode("utf8") - if not line.startswith("data: "): - continue - gen = json.loads(line[6:]) - yield {"text": gen["text"], "error_code": 0} + for chunk in response: + try: + yield {"text": chunk.responses[0].chunk.content, "error_code": 0} + except: + yield { + "text": f"**API REQUEST ERROR** ", + "error_code": 1, + } diff --git a/fastchat/serve/gradio_block_arena_anony.py b/fastchat/serve/gradio_block_arena_anony.py index dc9b89a0c..bfba8485e 100644 --- a/fastchat/serve/gradio_block_arena_anony.py +++ b/fastchat/serve/gradio_block_arena_anony.py @@ -32,24 +32,27 @@ acknowledgment_md, get_ip, get_model_description_md, + _write_to_json, ) +from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, - moderation_filter, ) logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") num_sides = 2 enable_moderation = False +use_remote_storage = False anony_names = ["", ""] models = [] -def set_global_vars_anony(enable_moderation_): - global enable_moderation +def set_global_vars_anony(enable_moderation_, use_remote_storage_): + global enable_moderation, use_remote_storage enable_moderation = enable_moderation_ + use_remote_storage = use_remote_storage_ def load_demo_side_by_side_anony(models_, url_params): @@ -202,6 +205,9 @@ def get_battle_pair( if len(models) == 1: return models[0], models[0] + if len(models) == 0: + raise ValueError("There are no models provided. Cannot get battle pair.") + model_weights = [] for model in models: weight = get_sample_weight( @@ -290,7 +296,11 @@ def add_text( all_conv_text = ( all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text ) - flagged = moderation_filter(all_conv_text, model_list, do_moderation=True) + + content_moderator = AzureAndOpenAIContentModerator() + flagged = content_moderator.text_moderation_filter( + all_conv_text, model_list, do_moderation=True + ) if flagged: logger.info(f"violate moderation (anony). ip: {ip}. text: {text}") # overwrite the original text @@ -343,18 +353,50 @@ def bot_response_multi( request: gr.Request, ): logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}") + states = [state0, state1] + + if states[0] is None or states[0].skip_next: + if ( + states[0].content_moderator.text_flagged + or states[0].content_moderator.nsfw_flagged + ): + for i in range(num_sides): + # This generate call is skipped due to invalid inputs + start_tstamp = time.time() + finish_tstamp = start_tstamp + states[i].conv.save_new_images( + has_csam_images=states[i].has_csam_image, + use_remote_storage=use_remote_storage, + ) + + filename = get_conv_log_filename( + is_vision=states[i].is_vision, + has_csam_image=states[i].has_csam_image, + ) + + _write_to_json( + filename, + start_tstamp, + finish_tstamp, + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + + # Remove the last message: the user input + states[i].conv.messages.pop() + states[i].content_moderator.update_last_moderation_response(None) - if state0 is None or state0.skip_next: - # This generate call is skipped due to invalid inputs yield ( - state0, - state1, - state0.to_gradio_chatbot(), - state1.to_gradio_chatbot(), + states[0], + states[1], + states[0].to_gradio_chatbot(), + states[1].to_gradio_chatbot(), ) + (no_change_btn,) * 6 return - states = [state0, state1] gen = [] for i in range(num_sides): gen.append( diff --git a/fastchat/serve/gradio_block_arena_named.py b/fastchat/serve/gradio_block_arena_named.py index 7ee19b041..6a102d0e2 100644 --- a/fastchat/serve/gradio_block_arena_named.py +++ b/fastchat/serve/gradio_block_arena_named.py @@ -28,22 +28,25 @@ acknowledgment_md, get_ip, get_model_description_md, + _write_to_json, ) +from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, - moderation_filter, ) logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") num_sides = 2 enable_moderation = False +use_remote_storage = False -def set_global_vars_named(enable_moderation_): - global enable_moderation +def set_global_vars_named(enable_moderation_, use_remote_storage_): + global enable_moderation, use_remote_storage enable_moderation = enable_moderation_ + use_remote_storage = use_remote_storage_ def load_demo_side_by_side_named(models, url_params): @@ -175,19 +178,27 @@ def add_text( no_change_btn, ] * 6 + + [True] ) model_list = [states[i].model_name for i in range(num_sides)] - all_conv_text_left = states[0].conv.get_prompt() - all_conv_text_right = states[1].conv.get_prompt() - all_conv_text = ( - all_conv_text_left[-1000:] + all_conv_text_right[-1000:] + "\nuser: " + text - ) - flagged = moderation_filter(all_conv_text, model_list) - if flagged: - logger.info(f"violate moderation (named). ip: {ip}. text: {text}") - # overwrite the original text - text = MODERATION_MSG + text_flagged = states[0].content_moderator.text_moderation_filter(text, model_list) + + if text_flagged: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + for i in range(num_sides): + states[i].skip_next = True + gr.Warning(MODERATION_MSG) + return ( + states + + [x.to_gradio_chatbot() for x in states] + + [""] + + [ + no_change_btn, + ] + * 6 + + [True] + ) conv = states[0].conv if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: @@ -202,6 +213,7 @@ def add_text( no_change_btn, ] * 6 + + [True] ) text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off @@ -218,6 +230,7 @@ def add_text( disable_btn, ] * 6 + + [False] ) @@ -231,17 +244,49 @@ def bot_response_multi( ): logger.info(f"bot_response_multi (named). ip: {get_ip(request)}") - if state0.skip_next: - # This generate call is skipped due to invalid inputs + states = [state0, state1] + if states[0].skip_next: + if ( + states[0].content_moderator.text_flagged + or states[0].content_moderator.nsfw_flagged + ): + for i in range(num_sides): + # This generate call is skipped due to invalid inputs + start_tstamp = time.time() + finish_tstamp = start_tstamp + states[i].conv.save_new_images( + has_csam_images=states[i].has_csam_image, + use_remote_storage=use_remote_storage, + ) + + filename = get_conv_log_filename( + is_vision=states[i].is_vision, + has_csam_image=states[i].has_csam_image, + ) + + _write_to_json( + filename, + start_tstamp, + finish_tstamp, + states[i], + temperature, + top_p, + max_new_tokens, + request, + ) + + # Remove the last message: the user input + states[i].conv.messages.pop() + states[i].content_moderator.update_last_moderation_response(None) + yield ( - state0, - state1, - state0.to_gradio_chatbot(), - state1.to_gradio_chatbot(), + states[0], + states[1], + states[0].to_gradio_chatbot(), + states[1].to_gradio_chatbot(), ) + (no_change_btn,) * 6 return - states = [state0, state1] gen = [] for i in range(num_sides): gen.append( @@ -296,14 +341,19 @@ def bot_response_multi( break -def flash_buttons(): +def flash_buttons(show_vote_buttons: bool = True): btn_updates = [ [disable_btn] * 4 + [enable_btn] * 2, [enable_btn] * 6, ] - for i in range(4): - yield btn_updates[i % 2] - time.sleep(0.3) + + if show_vote_buttons: + for i in range(4): + yield btn_updates[i % 2] + time.sleep(0.3) + else: + yield [no_change_btn] * 4 + [enable_btn] * 2 + return def build_side_by_side_ui_named(models): @@ -323,6 +373,7 @@ def build_side_by_side_ui_named(models): states = [gr.State() for _ in range(num_sides)] model_selectors = [None] * num_sides chatbots = [None] * num_sides + dont_show_vote_buttons = gr.State(False) notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") @@ -478,24 +529,24 @@ def build_side_by_side_ui_named(models): textbox.submit( add_text, states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list, + states + chatbots + [textbox] + btn_list + [dont_show_vote_buttons], ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( - flash_buttons, [], btn_list + flash_buttons, [dont_show_vote_buttons], btn_list ) send_btn.click( add_text, states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list, + states + chatbots + [textbox] + btn_list + [dont_show_vote_buttons], ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( - flash_buttons, [], btn_list + flash_buttons, [dont_show_vote_buttons], btn_list ) return states + model_selectors diff --git a/fastchat/serve/gradio_block_arena_vision.py b/fastchat/serve/gradio_block_arena_vision.py index 25ff78c08..dadb360e0 100644 --- a/fastchat/serve/gradio_block_arena_vision.py +++ b/fastchat/serve/gradio_block_arena_vision.py @@ -10,6 +10,7 @@ import json import os import time +from typing import List, Union import gradio as gr from gradio.data_classes import FileData @@ -27,6 +28,7 @@ from fastchat.model.model_adapter import ( get_conversation_template, ) +from fastchat.serve.gradio_global_state import Context from fastchat.serve.gradio_web_server import ( get_model_description_md, acknowledgment_md, @@ -37,11 +39,10 @@ get_conv_log_filename, get_remote_logger, ) +from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator from fastchat.serve.vision.image import ImageFormat, Image from fastchat.utils import ( build_logger, - moderation_filter, - image_moderation_filter, ) logger = build_logger("gradio_web_server", "gradio_web_server.log") @@ -52,8 +53,16 @@ invisible_btn = gr.Button(interactive=False, visible=False) visible_image_column = gr.Image(visible=True) invisible_image_column = gr.Image(visible=False) -enable_multimodal = gr.MultimodalTextbox( - interactive=True, visible=True, placeholder="Enter your prompt or add image here" +enable_multimodal_keep_input = gr.MultimodalTextbox( + interactive=True, + visible=True, + placeholder="Enter your prompt or add image here", +) +enable_multimodal_clear_input = gr.MultimodalTextbox( + interactive=True, + visible=True, + placeholder="Enter your prompt or add image here", + value={"text": "", "files": []}, ) invisible_text = gr.Textbox(visible=False, value="", interactive=False) visible_text = gr.Textbox( @@ -76,10 +85,6 @@ def set_visible_image(textbox): images = textbox["files"] if len(images) == 0: return invisible_image_column - elif len(images) > 1: - gr.Warning( - "We only support single image conversations. Please start a new round if you would like to chat using this image." - ) return visible_image_column @@ -138,6 +143,7 @@ def regenerate(state, request: gr.Request): state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 state.conv.update_last_message(None) + state.content_moderator.update_last_moderation_response(None) return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 @@ -145,14 +151,14 @@ def clear_history(request: gr.Request): ip = get_ip(request) logger.info(f"clear_history. ip: {ip}") state = None - return (state, [], None) + (disable_btn,) * 5 + return (state, [], enable_multimodal_clear_input) + (disable_btn,) * 5 def clear_history_example(request: gr.Request): ip = get_ip(request) logger.info(f"clear_history_example. ip: {ip}") state = None - return (state, [], enable_multimodal) + (disable_btn,) * 5 + return (state, [], enable_multimodal_keep_input) + (disable_btn,) * 5 # TODO(Chris): At some point, we would like this to be a live-reporting feature. @@ -160,13 +166,32 @@ def report_csam_image(state, image): pass -def _prepare_text_with_image(state, text, images, csam_flag): +def _prepare_text_with_image( + state: State, text: str, images: List[Image], context: Context +): if len(images) > 0: - if len(state.conv.get_images()) > 0: - # reset convo with new image - state.conv = get_conversation_template(state.model_name) + model_supports_multi_image = context.api_endpoint_info[state.model_name].get( + "multi_image", False + ) + num_previous_images = len(state.conv.get_images()) + images_interleaved_with_text_exists_but_model_does_not_support = ( + num_previous_images > 0 and not model_supports_multi_image + ) + multiple_image_one_turn_but_model_does_not_support = ( + len(images) > 1 and not model_supports_multi_image + ) + if images_interleaved_with_text_exists_but_model_does_not_support: + gr.Warning( + f"The model does not support interleaved image/text. We only use the very first image." + ) + return text + elif multiple_image_one_turn_but_model_does_not_support: + gr.Warning( + f"The model does not support multiple images. Only the first image will be used." + ) + return text, [images[0]] - text = text, [images[0]] + text = text, images return text @@ -178,45 +203,35 @@ def convert_images_to_conversation_format(images): MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB = 5 / 1.5 conv_images = [] if len(images) > 0: - conv_image = Image(url=images[0]) - conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB) - conv_images.append(conv_image) + for image in images: + conv_image = Image(url=image) + conv_image.to_conversation_format(MAX_NSFW_ENDPOINT_IMAGE_SIZE_IN_MB) + conv_images.append(conv_image) return conv_images -def moderate_input(state, text, all_conv_text, model_list, images, ip): - text_flagged = moderation_filter(all_conv_text, model_list) - # flagged = moderation_filter(text, [state.model_name]) - nsfw_flagged, csam_flagged = False, False - if len(images) > 0: - nsfw_flagged, csam_flagged = image_moderation_filter(images[0]) - - image_flagged = nsfw_flagged or csam_flagged - if text_flagged or image_flagged: - logger.info(f"violate moderation. ip: {ip}. text: {all_conv_text}") - if text_flagged and not image_flagged: - # overwrite the original text - text = TEXT_MODERATION_MSG - elif not text_flagged and image_flagged: - text = IMAGE_MODERATION_MSG - elif text_flagged and image_flagged: - text = MODERATION_MSG - - if csam_flagged: - state.has_csam_image = True - report_csam_image(state, images[0]) - - return text, image_flagged, csam_flagged - +def add_text(state, model_selector, chat_input, context: Context, request: gr.Request): + if isinstance(chat_input, dict): + text, images = chat_input["text"], chat_input["files"] + else: + text, images = chat_input, [] -def add_text(state, model_selector, chat_input, request: gr.Request): - text, images = chat_input["text"], chat_input["files"] + if ( + len(images) > 0 + and model_selector in context.text_models + and model_selector not in context.vision_models + ): + gr.Warning(f"{model_selector} is a text-only model. Image is ignored.") + images = [] ip = get_ip(request) logger.info(f"add_text. ip: {ip}. len: {len(text)}") if state is None: - state = State(model_selector, is_vision=True) + if len(images) == 0: + state = State(model_selector, is_vision=False) + else: + state = State(model_selector, is_vision=True) if len(text) <= 0: state.skip_next = True @@ -227,33 +242,69 @@ def add_text(state, model_selector, chat_input, request: gr.Request): images = convert_images_to_conversation_format(images) - text, image_flagged, csam_flag = moderate_input( - state, text, all_conv_text, [state.model_name], images, ip + # Use the first state to get the moderation response because this is based on user input so it is independent of the model + moderation_type_to_response_map = ( + state.content_moderator.image_and_text_moderation_filter( + images, text, [state.model_name], do_moderation=False + ) + ) + + text_flagged, nsfw_flag, csam_flag = ( + moderation_type_to_response_map["text_moderation"]["flagged"], + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["nsfw_moderation"] + ] + ), + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["csam_moderation"] + ] + ), ) - if image_flagged: - logger.info(f"image flagged. ip: {ip}. text: {text}") + if csam_flag: + state.has_csam_image = True + + state.content_moderator.append_moderation_response(moderation_type_to_response_map) + + if text_flagged or nsfw_flag: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + gradio_chatbot_before_user_input = state.to_gradio_chatbot() + post_processed_text = _prepare_text_with_image(state, text, images, context) + state.conv.append_message(state.conv.roles[0], post_processed_text) state.skip_next = True - return (state, state.to_gradio_chatbot(), {"text": IMAGE_MODERATION_MSG}) + ( - no_change_btn, - ) * 5 + gr.Warning(MODERATION_MSG) + return ( + state, + gradio_chatbot_before_user_input, + None, + ) + (no_change_btn,) * 5 if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {ip}. text: {text}") state.skip_next = True - return (state, state.to_gradio_chatbot(), {"text": CONVERSATION_LIMIT_MSG}) + ( - no_change_btn, - ) * 5 + return ( + state, + state.to_gradio_chatbot(), + {"text": CONVERSATION_LIMIT_MSG}, + ) + (no_change_btn,) * 5 text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off - text = _prepare_text_with_image(state, text, images, csam_flag=csam_flag) + text = _prepare_text_with_image(state, text, images, context) state.conv.append_message(state.conv.roles[0], text) state.conv.append_message(state.conv.roles[1], None) - return (state, state.to_gradio_chatbot(), None) + (disable_btn,) * 5 + return ( + state, + state.to_gradio_chatbot(), + None, + ) + (disable_btn,) * 5 def build_single_vision_language_model_ui( - models, add_promotion_links=False, random_questions=None + context: Context, add_promotion_links=False, random_questions=None ): promotion = ( f""" @@ -275,33 +326,29 @@ def build_single_vision_language_model_ui( state = gr.State() gr.Markdown(notice_markdown, elem_id="notice_markdown") + text_and_vision_models = list(set(context.text_models + context.vision_models)) + context_state = gr.State(context) with gr.Group(): with gr.Row(elem_id="model_selector_row"): model_selector = gr.Dropdown( - choices=models, - value=models[0] if len(models) > 0 else "", + choices=text_and_vision_models, + value=text_and_vision_models[0] + if len(text_and_vision_models) > 0 + else "", interactive=True, show_label=False, container=False, ) with gr.Accordion( - f"πŸ” Expand to see the descriptions of {len(models)} models", open=False + f"πŸ” Expand to see the descriptions of {len(text_and_vision_models)} models", + open=False, ): - model_description_md = get_model_description_md(models) + model_description_md = get_model_description_md(text_and_vision_models) gr.Markdown(model_description_md, elem_id="model_description_markdown") with gr.Row(): - textbox = gr.MultimodalTextbox( - file_types=["image"], - show_label=False, - placeholder="Enter your prompt or add image here", - container=True, - render=False, - elem_id="input_box", - ) - with gr.Column(scale=2, visible=False) as image_column: imagebox = gr.Image( type="pil", @@ -314,9 +361,13 @@ def build_single_vision_language_model_ui( ) with gr.Row(): - textbox.render() - # with gr.Column(scale=1, min_width=50): - # send_btn = gr.Button(value="Send", variant="primary") + multimodal_textbox = gr.MultimodalTextbox( + file_types=["image"], + show_label=False, + placeholder="Enter your prompt or add image here", + container=True, + elem_id="input_box", + ) with gr.Row(elem_id="buttons"): if random_questions: @@ -330,22 +381,6 @@ def build_single_vision_language_model_ui( regenerate_btn = gr.Button(value="πŸ”„ Regenerate", interactive=False) clear_btn = gr.Button(value="πŸ—‘οΈ Clear", interactive=False) - cur_dir = os.path.dirname(os.path.abspath(__file__)) - - examples = gr.Examples( - examples=[ - { - "text": "How can I prepare a delicious meal using these ingredients?", - "files": [f"{cur_dir}/example_images/fridge.jpg"], - }, - { - "text": "What might the woman on the right be thinking about?", - "files": [f"{cur_dir}/example_images/distracted.jpg"], - }, - ], - inputs=[textbox], - ) - with gr.Accordion("Parameters", open=False) as parameter_row: temperature = gr.Slider( minimum=0.0, @@ -380,40 +415,45 @@ def build_single_vision_language_model_ui( upvote_btn.click( upvote_last_response, [state, model_selector], - [textbox, upvote_btn, downvote_btn, flag_btn], + [multimodal_textbox, upvote_btn, downvote_btn, flag_btn], ) downvote_btn.click( downvote_last_response, [state, model_selector], - [textbox, upvote_btn, downvote_btn, flag_btn], + [multimodal_textbox, upvote_btn, downvote_btn, flag_btn], ) flag_btn.click( flag_last_response, [state, model_selector], - [textbox, upvote_btn, downvote_btn, flag_btn], + [multimodal_textbox, upvote_btn, downvote_btn, flag_btn], ) - regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( + regenerate_btn.click( + regenerate, state, [state, chatbot, multimodal_textbox] + btn_list + ).then( bot_response, [state, temperature, top_p, max_output_tokens], [state, chatbot] + btn_list, ) - clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) + clear_btn.click( + clear_history, + None, + [state, chatbot, multimodal_textbox] + btn_list, + ) model_selector.change( - clear_history, None, [state, chatbot, textbox] + btn_list - ).then(set_visible_image, [textbox], [image_column]) - examples.dataset.click( - clear_history_example, None, [state, chatbot, textbox] + btn_list - ) + clear_history, + None, + [state, chatbot, multimodal_textbox] + btn_list, + ).then(set_visible_image, [multimodal_textbox], [image_column]) - textbox.input(add_image, [textbox], [imagebox]).then( - set_visible_image, [textbox], [image_column] - ).then(clear_history_example, None, [state, chatbot, textbox] + btn_list) + multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then( + set_visible_image, [multimodal_textbox], [image_column] + ) - textbox.submit( + multimodal_textbox.submit( add_text, - [state, model_selector, textbox], - [state, chatbot, textbox] + btn_list, + [state, model_selector, multimodal_textbox, context_state], + [state, chatbot, multimodal_textbox] + btn_list, ).then(set_invisible_image, [], [image_column]).then( bot_response, [state, temperature, top_p, max_output_tokens], @@ -424,9 +464,7 @@ def build_single_vision_language_model_ui( random_btn.click( get_vqa_sample, # First, get the VQA sample [], # Pass the path to the VQA samples - [textbox, imagebox], # Outputs are textbox and imagebox - ).then(set_visible_image, [textbox], [image_column]).then( - clear_history_example, None, [state, chatbot, textbox] + btn_list - ) + [multimodal_textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [multimodal_textbox], [image_column]) return [state, model_selector] diff --git a/fastchat/serve/gradio_block_arena_vision_anony.py b/fastchat/serve/gradio_block_arena_vision_anony.py index 76bc47329..d24a54723 100644 --- a/fastchat/serve/gradio_block_arena_vision_anony.py +++ b/fastchat/serve/gradio_block_arena_vision_anony.py @@ -8,6 +8,7 @@ import gradio as gr import numpy as np +from typing import Union from fastchat.constants import ( TEXT_MODERATION_MSG, @@ -34,6 +35,9 @@ get_model_description_md, disable_text, enable_text, + use_remote_storage, + show_vote_button, + dont_show_vote_button, ) from fastchat.serve.gradio_block_arena_anony import ( flash_buttons, @@ -45,7 +49,6 @@ regenerate, clear_history, share_click, - add_text, bot_response_multi, set_global_vars_anony, load_demo_side_by_side_anony, @@ -60,19 +63,22 @@ set_invisible_image, set_visible_image, add_image, - moderate_input, - enable_multimodal, + enable_multimodal_keep_input, _prepare_text_with_image, convert_images_to_conversation_format, invisible_text, visible_text, disable_multimodal, + enable_multimodal_clear_input, ) +from fastchat.serve.moderation.moderator import ( + BaseContentModerator, + AzureAndOpenAIContentModerator, +) +from fastchat.serve.gradio_global_state import Context from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, - moderation_filter, - image_moderation_filter, ) logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") @@ -96,6 +102,7 @@ "llava-v1.6-34b": 4, "reka-core-20240501": 4, "reka-flash-preview-20240611": 4, + "reka-flash": 4, } # TODO(chris): Find battle targets that make sense @@ -115,16 +122,12 @@ def get_vqa_sample(): return (res, path) -def load_demo_side_by_side_vision_anony(all_text_models, all_vl_models, url_params): - global text_models, vl_models - text_models = all_text_models - vl_models = all_vl_models - - states = (None,) * num_sides - selector_updates = ( +def load_demo_side_by_side_vision_anony(): + states = [None] * num_sides + selector_updates = [ gr.Markdown(visible=True), gr.Markdown(visible=True), - ) + ] return states + selector_updates @@ -135,7 +138,7 @@ def clear_history_example(request: gr.Request): [None] * num_sides + [None] * num_sides + anony_names - + [enable_multimodal, invisible_text] + + [enable_multimodal_keep_input] + [invisible_btn] * 4 + [disable_btn] * 2 + [enable_btn] @@ -221,6 +224,7 @@ def regenerate(state0, state1, request: gr.Request): if state0.regen_support and state1.regen_support: for i in range(num_sides): states[i].conv.update_last_message(None) + states[i].content_moderator.update_last_moderation_response(None) return ( states + [x.to_gradio_chatbot() for x in states] @@ -240,7 +244,7 @@ def clear_history(request: gr.Request): [None] * num_sides + [None] * num_sides + anony_names - + [enable_multimodal, invisible_text] + + [enable_multimodal_clear_input] + [invisible_btn] * 4 + [disable_btn] * 2 + [enable_btn] @@ -249,7 +253,13 @@ def clear_history(request: gr.Request): def add_text( - state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request + state0, + state1, + model_selector0, + model_selector1, + chat_input: Union[str, dict], + context: Context, + request: gr.Request, ): if isinstance(chat_input, dict): text, images = chat_input["text"], chat_input["files"] @@ -268,7 +278,7 @@ def add_text( if len(images) > 0: model_left, model_right = get_battle_pair( - vl_models, + context.all_vision_models, VISION_BATTLE_TARGETS, VISION_OUTAGE_MODELS, VISION_SAMPLING_WEIGHTS, @@ -280,7 +290,7 @@ def add_text( ] else: model_left, model_right = get_battle_pair( - text_models, + context.all_text_models, BATTLE_TARGETS, OUTAGE_MODELS, SAMPLING_WEIGHTS, @@ -298,22 +308,49 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [None, ""] + + [None] + [ no_change_btn, ] * 7 + [""] + + [dont_show_vote_button] ) model_list = [states[i].model_name for i in range(num_sides)] images = convert_images_to_conversation_format(images) - text, image_flagged, csam_flag = moderate_input( - state0, text, text, model_list, images, ip + # Use the first state to get the moderation response because this is based on user input so it is independent of the model + moderation_type_to_response_map = states[ + 0 + ].content_moderator.image_and_text_moderation_filter( + images, text, model_list, do_moderation=True + ) + text_flagged, nsfw_flag, csam_flag = ( + moderation_type_to_response_map["text_moderation"]["flagged"], + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["nsfw_moderation"] + ] + ), + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["csam_moderation"] + ] + ), ) + if csam_flag: + states[0].has_csam_image, states[1].has_csam_image = True, True + + for state in states: + state.content_moderator.append_moderation_response( + moderation_type_to_response_map + ) + conv = states[0].conv if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}") @@ -322,37 +359,40 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [{"text": CONVERSATION_LIMIT_MSG}, ""] + + [{"text": CONVERSATION_LIMIT_MSG}] + [ no_change_btn, ] * 7 + [""] + + [dont_show_vote_button] ) - if image_flagged: - logger.info(f"image flagged. ip: {ip}. text: {text}") + if text_flagged or nsfw_flag: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + # We call this before appending the text so it does not appear in the UI + gradio_chatbot_list = [x.to_gradio_chatbot() for x in states] for i in range(num_sides): + post_processed_text = _prepare_text_with_image( + states[i], text, images, context + ) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].skip_next = True + gr.Warning(MODERATION_MSG) return ( states - + [x.to_gradio_chatbot() for x in states] + + gradio_chatbot_list + [ - { - "text": IMAGE_MODERATION_MSG - + " PLEASE CLICK 🎲 NEW ROUND TO START A NEW CONVERSATION." - }, - "", + None, ] - + [no_change_btn] * 7 + + [disable_btn] * 7 + [""] + + [dont_show_vote_button] ) text = text[:BLIND_MODE_INPUT_CHAR_LEN_LIMIT] # Hard cut-off for i in range(num_sides): - post_processed_text = _prepare_text_with_image( - states[i], text, images, csam_flag=csam_flag - ) + post_processed_text = _prepare_text_with_image(states[i], text, images, context) states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].conv.append_message(states[i].conv.roles[1], None) states[i].skip_next = False @@ -364,17 +404,18 @@ def add_text( return ( states + [x.to_gradio_chatbot() for x in states] - + [disable_multimodal, visible_text] + + [None] + [ disable_btn, ] * 7 + [hint_msg] + + [show_vote_button] ) -def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions=None): - notice_markdown = f""" +def build_side_by_side_vision_ui_anony(context: Context, random_questions=None): + notice_markdown = """ # βš”οΈ LMSYS Chatbot Arena (Multimodal): Benchmarking LLMs and VLMs in the Wild [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | [Kaggle Competition](https://www.kaggle.com/competitions/lmsys-chatbot-arena) @@ -395,8 +436,11 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions= states = [gr.State() for _ in range(num_sides)] model_selectors = [None] * num_sides chatbots = [None] * num_sides + show_vote_buttons = gr.State(True) + context_state = gr.State(context) gr.Markdown(notice_markdown, elem_id="notice_markdown") + text_and_vision_models = list(set(context.text_models + context.vision_models)) with gr.Row(): with gr.Column(scale=2, visible=False) as image_column: @@ -409,11 +453,11 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions= with gr.Column(scale=5): with gr.Group(elem_id="share-region-anony"): with gr.Accordion( - f"πŸ” Expand to see the descriptions of {len(text_models) + len(vl_models)} models", + f"πŸ” Expand to see the descriptions of {len(text_and_vision_models)} models", open=False, ): model_description_md = get_model_description_md( - text_models + vl_models + text_and_vision_models ) gr.Markdown( model_description_md, elem_id="model_description_markdown" @@ -452,13 +496,6 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions= ) with gr.Row(): - textbox = gr.Textbox( - show_label=False, - placeholder="πŸ‘‰ Enter your prompt and press ENTER", - elem_id="input_box", - visible=False, - ) - multimodal_textbox = gr.MultimodalTextbox( file_types=["image"], show_label=False, @@ -466,7 +503,6 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions= placeholder="Enter your prompt or add image here", elem_id="input_box", ) - # send_btn = gr.Button(value="Send", variant="primary", scale=0) with gr.Row() as button_row: if random_questions: @@ -518,25 +554,29 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions= leftvote_btn.click( leftvote_last_response, states + model_selectors, - model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + model_selectors + + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) rightvote_btn.click( rightvote_last_response, states + model_selectors, - model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + model_selectors + + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) tie_btn.click( tievote_last_response, states + model_selectors, - model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + model_selectors + + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) bothbad_btn.click( bothbad_vote_last_response, states + model_selectors, - model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + model_selectors + + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) regenerate_btn.click( - regenerate, states, states + chatbots + [textbox] + btn_list + regenerate, states, states + chatbots + [multimodal_textbox] + btn_list ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], @@ -550,7 +590,7 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions= states + chatbots + model_selectors - + [multimodal_textbox, textbox] + + [multimodal_textbox] + btn_list + [random_btn] + [slow_warning], @@ -580,47 +620,25 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions= multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then( set_visible_image, [multimodal_textbox], [image_column] - ).then( - clear_history_example, - None, - states + chatbots + model_selectors + [multimodal_textbox, textbox] + btn_list, ) multimodal_textbox.submit( add_text, - states + model_selectors + [multimodal_textbox], + states + model_selectors + [multimodal_textbox, context_state], states + chatbots - + [multimodal_textbox, textbox] + + [multimodal_textbox] + btn_list + [random_btn] - + [slow_warning], + + [slow_warning] + + [show_vote_buttons], ).then(set_invisible_image, [], [image_column]).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], states + chatbots + btn_list, ).then( flash_buttons, - [], - btn_list, - ) - - textbox.submit( - add_text, - states + model_selectors + [textbox], - states - + chatbots - + [multimodal_textbox, textbox] - + btn_list - + [random_btn] - + [slow_warning], - ).then( - bot_response_multi, - states + [temperature, top_p, max_output_tokens], - states + chatbots + btn_list, - ).then( - flash_buttons, - [], + [show_vote_buttons], btn_list, ) @@ -629,15 +647,6 @@ def build_side_by_side_vision_ui_anony(text_models, vl_models, random_questions= get_vqa_sample, # First, get the VQA sample [], # Pass the path to the VQA samples [multimodal_textbox, imagebox], # Outputs are textbox and imagebox - ).then(set_visible_image, [multimodal_textbox], [image_column]).then( - clear_history_example, - None, - states - + chatbots - + model_selectors - + [multimodal_textbox, textbox] - + btn_list - + [random_btn], - ) + ).then(set_visible_image, [multimodal_textbox], [image_column]) return states + model_selectors diff --git a/fastchat/serve/gradio_block_arena_vision_named.py b/fastchat/serve/gradio_block_arena_vision_named.py index ecca169ca..218eb344e 100644 --- a/fastchat/serve/gradio_block_arena_vision_named.py +++ b/fastchat/serve/gradio_block_arena_vision_named.py @@ -6,6 +6,7 @@ import json import os import time +from typing import List, Union import gradio as gr import numpy as np @@ -31,11 +32,20 @@ set_invisible_image, set_visible_image, add_image, - moderate_input, _prepare_text_with_image, convert_images_to_conversation_format, - enable_multimodal, + enable_multimodal_keep_input, + enable_multimodal_clear_input, + disable_multimodal, + invisible_text, + invisible_btn, + visible_text, +) +from fastchat.serve.moderation.moderator import ( + BaseContentModerator, + AzureAndOpenAIContentModerator, ) +from fastchat.serve.gradio_global_state import Context from fastchat.serve.gradio_web_server import ( State, bot_response, @@ -52,8 +62,6 @@ from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, - moderation_filter, - image_moderation_filter, ) @@ -63,12 +71,35 @@ enable_moderation = False +def load_demo_side_by_side_vision_named(context: Context): + states = [None] * num_sides + + # default to the text models + models = context.text_models + + model_left = models[0] if len(models) > 0 else "" + if len(models) > 1: + weights = ([8] * 4 + [4] * 8 + [1] * 64)[: len(models) - 1] + weights = weights / np.sum(weights) + model_right = np.random.choice(models[1:], p=weights) + else: + model_right = model_left + + all_models = list(set(context.text_models + context.vision_models)) + selector_updates = [ + gr.Dropdown(choices=all_models, value=model_left, visible=True), + gr.Dropdown(choices=all_models, value=model_right, visible=True), + ] + + return states + selector_updates + + def clear_history_example(request: gr.Request): logger.info(f"clear_history_example (named). ip: {get_ip(request)}") return ( [None] * num_sides + [None] * num_sides - + [enable_multimodal] + + [enable_multimodal_keep_input] + [invisible_btn] * 4 + [disable_btn] * 2 ) @@ -134,6 +165,7 @@ def regenerate(state0, state1, request: gr.Request): if state0.regen_support and state1.regen_support: for i in range(num_sides): states[i].conv.update_last_message(None) + states[i].content_moderator.update_last_moderation_response(None) return ( states + [x.to_gradio_chatbot() for x in states] @@ -152,16 +184,40 @@ def clear_history(request: gr.Request): return ( [None] * num_sides + [None] * num_sides - + [enable_multimodal] + + [enable_multimodal_clear_input] + [invisible_btn] * 4 + [disable_btn] * 2 ) def add_text( - state0, state1, model_selector0, model_selector1, chat_input, request: gr.Request + state0, + state1, + model_selector0, + model_selector1, + chat_input: Union[str, dict], + context: Context, + request: gr.Request, ): - text, images = chat_input["text"], chat_input["files"] + if isinstance(chat_input, dict): + text, images = chat_input["text"], chat_input["files"] + else: + text, images = chat_input, [] + + if len(images) > 0: + if ( + model_selector0 in context.text_models + and model_selector0 not in context.vision_models + ): + gr.Warning(f"{model_selector0} is a text-only model. Image is ignored.") + images = [] + if ( + model_selector1 in context.text_models + and model_selector1 not in context.vision_models + ): + gr.Warning(f"{model_selector1} is a text-only model. Image is ignored.") + images = [] + ip = get_ip(request) logger.info(f"add_text (named). ip: {ip}. len: {len(text)}") states = [state0, state1] @@ -169,7 +225,9 @@ def add_text( # Init states if necessary for i in range(num_sides): - if states[i] is None: + if states[i] is None and len(images) == 0: + states[i] = State(model_selectors[i], is_vision=False) + elif states[i] is None and len(images) > 0: states[i] = State(model_selectors[i], is_vision=True) if len(text) <= 0: @@ -194,10 +252,41 @@ def add_text( images = convert_images_to_conversation_format(images) - text, image_flagged, csam_flag = moderate_input( - state0, text, all_conv_text, model_list, images, ip + # Use the first state to get the moderation response because this is based on user input so it is independent of the model + moderation_image_input = images if len(images) > 0 else None + moderation_type_to_response_map = states[ + 0 + ].content_moderator.image_and_text_moderation_filter( + moderation_image_input, text, model_list, do_moderation=False + ) + + text_flagged, nsfw_flag, csam_flag = ( + moderation_type_to_response_map["text_moderation"]["flagged"], + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["nsfw_moderation"] + ] + ), + any( + [ + response["flagged"] + for response in moderation_type_to_response_map["csam_moderation"] + ] + ), ) + if csam_flag: + states[0].has_csam_image, states[1].has_csam_image = True, True + + for state in states: + state.content_moderator.append_moderation_response( + moderation_type_to_response_map + ) + + if text_flagged or nsfw_flag: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + conv = states[0].conv if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {ip}. text: {text}") @@ -213,14 +302,20 @@ def add_text( * 6 ) - if image_flagged: - logger.info(f"image flagged. ip: {ip}. text: {text}") + if text_flagged or nsfw_flag: + logger.info(f"violate moderation. ip: {ip}. text: {text}") + gradio_chatbot_list = [x.to_gradio_chatbot() for x in states] for i in range(num_sides): + post_processed_text = _prepare_text_with_image( + states[i], text, images, context + ) + states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].skip_next = True + gr.Warning(MODERATION_MSG) return ( states - + [x.to_gradio_chatbot() for x in states] - + [{"text": IMAGE_MODERATION_MSG}] + + gradio_chatbot_list + + [None] + [ no_change_btn, ] @@ -230,7 +325,10 @@ def add_text( text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off for i in range(num_sides): post_processed_text = _prepare_text_with_image( - states[i], text, images, csam_flag=csam_flag + states[i], + text, + images, + context, ) states[i].conv.append_message(states[i].conv.roles[0], post_processed_text) states[i].conv.append_message(states[i].conv.roles[1], None) @@ -247,8 +345,8 @@ def add_text( ) -def build_side_by_side_vision_ui_named(models, random_questions=None): - notice_markdown = f""" +def build_side_by_side_vision_ui_named(context: Context, random_questions=None): + notice_markdown = """ # βš”οΈ LMSYS Chatbot Arena (Multimodal): Benchmarking LLMs and VLMs in the Wild [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2403.04132) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) @@ -271,6 +369,9 @@ def build_side_by_side_vision_ui_named(models, random_questions=None): notice = gr.Markdown(notice_markdown, elem_id="notice_markdown") + text_and_vision_models = list(set(context.text_models + context.vision_models)) + context_state = gr.State(context) + with gr.Row(): with gr.Column(scale=2, visible=False) as image_column: imagebox = gr.Image( @@ -282,10 +383,12 @@ def build_side_by_side_vision_ui_named(models, random_questions=None): with gr.Column(scale=5): with gr.Group(elem_id="share-region-anony"): with gr.Accordion( - f"πŸ” Expand to see the descriptions of {len(models)} models", + f"πŸ” Expand to see the descriptions of {len(text_and_vision_models)} models", open=False, ): - model_description_md = get_model_description_md(models) + model_description_md = get_model_description_md( + text_and_vision_models + ) gr.Markdown( model_description_md, elem_id="model_description_markdown" ) @@ -294,8 +397,10 @@ def build_side_by_side_vision_ui_named(models, random_questions=None): for i in range(num_sides): with gr.Column(): model_selectors[i] = gr.Dropdown( - choices=models, - value=models[i] if len(models) > i else "", + choices=text_and_vision_models, + value=text_and_vision_models[i] + if len(text_and_vision_models) > i + else "", interactive=True, show_label=False, container=False, @@ -325,7 +430,7 @@ def build_side_by_side_vision_ui_named(models, random_questions=None): ) with gr.Row(): - textbox = gr.MultimodalTextbox( + multimodal_textbox = gr.MultimodalTextbox( file_types=["image"], show_label=False, placeholder="Enter your prompt or add image here", @@ -383,25 +488,25 @@ def build_side_by_side_vision_ui_named(models, random_questions=None): leftvote_btn.click( leftvote_last_response, states + model_selectors, - [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) rightvote_btn.click( rightvote_last_response, states + model_selectors, - [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) tie_btn.click( tievote_last_response, states + model_selectors, - [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) bothbad_btn.click( bothbad_vote_last_response, states + model_selectors, - [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], + [multimodal_textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn], ) regenerate_btn.click( - regenerate, states, states + chatbots + [textbox] + btn_list + regenerate, states, states + chatbots + [multimodal_textbox] + btn_list ).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], @@ -409,7 +514,11 @@ def build_side_by_side_vision_ui_named(models, random_questions=None): ).then( flash_buttons, [], btn_list ) - clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list) + clear_btn.click( + clear_history, + None, + states + chatbots + [multimodal_textbox] + btn_list, + ) share_js = """ function (a, b, c, d) { @@ -435,17 +544,19 @@ def build_side_by_side_vision_ui_named(models, random_questions=None): for i in range(num_sides): model_selectors[i].change( - clear_history, None, states + chatbots + [textbox] + btn_list - ).then(set_visible_image, [textbox], [image_column]) + clear_history, + None, + states + chatbots + [multimodal_textbox] + btn_list, + ).then(set_visible_image, [multimodal_textbox], [image_column]) - textbox.input(add_image, [textbox], [imagebox]).then( - set_visible_image, [textbox], [image_column] - ).then(clear_history_example, None, states + chatbots + [textbox] + btn_list) + multimodal_textbox.input(add_image, [multimodal_textbox], [imagebox]).then( + set_visible_image, [multimodal_textbox], [image_column] + ) - textbox.submit( + multimodal_textbox.submit( add_text, - states + model_selectors + [textbox], - states + chatbots + [textbox] + btn_list, + states + model_selectors + [multimodal_textbox, context_state], + states + chatbots + [multimodal_textbox] + btn_list, ).then(set_invisible_image, [], [image_column]).then( bot_response_multi, states + [temperature, top_p, max_output_tokens], @@ -458,9 +569,7 @@ def build_side_by_side_vision_ui_named(models, random_questions=None): random_btn.click( get_vqa_sample, # First, get the VQA sample [], # Pass the path to the VQA samples - [textbox, imagebox], # Outputs are textbox and imagebox - ).then(set_visible_image, [textbox], [image_column]).then( - clear_history_example, None, states + chatbots + [textbox] + btn_list - ) + [multimodal_textbox, imagebox], # Outputs are textbox and imagebox + ).then(set_visible_image, [multimodal_textbox], [image_column]) return states + model_selectors diff --git a/fastchat/serve/gradio_global_state.py b/fastchat/serve/gradio_global_state.py new file mode 100644 index 000000000..de911985d --- /dev/null +++ b/fastchat/serve/gradio_global_state.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass, field +from typing import List + + +@dataclass +class Context: + text_models: List[str] = field(default_factory=list) + all_text_models: List[str] = field(default_factory=list) + vision_models: List[str] = field(default_factory=list) + all_vision_models: List[str] = field(default_factory=list) + api_endpoint_info: dict = field(default_factory=dict) diff --git a/fastchat/serve/gradio_web_server.py b/fastchat/serve/gradio_web_server.py index 2ef47b14d..46ff0c421 100644 --- a/fastchat/serve/gradio_web_server.py +++ b/fastchat/serve/gradio_web_server.py @@ -11,6 +11,7 @@ import random import time import uuid +from typing import List import gradio as gr import requests @@ -33,12 +34,13 @@ ) from fastchat.model.model_registry import get_model_info, model_info from fastchat.serve.api_provider import get_api_provider_stream_iter +from fastchat.serve.moderation.moderator import AzureAndOpenAIContentModerator +from fastchat.serve.gradio_global_state import Context from fastchat.serve.remote_logger import get_remote_logger from fastchat.utils import ( build_logger, get_window_url_params_js, get_window_url_params_with_tos_js, - moderation_filter, parse_gradio_auth_creds, load_image, ) @@ -59,6 +61,8 @@ visible=True, placeholder='Press "🎲 New Round" to start overπŸ‘‡ (Note: Your vote shapes the leaderboard, please vote RESPONSIBLY!)', ) +show_vote_button = True +dont_show_vote_button = False controller_url = None enable_moderation = False @@ -117,6 +121,7 @@ def __init__(self, model_name, is_vision=False): self.model_name = model_name self.oai_thread_id = None self.is_vision = is_vision + self.content_moderator = AzureAndOpenAIContentModerator() # NOTE(chris): This could be sort of a hack since it assumes the user only uploads one image. If they can upload multiple, we should store a list of image hashes. self.has_csam_image = False @@ -143,6 +148,7 @@ def dict(self): { "conv_id": self.conv_id, "model_name": self.model_name, + "moderation": self.content_moderator.conv_moderation_responses, } ) @@ -151,7 +157,11 @@ def dict(self): return base -def set_global_vars(controller_url_, enable_moderation_, use_remote_storage_): +def set_global_vars( + controller_url_, + enable_moderation_, + use_remote_storage_, +): global controller_url, enable_moderation, use_remote_storage controller_url = controller_url_ enable_moderation = enable_moderation_ @@ -218,16 +228,27 @@ def get_model_list(controller_url, register_api_endpoint_file, vision_arena): return visible_models, models -def load_demo_single(models, url_params): +def _get_api_endpoint_info(): + global api_endpoint_info + return api_endpoint_info + + +def load_demo_single(context: Context, query_params): + # default to text models + models = context.text_models + selected_model = models[0] if len(models) > 0 else "" - if "model" in url_params: - model = url_params["model"] + if "model" in query_params: + model = query_params["model"] if model in models: selected_model = model - dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True) + all_models = list(set(context.text_models + context.vision_models)) + dropdown_update = gr.Dropdown( + choices=all_models, value=selected_model, visible=True + ) state = None - return state, dropdown_update + return [state, dropdown_update] def load_demo(url_params, request: gr.Request): @@ -289,6 +310,7 @@ def regenerate(state, request: gr.Request): state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 state.conv.update_last_message(None) + state.content_moderator.update_last_moderation_response(None) return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5 @@ -322,14 +344,24 @@ def add_text(state, model_selector, text, request: gr.Request): state.skip_next = True return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 - all_conv_text = state.conv.get_prompt() - all_conv_text = all_conv_text[-2000:] + "\nuser: " + text - flagged = moderation_filter(all_conv_text, [state.model_name]) - # flagged = moderation_filter(text, [state.model_name]) - if flagged: + content_moderator = AzureAndOpenAIContentModerator() + text_flagged = content_moderator.text_moderation_filter(text, [state.model_name]) + + if text_flagged: logger.info(f"violate moderation. ip: {ip}. text: {text}") # overwrite the original text - text = MODERATION_MSG + content_moderator.write_to_json(get_ip(request)) + state.skip_next = True + gr.Warning(MODERATION_MSG) + return ( + [state] + + [state.to_gradio_chatbot()] + + [""] + + [ + no_change_btn, + ] + * 5 + ) if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: logger.info(f"conversation turn limit. ip: {ip}. text: {text}") @@ -400,6 +432,34 @@ def is_limit_reached(model_name, ip): return None +def _write_to_json( + filename: str, + start_tstamp: float, + finish_tstamp: float, + state: State, + temperature: float, + top_p: float, + max_new_tokens: int, + request: gr.Request, +): + with open(filename, "a") as fout: + data = { + "tstamp": round(finish_tstamp, 4), + "type": "chat", + "model": state.model_name, + "gen_params": { + "temperature": temperature, + "top_p": top_p, + "max_new_tokens": max_new_tokens, + }, + "start": round(start_tstamp, 4), + "finish": round(finish_tstamp, 4), + "state": state.dict(), + "ip": get_ip(request), + } + fout.write(json.dumps(data) + "\n") + + def bot_response( state, temperature, @@ -419,6 +479,32 @@ def bot_response( if state.skip_next: # This generate call is skipped due to invalid inputs state.skip_next = False + if state.content_moderator.text_flagged or state.content_moderator.nsfw_flagged: + start_tstamp = time.time() + finish_tstamp = start_tstamp + state.conv.save_new_images( + has_csam_images=state.has_csam_image, + use_remote_storage=use_remote_storage, + ) + + filename = get_conv_log_filename( + is_vision=state.is_vision, has_csam_image=state.has_csam_image + ) + + _write_to_json( + filename, + start_tstamp, + finish_tstamp, + state, + temperature, + top_p, + max_new_tokens, + request, + ) + + # Remove the last message: the user input + state.conv.messages.pop() + state.content_moderator.update_last_moderation_response(None) yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 return @@ -570,22 +656,23 @@ def bot_response( is_vision=state.is_vision, has_csam_image=state.has_csam_image ) - with open(filename, "a") as fout: - data = { - "tstamp": round(finish_tstamp, 4), - "type": "chat", - "model": model_name, - "gen_params": { - "temperature": temperature, - "top_p": top_p, - "max_new_tokens": max_new_tokens, - }, - "start": round(start_tstamp, 4), - "finish": round(finish_tstamp, 4), - "state": state.dict(), - "ip": get_ip(request), - } - fout.write(json.dumps(data) + "\n") + moderation_type_to_response_map = ( + state.content_moderator.image_and_text_moderation_filter( + None, output, [state.model_name], do_moderation=True + ) + ) + state.content_moderator.append_moderation_response(moderation_type_to_response_map) + + _write_to_json( + filename, + start_tstamp, + finish_tstamp, + state, + temperature, + top_p, + max_new_tokens, + request, + ) get_remote_logger().log(data) diff --git a/fastchat/serve/gradio_web_server_multi.py b/fastchat/serve/gradio_web_server_multi.py index 14f254bf3..f19fab6ec 100644 --- a/fastchat/serve/gradio_web_server_multi.py +++ b/fastchat/serve/gradio_web_server_multi.py @@ -6,6 +6,7 @@ import argparse import pickle import time +from typing import List import gradio as gr @@ -28,7 +29,9 @@ ) from fastchat.serve.gradio_block_arena_vision_named import ( build_side_by_side_vision_ui_named, + load_demo_side_by_side_vision_named, ) +from fastchat.serve.gradio_global_state import Context from fastchat.serve.gradio_web_server import ( set_global_vars, @@ -38,6 +41,7 @@ get_model_list, load_demo_single, get_ip, + _get_api_endpoint_info, ) from fastchat.serve.monitor.monitor import build_leaderboard_tab from fastchat.utils import ( @@ -51,57 +55,66 @@ logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log") -def load_demo(url_params, request: gr.Request): - global models, all_models, vl_models, all_vl_models - +def load_demo(context: Context, url_params, request: gr.Request): ip = get_ip(request) logger.info(f"load_demo. ip: {ip}. params: {url_params}") inner_selected = 0 - if "arena" in url_params: + if "arena" in request.query_params: inner_selected = 0 - elif "vision" in url_params: - inner_selected = 1 - elif "compare" in url_params: + elif "vision" in request.query_params: + inner_selected = 0 + elif "compare" in request.query_params: inner_selected = 1 - elif "direct" in url_params or "model" in url_params: + elif "direct" in request.query_params or "model" in request.query_params: + inner_selected = 2 + elif "leaderboard" in request.query_params: inner_selected = 3 - elif "leaderboard" in url_params: + elif "about" in request.query_params: inner_selected = 4 - elif "about" in url_params: - inner_selected = 5 if args.model_list_mode == "reload": - models, all_models = get_model_list( + context.text_models, context.all_text_models = get_model_list( args.controller_url, args.register_api_endpoint_file, vision_arena=False, ) - vl_models, all_vl_models = get_model_list( + context.vision_models, context.all_vision_models = get_model_list( args.controller_url, args.register_api_endpoint_file, vision_arena=True, ) - single_updates = load_demo_single(models, url_params) - side_by_side_anony_updates = load_demo_side_by_side_anony(all_models, url_params) - side_by_side_named_updates = load_demo_side_by_side_named(models, url_params) + # Text models + if args.vision_arena: + side_by_side_anony_updates = load_demo_side_by_side_vision_anony() - side_by_side_vision_anony_updates = load_demo_side_by_side_vision_anony( - all_models, all_vl_models, url_params - ) + side_by_side_named_updates = load_demo_side_by_side_vision_named( + context, + ) - return ( - (gr.Tabs(selected=inner_selected),) - + single_updates + direct_chat_updates = load_demo_single(context, request.query_params) + else: + direct_chat_updates = load_demo_single(context, request.query_params) + side_by_side_anony_updates = load_demo_side_by_side_anony( + context.all_text_models, request.query_params + ) + side_by_side_named_updates = load_demo_side_by_side_named( + context.text_models, request.query_params + ) + + tabs_list = ( + [gr.Tabs(selected=inner_selected)] + side_by_side_anony_updates + side_by_side_named_updates - + side_by_side_vision_anony_updates + + direct_chat_updates ) + return tabs_list -def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): + +def build_demo(context: Context, elo_results_file: str, leaderboard_table_file): if args.show_terms_of_use: load_js = get_window_url_params_with_tos_js else: @@ -134,41 +147,62 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): with gr.Tab("βš”οΈ Arena (battle)", id=0) as arena_tab: arena_tab.select(None, None, None, js=load_js) side_by_side_anony_list = build_side_by_side_vision_ui_anony( - all_models, - all_vl_models, + context, random_questions=args.random_questions, ) + with gr.Tab("βš”οΈ Arena (side-by-side)", id=1) as side_by_side_tab: + side_by_side_tab.select(None, None, None, js=alert_js) + side_by_side_named_list = build_side_by_side_vision_ui_named( + context, random_questions=args.random_questions + ) + + with gr.Tab("πŸ’¬ Direct Chat", id=2) as direct_tab: + direct_tab.select(None, None, None, js=alert_js) + single_model_list = build_single_vision_language_model_ui( + context, + add_promotion_links=True, + random_questions=args.random_questions, + ) + else: with gr.Tab("βš”οΈ Arena (battle)", id=0) as arena_tab: arena_tab.select(None, None, None, js=load_js) - side_by_side_anony_list = build_side_by_side_ui_anony(models) + side_by_side_anony_list = build_side_by_side_ui_anony( + context.all_text_models + ) - with gr.Tab("βš”οΈ Arena (side-by-side)", id=2) as side_by_side_tab: - side_by_side_tab.select(None, None, None, js=alert_js) - side_by_side_named_list = build_side_by_side_ui_named(models) + with gr.Tab("βš”οΈ Arena (side-by-side)", id=1) as side_by_side_tab: + side_by_side_tab.select(None, None, None, js=alert_js) + side_by_side_named_list = build_side_by_side_ui_named( + context.text_models + ) - with gr.Tab("πŸ’¬ Direct Chat", id=3) as direct_tab: - direct_tab.select(None, None, None, js=alert_js) - single_model_list = build_single_model_ui( - models, add_promotion_links=True - ) + with gr.Tab("πŸ’¬ Direct Chat", id=2) as direct_tab: + direct_tab.select(None, None, None, js=alert_js) + single_model_list = build_single_model_ui( + context.text_models, add_promotion_links=True + ) demo_tabs = ( [inner_tabs] - + single_model_list + side_by_side_anony_list + side_by_side_named_list + + single_model_list ) if elo_results_file: - with gr.Tab("πŸ† Leaderboard", id=4): + with gr.Tab("πŸ† Leaderboard", id=3): build_leaderboard_tab( - elo_results_file, leaderboard_table_file, show_plot=True + elo_results_file, + leaderboard_table_file, + arena_hard_leaderboard=None, + show_plot=True, ) - with gr.Tab("ℹ️ About Us", id=5): + with gr.Tab("ℹ️ About Us", id=4): about = build_about() + context_state = gr.State(context) url_params = gr.JSON(visible=False) if args.model_list_mode not in ["once", "reload"]: @@ -176,7 +210,7 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): demo.load( load_demo, - [url_params], + [context_state, url_params], demo_tabs, js=load_js, ) @@ -272,19 +306,28 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): # Set global variables set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) - set_global_vars_named(args.moderate) - set_global_vars_anony(args.moderate) - models, all_models = get_model_list( + set_global_vars_named(args.moderate, args.use_remote_storage) + set_global_vars_anony(args.moderate, args.use_remote_storage) + text_models, all_text_models = get_model_list( args.controller_url, args.register_api_endpoint_file, vision_arena=False, ) - vl_models, all_vl_models = get_model_list( + vision_models, all_vision_models = get_model_list( args.controller_url, args.register_api_endpoint_file, vision_arena=True, ) + api_endpoint_info = _get_api_endpoint_info() + + context = Context( + text_models, + all_text_models, + vision_models, + all_vision_models, + api_endpoint_info, + ) # Set authorization credentials auth = None @@ -293,8 +336,7 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file): # Launch the demo demo = build_demo( - models, - all_vl_models, + context, args.elo_results_file, args.leaderboard_table_file, ) diff --git a/fastchat/serve/moderation/moderator.py b/fastchat/serve/moderation/moderator.py new file mode 100644 index 000000000..efafdcd88 --- /dev/null +++ b/fastchat/serve/moderation/moderator.py @@ -0,0 +1,268 @@ +import datetime +import hashlib +import os +import json +import time +import base64 +import requests +from typing import Tuple, Dict, List, Union + +from fastchat.constants import LOGDIR +from fastchat.serve.vision.image import Image +from fastchat.utils import load_image, upload_image_file_to_gcs + + +class BaseContentModerator: + def __init__(self): + self.conv_moderation_responses: List[ + Dict[str, Dict[str, Union[str, Dict[str, float]]]] + ] = [] + + def _image_moderation_filter(self, image: Image) -> Tuple[bool, bool]: + """Function that detects whether image violates moderation policies. + + Returns: + Tuple[bool, bool]: A tuple of two boolean values indicating whether the image was flagged for nsfw and csam respectively. + """ + raise NotImplementedError + + def _text_moderation_filter(self, text: str) -> bool: + """Function that detects whether text violates moderation policies. + + Returns: + bool: A boolean value indicating whether the text was flagged. + """ + raise NotImplementedError + + def image_and_text_moderation_filter( + self, images: List[Image], text: str + ) -> Dict[str, Dict[str, Union[str, Dict[str, float]]]]: + """Function that detects whether image and text violate moderation policies. + + Returns: + Dict[str, Dict[str, Union[str, Dict[str, float]]]]: A dictionary that maps the type of moderation (text, nsfw, csam) to a dictionary that contains the moderation response. + """ + raise NotImplementedError + + def append_moderation_response( + self, moderation_response: Dict[str, Dict[str, Union[str, Dict[str, float]]]] + ): + """Function that appends the moderation response to the list of moderation responses.""" + if ( + len(self.conv_moderation_responses) == 0 + or self.conv_moderation_responses[-1] is not None + ): + self.conv_moderation_responses.append(moderation_response) + else: + self.update_last_moderation_response(moderation_response) + + def update_last_moderation_response( + self, moderation_response: Dict[str, Dict[str, Union[str, Dict[str, float]]]] + ): + """Function that updates the last moderation response.""" + self.conv_moderation_responses[-1] = moderation_response + + +class AzureAndOpenAIContentModerator(BaseContentModerator): + _NON_TOXIC_IMAGE_MODERATION_MAP = { + "nsfw_moderation": [{"flagged": False}], + "csam_moderation": [{"flagged": False}], + } + + def __init__(self, use_remote_storage: bool = False): + """This class is used to moderate content using Azure and OpenAI. + + conv_to_moderation_responses: A dictionary that is a map from the type of moderation + (text, nsfw, csam) moderation to the moderation response returned from the request sent + to the moderation API. + """ + super().__init__() + self.text_flagged = False + self.csam_flagged = False + self.nsfw_flagged = False + + def _image_moderation_request( + self, image_bytes: bytes, endpoint: str, api_key: str + ) -> dict: + headers = {"Content-Type": "image/jpeg", "Ocp-Apim-Subscription-Key": api_key} + + MAX_RETRIES = 3 + for _ in range(MAX_RETRIES): + response = requests.post(endpoint, headers=headers, data=image_bytes).json() + try: + if response["Status"]["Code"] == 3000: + break + except: + time.sleep(0.5) + return response + + def _image_moderation_provider(self, image_bytes: bytes, api_type: str) -> bool: + if api_type == "nsfw": + endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"] + api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"] + response = self._image_moderation_request(image_bytes, endpoint, api_key) + flagged = response["IsImageAdultClassified"] + elif api_type == "csam": + endpoint = ( + "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false" + ) + api_key = os.environ["PHOTODNA_API_KEY"] + response = self._image_moderation_request(image_bytes, endpoint, api_key) + flagged = response["IsMatch"] + + image_md5_hash = hashlib.md5(image_bytes).hexdigest() + moderation_response_map = { + "image_hash": image_md5_hash, + "response": response, + "flagged": False, + } + if flagged: + moderation_response_map["flagged"] = True + + return moderation_response_map + + def image_moderation_filter(self, images: List[Image]): + print(f"moderating images") + + images_moderation_response: Dict[ + str, List[Dict[str, Union[str, Dict[str, float]]]] + ] = { + "nsfw_moderation": [], + "csam_moderation": [], + } + + for image in images: + image_bytes = base64.b64decode(image.base64_str) + + nsfw_flagged_map = self._image_moderation_provider(image_bytes, "nsfw") + + if nsfw_flagged_map["flagged"]: + csam_flagged_map = self._image_moderation_provider(image_bytes, "csam") + else: + csam_flagged_map = {"flagged": False} + + self.nsfw_flagged |= nsfw_flagged_map["flagged"] + self.csam_flagged |= csam_flagged_map["flagged"] + + images_moderation_response["nsfw_moderation"].append(nsfw_flagged_map) + images_moderation_response["csam_moderation"].append(csam_flagged_map) + + return images_moderation_response + + def _openai_moderation_filter( + self, text: str, custom_thresholds: dict = None + ) -> bool: + """ + Check whether the text violates OpenAI moderation API. + """ + import openai + + client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) + + # default to true to be conservative + flagged = True + MAX_RETRY = 3 + moderation_response_map = {"content": text, "response": None, "flagged": False} + for _ in range(MAX_RETRY): + try: + res = client.moderations.create(input=text) + flagged = res.results[0].flagged + if custom_thresholds is not None: + for category, threshold in custom_thresholds.items(): + if ( + getattr(res.results[0].category_scores, category) + > threshold + ): + flagged = True + moderation_response_map = { + "response": dict(res.results[0].category_scores), + "flagged": flagged, + } + break + except (openai.OpenAIError, KeyError, IndexError) as e: + print(f"MODERATION ERROR: {e}\nInput: {text}") + + return moderation_response_map + + def text_moderation_filter( + self, text: str, model_list: List[str], do_moderation: bool = False + ): + # Apply moderation for below models + MODEL_KEYWORDS = [ + "claude", + "gpt", + "bard", + "mistral-large", + "command-r", + "dbrx", + "gemini", + "reka", + ] + + custom_thresholds = {"sexual": 0.3} + # set a stricter threshold for claude + for model in model_list: + if "claude" in model: + custom_thresholds = {"sexual": 0.2} + + for keyword in MODEL_KEYWORDS: + for model in model_list: + if keyword in model: + do_moderation = True + break + + moderation_response_map = {"flagged": False} + if do_moderation: + moderation_response_map = self._openai_moderation_filter( + text, custom_thresholds + ) + self.text_flagged = moderation_response_map["flagged"] + + return {"text_moderation": moderation_response_map} + + def image_and_text_moderation_filter( + self, images: List[Image], text: str, model_list: List[str], do_moderation=True + ) -> Dict[str, bool]: + """Function that detects whether image and text violate moderation policies using the Azure and OpenAI moderation APIs. + + Returns: + Dict[str, Dict[str, Union[str, Dict[str, float]]]]: A dictionary that maps the type of moderation (text, nsfw, csam) to a dictionary that contains the moderation response. + + Example: + { + "text_moderation": { + "content": "This is a test", + "response": { + "sexual": 0.1 + }, + "flagged": True + }, + "nsfw_moderation": { + "image_hash": "1234567890", + "response": { + "IsImageAdultClassified": True + }, + "flagged": True + }, + "csam_moderation": { + "image_hash": "1234567890", + "response": { + "IsMatch": True + }, + "flagged": True + } + } + """ + print("moderating text: ", text) + text_flagged_map = self.text_moderation_filter(text, model_list, do_moderation) + + if images: + image_flagged_map = self.image_moderation_filter(images) + else: + image_flagged_map = self._NON_TOXIC_IMAGE_MODERATION_MAP + + res = {} + res.update(text_flagged_map) + res.update(image_flagged_map) + + return res diff --git a/fastchat/serve/monitor/monitor.py b/fastchat/serve/monitor/monitor.py index f5842c7c4..85ad2b647 100644 --- a/fastchat/serve/monitor/monitor.py +++ b/fastchat/serve/monitor/monitor.py @@ -455,7 +455,7 @@ def update_leaderboard_and_plots(category): "Knowledge Cutoff", ], datatype=[ - "str", + "number", "markdown", "number", "str", @@ -534,7 +534,7 @@ def update_leaderboard_and_plots(category): "Knowledge Cutoff", ], datatype=[ - "str", + "number", "markdown", "number", "str", @@ -562,9 +562,7 @@ def update_leaderboard_and_plots(category): elem_id="leaderboard_markdown", ) - if not vision: - # only live update the text tab - leader_component_values[:] = [default_md, p1, p2, p3, p4] + leader_component_values[:] = [default_md, p1, p2, p3, p4] if show_plot: more_stats_md = gr.Markdown( @@ -777,6 +775,7 @@ def build_leaderboard_tab( elo_results_file, leaderboard_table_file, arena_hard_leaderboard, + vision=True, show_plot=False, mirror=False, ): @@ -855,14 +854,18 @@ def build_leaderboard_tab( default_md, show_plot=show_plot, ) - with gr.Tab("Arena (Vision)", id=2): - build_arena_tab( - elo_results_vision, - model_table_df, - default_md, - vision=True, - show_plot=show_plot, - ) + + if vision: + with gr.Tab("Arena (Vision)", id=2): + build_arena_tab( + elo_results_vision, + model_table_df, + default_md, + vision=True, + show_plot=show_plot, + ) + + model_to_score = {} if arena_hard_leaderboard is not None: with gr.Tab("Arena-Hard-Auto", id=3): dataFrame = arena_hard_process( @@ -882,7 +885,6 @@ def build_leaderboard_tab( "avg_tokens": "Average Tokens", } ) - model_to_score = {} for i in range(len(dataFrame)): model_to_score[dataFrame.loc[i, "Model"]] = dataFrame.loc[ i, "Win-rate" diff --git a/fastchat/utils.py b/fastchat/utils.py index 545e01414..4ec8249e1 100644 --- a/fastchat/utils.py +++ b/fastchat/utils.py @@ -149,61 +149,6 @@ def get_gpu_memory(max_gpus=None): return gpu_memory -def oai_moderation(text, custom_thresholds=None): - """ - Check whether the text violates OpenAI moderation API. - """ - import openai - - client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"]) - - # default to true to be conservative - flagged = True - MAX_RETRY = 3 - for _ in range(MAX_RETRY): - try: - res = client.moderations.create(input=text) - flagged = res.results[0].flagged - if custom_thresholds is not None: - for category, threshold in custom_thresholds.items(): - if getattr(res.results[0].category_scores, category) > threshold: - flagged = True - break - except (openai.OpenAIError, KeyError, IndexError) as e: - print(f"MODERATION ERROR: {e}\nInput: {text}") - return flagged - - -def moderation_filter(text, model_list, do_moderation=False): - # Apply moderation for below models - MODEL_KEYWORDS = [ - "claude", - "gpt", - "bard", - "mistral-large", - "command-r", - "dbrx", - "gemini", - "reka", - ] - - custom_thresholds = {"sexual": 0.3} - # set a stricter threshold for claude - for model in model_list: - if "claude" in model: - custom_thresholds = {"sexual": 0.2} - - for keyword in MODEL_KEYWORDS: - for model in model_list: - if keyword in model: - do_moderation = True - break - - if do_moderation: - return oai_moderation(text, custom_thresholds) - return False - - def clean_flant5_ckpt(ckpt_path): """ Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, @@ -438,47 +383,3 @@ def get_image_file_from_gcs(filename): contents = blob.download_as_bytes() return contents - - -def image_moderation_request(image_bytes, endpoint, api_key): - headers = {"Content-Type": "image/jpeg", "Ocp-Apim-Subscription-Key": api_key} - - MAX_RETRIES = 3 - for _ in range(MAX_RETRIES): - response = requests.post(endpoint, headers=headers, data=image_bytes).json() - try: - if response["Status"]["Code"] == 3000: - break - except: - time.sleep(0.5) - return response - - -def image_moderation_provider(image, api_type): - if api_type == "nsfw": - endpoint = os.environ["AZURE_IMG_MODERATION_ENDPOINT"] - api_key = os.environ["AZURE_IMG_MODERATION_API_KEY"] - response = image_moderation_request(image, endpoint, api_key) - print(response) - return response["IsImageAdultClassified"] - elif api_type == "csam": - endpoint = ( - "https://api.microsoftmoderator.com/photodna/v1.0/Match?enhance=false" - ) - api_key = os.environ["PHOTODNA_API_KEY"] - response = image_moderation_request(image, endpoint, api_key) - return response["IsMatch"] - - -def image_moderation_filter(image): - print(f"moderating image") - - image_bytes = base64.b64decode(image.base64_str) - - nsfw_flagged = image_moderation_provider(image_bytes, "nsfw") - csam_flagged = False - - if nsfw_flagged: - csam_flagged = image_moderation_provider(image_bytes, "csam") - - return nsfw_flagged, csam_flagged diff --git a/playground/__init__.py b/playground/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/playground/benchmark/benchmark_api_provider.py b/playground/benchmark/benchmark_api_provider.py new file mode 100644 index 000000000..89ca02ece --- /dev/null +++ b/playground/benchmark/benchmark_api_provider.py @@ -0,0 +1,135 @@ +""" +Usage: +python3 -m playground.benchmark.benchmark_api_provider --api-endpoint-file api_endpoints.json --output-file ./benchmark_results.json --random-questions metadata_sampled.json +""" +import argparse +import json +import time + +import numpy as np + +from fastchat.serve.api_provider import get_api_provider_stream_iter +from fastchat.serve.gradio_web_server import State +from fastchat.serve.vision.image import Image + + +class Metrics: + def __init__(self): + self.ttft = None + self.avg_token_time = None + + def to_dict(self): + return {"ttft": self.ttft, "avg_token_time": self.avg_token_time} + + +def sample_image_and_question(random_questions_dict, index): + # message = np.random.choice(random_questions_dict) + message = random_questions_dict[index] + question = message["question"] + path = message["path"] + + if isinstance(question, list): + question = question[0] + + return (question, path) + + +def call_model( + conv, + model_name, + model_api_dict, + state, + temperature=0.4, + top_p=0.9, + max_new_tokens=2048, +): + prev_message = "" + prev_time = time.time() + CHARACTERS_PER_TOKEN = 4 + metrics = Metrics() + + stream_iter = get_api_provider_stream_iter( + conv, model_name, model_api_dict, temperature, top_p, max_new_tokens, state + ) + call_time = time.time() + token_times = [] + for i, data in enumerate(stream_iter): + output = data["text"].strip() + if i == 0: + metrics.ttft = time.time() - call_time + prev_message = output + prev_time = time.time() + else: + token_diff_length = (len(output) - len(prev_message)) / CHARACTERS_PER_TOKEN + if token_diff_length == 0: + continue + + token_diff_time = time.time() - prev_time + token_time = token_diff_time / token_diff_length + token_times.append(token_time) + prev_time = time.time() + + metrics.avg_token_time = np.mean(token_times) + return metrics + + +def run_benchmark(model_name, model_api_dict, random_questions_dict, num_calls=20): + model_results = [] + + for index in range(num_calls): + state = State(model_name) + text, image_path = sample_image_and_question(random_questions_dict, index) + max_image_size_mb = 5 / 1.5 + + images = [ + Image(url=image_path).to_conversation_format( + max_image_size_mb=max_image_size_mb + ) + ] + message = (text, images) + + state.conv.append_message(state.conv.roles[0], message) + state.conv.append_message(state.conv.roles[1], None) + + metrics = call_model(state.conv, model_name, model_api_dict, state) + model_results.append(metrics.to_dict()) + + return model_results + + +def benchmark_models(api_endpoint_info, random_questions_dict, models): + results = {model_name: [] for model_name in models} + + for model_name in models: + model_results = run_benchmark( + model_name, + api_endpoint_info[model_name], + random_questions_dict, + num_calls=20, + ) + results[model_name] = model_results + + print(results) + return results + + +def main(api_endpoint_file, random_questions, output_file): + api_endpoint_info = json.load(open(api_endpoint_file)) + random_questions_dict = json.load(open(random_questions)) + models = ["reka-core-20240501", "gpt-4o-2024-05-13"] + + models_results = benchmark_models(api_endpoint_info, random_questions_dict, models) + + with open(output_file, "w") as f: + json.dump(models_results, f) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--api-endpoint-file", required=True) + parser.add_argument("--random-questions", required=True) + parser.add_argument("--output-file", required=True) + + args = parser.parse_args() + + main(args.api_endpoint_file, args.random_questions, args.output_file)