From 273bc363ef61006b47940650cd6cf7096648a24c Mon Sep 17 00:00:00 2001 From: "Jonathan C. McKinney" Date: Thu, 20 Jun 2024 00:44:04 -0700 Subject: [PATCH] Instruction multi-modal --- src/gradio_runner.py | 52 ++++++++++++++++++++++++++++++++++++++++---- src/version.py | 2 +- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/src/gradio_runner.py b/src/gradio_runner.py index 3a4f8fbd4..b809a02f4 100644 --- a/src/gradio_runner.py +++ b/src/gradio_runner.py @@ -132,6 +132,16 @@ def ask_block(kwargs, instruction_label, visible_upload, file_types, mic_sources # info=None, elem_id='prompt-form', container=True, + visible=False, + ) + instruction_mm = gr.MultimodalTextbox( + lines=kwargs['input_lines'], + label=label_instruction, + info=instruction_label, + # info=None, + elem_id='prompt-form', + container=True, + submit_btn=False, ) mw0 = 20 mic_button = gr.Button( @@ -263,7 +273,7 @@ def clear_audio_state(): .then(fn=lambda: None, **submit_kwargs) stop_text.change(fn=clear_audio_state, outputs=audio_state) \ .then(fn=lambda: None, **stop_kwargs) - return attach_button, add_button, submit_buttons, instruction, submit, retry_btn, undo, clear_chat_btn, save_chat_btn, stop_btn + return attach_button, add_button, submit_buttons, instruction, instruction_mm, submit, retry_btn, undo, clear_chat_btn, save_chat_btn, stop_btn def go_gradio(**kwargs): @@ -1093,17 +1103,25 @@ def get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=None): text_viewable_doc_count = gr.Textbox(lines=2, label=None, visible=False) with gr.Accordion("Image/Video Query", open=False, visible=have_vision_models): - image_file = gr.Image(value=kwargs['image_file'] if kwargs['image_file'] and any( + image_file = gr.Gallery(value=kwargs['image_file'] if kwargs['image_file'] and any( kwargs['image_file'].endswith(y) for y in IMAGE_EXTENSIONS) else None, label='Upload', show_label=False, type='filepath', elem_id="warning", elem_classes="feedback", + visible=False, ) video_file = gr.Video(value=None, label='Upload', show_label=False, elem_id="warning", elem_classes="feedback", + visible=False, + ) + audio_file = gr.Audio(value=None, + label='Upload', + show_label=False, + elem_id="warning", elem_classes="feedback", + visible=False, ) col_tabs = gr.Column(elem_id="col-tabs", scale=10) @@ -1156,7 +1174,8 @@ def get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=None): col_chat = gr.Column(visible=kwargs['chat']) with col_chat: if kwargs['visible_ask_anything_high']: - attach_button, add_button, submit_buttons, instruction, submit, retry_btn, undo, clear_chat_btn, save_chat_btn, stop_btn = \ + attach_button, add_button, submit_buttons, instruction, instruction_mm, \ + submit, retry_btn, undo, clear_chat_btn, save_chat_btn, stop_btn = \ ask_block(kwargs, instruction_label, visible_upload, file_types, mic_sources_kwargs, mic_kwargs, noqueue_kwargs2, submit_kwargs, stop_kwargs) visible_model_choice = bool(kwargs['model_lock']) and \ @@ -1191,7 +1210,8 @@ def get_df_langchain_mode_paths(selection_docs_state1, db1s, dbs1=None): **kwargs) if not kwargs['visible_ask_anything_high']: - attach_button, add_button, submit_buttons, instruction, submit, retry_btn, undo, clear_chat_btn, save_chat_btn, stop_btn = \ + attach_button, add_button, submit_buttons, instruction, instruction_mm, \ + submit, retry_btn, undo, clear_chat_btn, save_chat_btn, stop_btn = \ ask_block(kwargs, instruction_label, visible_upload, file_types, mic_sources_kwargs, mic_kwargs, noqueue_kwargs2, submit_kwargs, stop_kwargs) with gr.Row(): @@ -2427,6 +2447,30 @@ def set_loaders(max_quality1, inputs=max_quality, outputs=[image_audio_loaders, pdf_loaders, url_loaders]) + def fun_mm(instruction_mm1, instruction1, image_file1, video_file1, audio_file1): + if 'text' in instruction_mm1: + instruction1 = instruction_mm1['text'] + if 'url' in instruction_mm1: + instruction1 = instruction_mm1['url'] + if 'files' in instruction_mm1: + image_file2 = [x for x in instruction_mm1['files'] if any(x.endswith(y) for y in IMAGE_EXTENSIONS)] + if image_file2: + image_file1 = image_file2 + video_file2 = [x for x in instruction_mm1['files'] if any(x.endswith(y) for y in VIDEO_EXTENSIONS)] + if video_file2: + video_file1 = video_file2 + if isinstance(video_file1, list): + if video_file1: + video_file1 = video_file1[0] + else: + video_file1 = None + return instruction1, image_file1, video_file1, audio_file1 + + instruction_mm.submit(fn=fun_mm, + inputs=[instruction_mm, instruction, image_file, video_file, audio_file], + outputs=[instruction, image_file, video_file, audio_file]). \ + then(lambda: gr.MultimodalTextbox(interactive=True), None, [instruction_mm]) + # Add to UserData or custom user db update_db_func = functools.partial(update_user_db_gr, dbs=dbs, diff --git a/src/version.py b/src/version.py index e4d1e8944..a20cef218 100644 --- a/src/version.py +++ b/src/version.py @@ -1 +1 @@ -__version__ = "65f77a21fae92cf145fa2ee35636de83cb9184e8" +__version__ = "e964c970ac9b693070863c2cc8fdaf66f4ba4dbe"