Skip to content

Commit

Permalink
only tokenize on change
Browse files Browse the repository at this point in the history
  • Loading branch information
trevyn committed Jul 30, 2024
1 parent 26d384b commit 08362c7
Showing 1 changed file with 24 additions and 130 deletions.
154 changes: 24 additions & 130 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,32 @@ use async_openai::types::Role::{self, *};
use async_openai::types::{
ChatCompletionRequestMessage, ChatCompletionToolArgs, ChatCompletionToolType, FunctionObjectArgs,
};
use bytes::{BufMut, Bytes, BytesMut};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::Sample;
use crossbeam::channel::RecvError;
use egui::text::LayoutJob;
use egui::*;
use futures::channel::mpsc::{self, Receiver, Sender};
use futures::stream::StreamExt as _;
use futures::SinkExt;
use futures::channel::mpsc::{self, Sender};
use once_cell::sync::Lazy;
use poll_promise::Promise;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Mutex;
use std::thread;
use std::time::Duration;
use std::{env, string};
use stream_cancel::{StreamExt as _, Trigger, Tripwire};
use tokio::fs::File;
use turbosql::{execute, now_ms, select, update, Blob, Turbosql};
use turbosql::*;

mod audiofile;
mod self_update;
mod session;

static DURATION: Lazy<Mutex<f64>> = Lazy::new(Default::default);
static TOKENIZER: Lazy<Mutex<tiktoken_rs::CoreBPE>> =
Lazy::new(|| Mutex::new(tiktoken_rs::o200k_base().unwrap()));
static COMPLETION: Lazy<Mutex<String>> = Lazy::new(Default::default);
static COMPLETION_PROMPT: Lazy<Mutex<String>> = Lazy::new(|| Mutex::new(String::from("")));

#[derive(Clone)]
struct ChatMessage {
role: Role,
content: String,
token_count: usize,
}

struct WheelWindow {
Expand All @@ -53,7 +45,7 @@ impl Default for WheelWindow {
Self {
open: true,
request_close: false,
messages: vec![ChatMessage { role: User, content: String::new() }],
messages: vec![ChatMessage { role: User, content: String::new(), token_count: 0 }],
}
}
}
Expand Down Expand Up @@ -142,9 +134,6 @@ struct Resource {

/// If set, the response was an image.
image: Option<Image<'static>>,

/// If set, the response was text with some supported syntax highlighting (e.g. ".rs" or ".md").
colored_text: Option<ColoredText>,
}

impl Resource {
Expand All @@ -154,13 +143,9 @@ impl Resource {
ctx.include_bytes(response.url.clone(), response.bytes.clone());
let image = Image::from_uri(response.url.clone());

Self { response, text: None, colored_text: None, image: Some(image) }
Self { response, text: None, image: Some(image) }
} else {
let text = response.text();
let colored_text = text.and_then(|text| syntax_highlighting(ctx, &response, text));
let text = text.map(|text| text.to_owned());

Self { response, text, colored_text, image: None }
Self { response, text: None, image: None }
}
}
}
Expand Down Expand Up @@ -483,6 +468,8 @@ impl eframe::App for App {
)
.changed()
{
entry.token_count =
self.tokenizer.as_ref().unwrap().encode_with_special_tokens(&entry.content).len();
// eprintln!("{}", entry.content);
// let debounce_tx = self.debounce_tx.clone();
// let entry_content = entry.content.clone();
Expand All @@ -494,8 +481,7 @@ impl eframe::App for App {
// WHEEL_WINDOWS.lock().unwrap().get_mut(i).unwrap().0.remove(j);
// }
});
let tokens = self.tokenizer.as_ref().unwrap().encode_with_special_tokens(&entry.content);
ui.label(format!("{} tokens", tokens.len()));
ui.label(format!("{} tokens", entry.token_count));
}

