diff --git a/browser/browser.py b/browser/browser.py index 9ecc697b2..8d2ace446 100644 --- a/browser/browser.py +++ b/browser/browser.py @@ -1,36 +1,82 @@ import time + import streamlit as st from openai import OpenAI st.title("torchchat") -start_state = [ - { - "role": "system", - "content": "You're an assistant. Answer questions directly, be brief, and have fun.", - }, - {"role": "assistant", "content": "How can I help you?"}, -] +client = OpenAI( + base_url="http://127.0.0.1:5000/v1", + api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string. +) + with st.sidebar: response_max_tokens = st.slider( "Max Response Tokens", min_value=10, max_value=1000, value=250, step=10 ) + st.divider() + + # Build model list and allow user to change the model running on the server. + try: + models = client.models.list().data + model_keys = [model.id for model in models] + except: + models = [] + model_keys = [] + selected_model = st.selectbox( + label="Model", + options=model_keys, + ) + is_instruct_model = "instruct" in selected_model.lower() + + st.divider() + + # Change system prompt and default chat message. + system_prompt = st.text_area( + label="System Prompt", + value=( + "You're an assistant. Answer questions directly, be brief, and have fun." + if is_instruct_model + else f'Selected model "{selected_model}" doesn\'t support chat.' + ), + disabled=not is_instruct_model, + ) + assistant_prompt = st.text_area( + label="Assistant Prompt", + value=( + "How can I help you?" + if is_instruct_model + else f'Selected model "{selected_model}" doesn\'t support chat.' + ), + disabled=not is_instruct_model, + ) + + st.divider() + + # Manage chat histoory and prompts. + start_state = ( + [ + { + "role": "system", + "content": system_prompt, + }, + {"role": "assistant", "content": assistant_prompt}, + ] + if is_instruct_model + else [] + ) if st.button("Reset Chat", type="primary"): st.session_state["messages"] = start_state -if "messages" not in st.session_state: - st.session_state["messages"] = start_state + +st.session_state["messages"] = start_state for msg in st.session_state.messages: st.chat_message(msg["role"]).write(msg["content"]) if prompt := st.chat_input(): - client = OpenAI( - base_url="http://127.0.0.1:5000/v1", - api_key="813", # The OpenAI API requires an API key, but since we don't consume it, this can be any non-empty string. - ) st.session_state.messages.append({"role": "user", "content": prompt}) st.chat_message("user").write(prompt) @@ -56,7 +102,7 @@ def get_streamed_completion(completion_generator): response = st.write_stream( get_streamed_completion( client.chat.completions.create( - model="llama3", + model=selected_model, messages=st.session_state.messages, max_tokens=response_max_tokens, stream=True,