diff --git a/cli/Cargo.toml b/cli/Cargo.toml index bf45d03..2b6adc0 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -1,13 +1,15 @@ [package] name = "cli" -version = "0.1.0" -edition = "2021" +edition.workspace = true +version.workspace = true +readme.workspace = true +license.workspace = true [dependencies] rand = "0.8" clap = { version = "4.0", features = ["derive"] } burn = { version = "0.13.2", features = ["train", "wgpu", "vision"] } core = { path = "../core" } -plotters = "0.3.6" ratatui = "0.28.0" crossterm = "0.28.1" +tokio = { version = "1.0", features = ["full"] } diff --git a/cli/src/main.rs b/cli/src/main.rs index 75de435..8609df1 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -12,6 +12,7 @@ use ratatui::{ widgets::{Block, Cell, Row, Table, TableState}, Frame, Terminal, }; +use std::cell::RefCell; use std::io::{self}; use std::{fs, path::Path}; @@ -29,15 +30,53 @@ struct Opts { t: f64, } +#[derive(Clone, PartialEq, Debug)] +struct ModelResults { + call: f64, + put: f64, + delta: f64, + gamma: f64, + vega: f64, + theta: f64, + rho: f64, +} + struct ModelWrapper { name: String, model: Box, + cache: RefCell>, +} + +impl ModelWrapper { + fn get_results(&self, params: &OptionParameters) -> ModelResults { + if let Some(cached) = self.cache.borrow().as_ref() { + return cached.clone(); + } + + let results = ModelResults { + call: self.model.call_price(params), + put: self.model.put_price(params), + delta: self.model.delta(params), + gamma: self.model.gamma(params), + vega: self.model.vega(params), + theta: self.model.theta(params), + rho: self.model.rho(params), + }; + + *self.cache.borrow_mut() = Some(results.clone()); + results + } + + fn invalidate_cache(&self) { + *self.cache.borrow_mut() = None; + } } struct App { models: Vec, table_state: TableState, params: OptionParameters, + params_changed: bool, } impl App { @@ -55,6 +94,7 @@ impl App { models: load_models(), table_state, params, + params_changed: true, } } @@ -85,6 +125,16 @@ impl App { }; self.table_state.select(Some(i)); } + + fn update_params(&mut self, new_params: OptionParameters) { + if self.params != new_params { + self.params = new_params; + self.params_changed = true; + for model in &self.models { + model.invalidate_cache(); + } + } + } } fn create_model(model_name: &str) -> Option> { @@ -94,7 +144,7 @@ fn create_model(model_name: &str) -> Option> { "garch" => Some(Box::new(core::models::GarchModel::default())), "monte_carlo" => Some(Box::new(core::models::MonteCarloModel { simulations: 1000, - epsilon: 0.01, + time_steps: 10, })), _ => None, } @@ -120,6 +170,7 @@ fn create_model_wrapper(entry: &fs::DirEntry) -> Option { create_model(model_name).map(|model| ModelWrapper { name: model_name.to_string(), model, + cache: RefCell::new(None), }) }) } @@ -138,17 +189,18 @@ fn load_models() -> Vec { .unwrap_or_else(|_| Vec::new()) } -fn main() -> Result<(), io::Error> { +#[tokio::main] +async fn main() -> Result<(), io::Error> { let opts: Opts = Opts::parse(); let mut app = App::new(opts); - enable_raw_mode()?; let mut stdout = io::stdout(); execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?; let backend = CrosstermBackend::new(stdout); let mut terminal = Terminal::new(backend)?; + terminal.hide_cursor()?; - let res = run_app(&mut terminal, &mut app); + let res = run_app(&mut terminal, &mut app).await; disable_raw_mode()?; execute!( @@ -156,7 +208,6 @@ fn main() -> Result<(), io::Error> { LeaveAlternateScreen, DisableMouseCapture )?; - terminal.show_cursor()?; if let Err(err) = res { println!("Error: {:?}", err) @@ -165,16 +216,23 @@ fn main() -> Result<(), io::Error> { Ok(()) } -fn run_app(terminal: &mut Terminal, app: &mut App) -> io::Result<()> { +async fn run_app(terminal: &mut Terminal, app: &mut App) -> io::Result<()> { loop { - terminal.draw(|f| ui(f, app))?; - - if let Event::Key(key) = event::read()? { - match key.code { - KeyCode::Char('q') => return Ok(()), - KeyCode::Down => app.next(), - KeyCode::Up => app.previous(), - _ => {} + if app.params_changed { + terminal.draw(|f| ui(f, app))?; + app.params_changed = false; + } + + if event::poll(std::time::Duration::from_millis(100))? { + if let Event::Key(key) = event::read()? { + match key.code { + KeyCode::Char('q') => return Ok(()), + KeyCode::Down => app.next(), + KeyCode::Up => app.previous(), + KeyCode::Esc => return Ok(()), + _ => {} + } + app.params_changed = true; } } } @@ -184,40 +242,32 @@ fn ui(f: &mut Frame, app: &App) { let chunks = Layout::default() .direction(Direction::Vertical) .margin(1) - .constraints([Constraint::Percentage(100)].as_ref()) + .constraints([Constraint::Percentage(100), Constraint::Ratio(1, 8)]) .split(f.area()); let header_cells = [ - "Models", - "Call Price", - "Put Price", - "Delta", - "Gamma", - "Vega", - "Theta", - "Rho", + "Models", "Call", "Put", "Delta", "Gamma", "Vega", "Theta", "Rho", ] .iter() .map(|h| { Cell::from(*h).style( Style::default() .fg(Color::Yellow) - .add_modifier(Modifier::BOLD) - .add_modifier(Modifier::UNDERLINED), + .add_modifier(Modifier::BOLD), ) }); let header = Row::new(header_cells).style(Style::default().bg(Color::Black)); - let rows = app.models.iter().map(|wrapper| { + let results = wrapper.get_results(&app.params); let cells = vec![ Cell::from(wrapper.name.as_str()), - Cell::from(format!("{:.4}", wrapper.model.call_price(&app.params))), - Cell::from(format!("{:.4}", wrapper.model.put_price(&app.params))), - Cell::from(format!("{:.4}", wrapper.model.delta(&app.params))), - Cell::from(format!("{:.4}", wrapper.model.gamma(&app.params))), - Cell::from(format!("{:.4}", wrapper.model.vega(&app.params))), - Cell::from(format!("{:.4}", wrapper.model.theta(&app.params))), - Cell::from(format!("{:.4}", wrapper.model.rho(&app.params))), + Cell::from(format!("{:.4}", results.call)), + Cell::from(format!("{:.4}", results.put)), + Cell::from(format!("{:.4}", results.delta)), + Cell::from(format!("{:.4}", results.gamma)), + Cell::from(format!("{:.4}", results.vega)), + Cell::from(format!("{:.4}", results.theta)), + Cell::from(format!("{:.4}", results.rho)), ]; Row::new(cells) });