ui.label("[command-enter to send]");
Expand All @@ -511,25 +497,20 @@ impl eframe::App for App {
.insert()
.unwrap();
let orig_messages = messages.clone();
messages.push(ChatMessage { role: Assistant, content: String::new() });
messages.push(ChatMessage { role: User, content: String::new() });
messages.push(ChatMessage { role: Assistant, content: String::new(), token_count: 0 });
messages.push(ChatMessage { role: User, content: String::new(), token_count: 0 });
ui.ctx().memory_mut(|m| m.request_focus(Id::new((window_num * 1000) + messages.len() - 1)));
let id = messages.len() - 2;
let ctx_cloned = ctx.clone();
let (trigger, tripwire) = Tripwire::new();
self.trigger = Some(trigger);
tokio::spawn(async move {
run_openai("gpt-4o-mini", tripwire, orig_messages, move |content| {
WHEEL_WINDOWS
.lock()
.unwrap()
.get_mut(window_num)
.unwrap()
.messages
.get_mut(id)
.unwrap()
.content
.push_str(content);
let mut wheel_windows = WHEEL_WINDOWS.lock().unwrap();
let entry = wheel_windows.get_mut(window_num).unwrap().messages.get_mut(id).unwrap();
entry.content.push_str(content);
entry.token_count =
TOKENIZER.lock().unwrap().encode_with_special_tokens(&entry.content).len();
ctx_cloned.request_repaint();
})
.await
Expand Down Expand Up @@ -565,7 +546,7 @@ fn ui_url(ui: &mut Ui, _frame: &mut eframe::Frame, url: &mut String) -> bool {
}

fn ui_resource(ui: &mut Ui, resource: &Resource) {
let Resource { response, text, image, colored_text } = resource;
let Resource { response, text, image } = resource;

ui.monospace(format!("url: {}", response.url));
ui.monospace(format!("status: {} ({})", response.status, response.status_text));
Expand Down Expand Up @@ -600,8 +581,6 @@ fn ui_resource(ui: &mut Ui, resource: &Resource) {

if let Some(image) = image {
ui.add(image.clone());
} else if let Some(colored_text) = colored_text {
colored_text.ui(ui);
} else if let Some(text) = &text {
selectable_text(ui, text);
} else {
Expand All @@ -614,49 +593,6 @@ fn selectable_text(ui: &mut Ui, mut text: &str) {
ui.add(TextEdit::multiline(&mut text).desired_width(f32::INFINITY).font(TextStyle::Monospace));
}

// ----------------------------------------------------------------------------
// Syntax highlighting:

fn syntax_highlighting(
ctx: &Context,
response: &ehttp::Response,
text: &str,
) -> Option<ColoredText> {
let extension_and_rest: Vec<&str> = response.url.rsplitn(2, '.').collect();
let extension = extension_and_rest.first()?;
let theme = egui_extras::syntax_highlighting::CodeTheme::from_style(&ctx.style());
Some(ColoredText(egui_extras::syntax_highlighting::highlight(ctx, &theme, text, extension)))
}

struct ColoredText(text::LayoutJob);

impl ColoredText {
pub fn ui(&self, ui: &mut Ui) {
if true {
// Selectable text:
let mut layouter = |ui: &Ui, _string: &str, wrap_width: f32| {
let mut layout_job = self.0.clone();
layout_job.wrap.max_width = wrap_width;
ui.fonts(|f| f.layout_job(layout_job))
};

let mut text = self.0.text.as_str();
ui.add(
TextEdit::multiline(&mut text)
.font(TextStyle::Monospace)
.desired_width(f32::INFINITY)
.layouter(&mut layouter),
);
} else {
let mut job = self.0.clone();
job.wrap.max_width = ui.available_width();
let galley = ui.fonts(|f| f.layout_job(job));
let (response, painter) = ui.allocate_painter(galley.size(), Sense::hover());
painter.add(Shape::galley(response.rect.min, galley, ui.visuals().text_color()));
}
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init(); // Log to stderr (if you run with `RUST_LOG=debug`).
Expand All @@ -667,7 +603,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

if let Some(document) = select!(Option<Document> "ORDER BY timestamp_ms DESC LIMIT 1")? {
WHEEL_WINDOWS.lock().unwrap().push(WheelWindow {
messages: vec![ChatMessage { role: User, content: document.content }],
messages: vec![ChatMessage {
role: User,
content: document.content.clone(),
token_count: TOKENIZER.lock().unwrap().encode_with_special_tokens(&document.content).len(),
}],
..Default::default()
});
}
Expand Down Expand Up @@ -788,49 +728,3 @@ pub(crate) async fn run_openai(

Ok(())
}

pub(crate) async fn run_openai_completion(
tripwire: Tripwire,
prompt: String,
callback: impl Fn(&String) + Send + 'static,
) -> Result<(), Box<dyn std::error::Error>> {
use async_openai::{types::CreateCompletionRequestArgs, Client};
use futures::StreamExt;

let client = Client::with_config(
async_openai::config::OpenAIConfig::new().with_api_key(Setting::get("openai_api_key").value),
);

let mut logit_bias: HashMap<String, serde_json::Value> = HashMap::new();

["198", "271", "1432", "4815", "1980", "382", "720", "627"].iter().for_each(|s| {
logit_bias
.insert(s.to_string(), serde_json::Value::Number(serde_json::Number::from_f64(-100.).unwrap()));
});

let request = CreateCompletionRequestArgs::default()
.model("gpt-3.5-turbo-instruct")
.max_tokens(100u16)
.logit_bias(logit_bias)
.prompt(prompt)
.n(1)
.stream(true)
.build()?;

let mut stream = client.completions().create_stream(request).await?.take_until_if(tripwire);

while let Some(result) = stream.next().await {
match result {
Ok(response) => {
response.choices.iter().for_each(|c| {
callback(&c.text);
});
}
Err(err) => {
panic!("error: {err}");
}
}
}

Ok(())
}

0 comments on commit 08362c7

Please sign in to comment.