Skip to content

Commit

Permalink
fixed garch && monte_carlo models
Browse files Browse the repository at this point in the history
  • Loading branch information
Liberxue committed Oct 11, 2024
1 parent e08b8a6 commit 8d84671
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 36 deletions.
8 changes: 5 additions & 3 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
[package]
name = "cli"
version = "0.1.0"
edition = "2021"
edition.workspace = true
version.workspace = true
readme.workspace = true
license.workspace = true

[dependencies]
rand = "0.8"
clap = { version = "4.0", features = ["derive"] }
burn = { version = "0.13.2", features = ["train", "wgpu", "vision"] }
core = { path = "../core" }
plotters = "0.3.6"
ratatui = "0.28.0"
crossterm = "0.28.1"
tokio = { version = "1.0", features = ["full"] }
116 changes: 83 additions & 33 deletions cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use ratatui::{
widgets::{Block, Cell, Row, Table, TableState},
Frame, Terminal,
};
use std::cell::RefCell;
use std::io::{self};
use std::{fs, path::Path};

Expand All @@ -29,15 +30,53 @@ struct Opts {
t: f64,
}

#[derive(Clone, PartialEq, Debug)]
struct ModelResults {
call: f64,
put: f64,
delta: f64,
gamma: f64,
vega: f64,
theta: f64,
rho: f64,
}

struct ModelWrapper {
name: String,
model: Box<dyn OptionPricingModel>,
cache: RefCell<Option<ModelResults>>,
}

impl ModelWrapper {
fn get_results(&self, params: &OptionParameters) -> ModelResults {
if let Some(cached) = self.cache.borrow().as_ref() {
return cached.clone();
}

let results = ModelResults {
call: self.model.call_price(params),
put: self.model.put_price(params),
delta: self.model.delta(params),
gamma: self.model.gamma(params),
vega: self.model.vega(params),
theta: self.model.theta(params),
rho: self.model.rho(params),
};

*self.cache.borrow_mut() = Some(results.clone());
results
}

fn invalidate_cache(&self) {
*self.cache.borrow_mut() = None;
}
}

struct App {
models: Vec<ModelWrapper>,
table_state: TableState,
params: OptionParameters,
params_changed: bool,
}

impl App {
Expand All @@ -55,6 +94,7 @@ impl App {
models: load_models(),
table_state,
params,
params_changed: true,
}
}

Expand Down Expand Up @@ -85,6 +125,16 @@ impl App {
};
self.table_state.select(Some(i));
}

fn update_params(&mut self, new_params: OptionParameters) {
if self.params != new_params {
self.params = new_params;
self.params_changed = true;
for model in &self.models {
model.invalidate_cache();
}
}
}
}

fn create_model(model_name: &str) -> Option<Box<dyn OptionPricingModel>> {
Expand All @@ -94,7 +144,7 @@ fn create_model(model_name: &str) -> Option<Box<dyn OptionPricingModel>> {
"garch" => Some(Box::new(core::models::GarchModel::default())),
"monte_carlo" => Some(Box::new(core::models::MonteCarloModel {
simulations: 1000,
epsilon: 0.01,
time_steps: 10,
})),
_ => None,
}
Expand All @@ -120,6 +170,7 @@ fn create_model_wrapper(entry: &fs::DirEntry) -> Option<ModelWrapper> {
create_model(model_name).map(|model| ModelWrapper {
name: model_name.to_string(),
model,
cache: RefCell::new(None),
})
})
}
Expand All @@ -138,25 +189,25 @@ fn load_models() -> Vec<ModelWrapper> {
.unwrap_or_else(|_| Vec::new())
}

fn main() -> Result<(), io::Error> {
#[tokio::main]
async fn main() -> Result<(), io::Error> {
let opts: Opts = Opts::parse();
let mut app = App::new(opts);

enable_raw_mode()?;
let mut stdout = io::stdout();
execute!(stdout, EnterAlternateScreen, EnableMouseCapture)?;
let backend = CrosstermBackend::new(stdout);
let mut terminal = Terminal::new(backend)?;
terminal.hide_cursor()?;

let res = run_app(&mut terminal, &mut app);
let res = run_app(&mut terminal, &mut app).await;

disable_raw_mode()?;
execute!(
terminal.backend_mut(),
LeaveAlternateScreen,
DisableMouseCapture
)?;
terminal.show_cursor()?;

if let Err(err) = res {
println!("Error: {:?}", err)
Expand All @@ -165,16 +216,23 @@ fn main() -> Result<(), io::Error> {
Ok(())
}

fn run_app<B: Backend>(terminal: &mut Terminal<B>, app: &mut App) -> io::Result<()> {
async fn run_app<B: Backend>(terminal: &mut Terminal<B>, app: &mut App) -> io::Result<()> {
loop {
terminal.draw(|f| ui(f, app))?;

if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('q') => return Ok(()),
KeyCode::Down => app.next(),
KeyCode::Up => app.previous(),
_ => {}
if app.params_changed {
terminal.draw(|f| ui(f, app))?;
app.params_changed = false;
}

if event::poll(std::time::Duration::from_millis(100))? {
if let Event::Key(key) = event::read()? {
match key.code {
KeyCode::Char('q') => return Ok(()),
KeyCode::Down => app.next(),
KeyCode::Up => app.previous(),
KeyCode::Esc => return Ok(()),
_ => {}
}
app.params_changed = true;
}
}
}
Expand All @@ -184,40 +242,32 @@ fn ui(f: &mut Frame, app: &App) {
let chunks = Layout::default()
.direction(Direction::Vertical)
.margin(1)
.constraints([Constraint::Percentage(100)].as_ref())
.constraints([Constraint::Percentage(100), Constraint::Ratio(1, 8)])
.split(f.area());

let header_cells = [
"Models",
"Call Price",
"Put Price",
"Delta",
"Gamma",
"Vega",
"Theta",
"Rho",
"Models", "Call", "Put", "Delta", "Gamma", "Vega", "Theta", "Rho",
]
.iter()
.map(|h| {
Cell::from(*h).style(
Style::default()
.fg(Color::Yellow)
.add_modifier(Modifier::BOLD)
.add_modifier(Modifier::UNDERLINED),
.add_modifier(Modifier::BOLD),
)
});
let header = Row::new(header_cells).style(Style::default().bg(Color::Black));

let rows = app.models.iter().map(|wrapper| {
let results = wrapper.get_results(&app.params);
let cells = vec![
Cell::from(wrapper.name.as_str()),
Cell::from(format!("{:.4}", wrapper.model.call_price(&app.params))),
Cell::from(format!("{:.4}", wrapper.model.put_price(&app.params))),
Cell::from(format!("{:.4}", wrapper.model.delta(&app.params))),
Cell::from(format!("{:.4}", wrapper.model.gamma(&app.params))),
Cell::from(format!("{:.4}", wrapper.model.vega(&app.params))),
Cell::from(format!("{:.4}", wrapper.model.theta(&app.params))),
Cell::from(format!("{:.4}", wrapper.model.rho(&app.params))),
Cell::from(format!("{:.4}", results.call)),
Cell::from(format!("{:.4}", results.put)),
Cell::from(format!("{:.4}", results.delta)),
Cell::from(format!("{:.4}", results.gamma)),
Cell::from(format!("{:.4}", results.vega)),
Cell::from(format!("{:.4}", results.theta)),
Cell::from(format!("{:.4}", results.rho)),
];
Row::new(cells)
});
Expand Down

0 comments on commit 8d84671

Please sign in to comment.