From 10394f3e12aab39ba1df7cf52ca1d54d72de116c Mon Sep 17 00:00:00 2001 From: Louis Beaumont Date: Thu, 22 Aug 2024 16:28:16 +0200 Subject: [PATCH] feat: add ocr accurary benches --- .github/workflows/benchmark.yml | 40 ++++++ screenpipe-server/src/core.rs | 26 ---- screenpipe-vision/Cargo.toml | 4 + screenpipe-vision/benches/ocr_benchmark.rs | 136 +++++++++++++++++++++ 4 files changed, 180 insertions(+), 26 deletions(-) create mode 100644 screenpipe-vision/benches/ocr_benchmark.rs diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index 43d80a2b..6e16476a 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -51,3 +51,43 @@ jobs: with: name: benchmark-results path: output.txt + + apple_ocr_benchmark: + name: Run OCR benchmark + runs-on: macos-latest + steps: + - uses: actions/checkout@v3 + - uses: dtolnay/rust-toolchain@stable + + - name: Install dependencies + run: | + brew install ffmpeg + + - name: Run OCR benchmarks + run: | + cargo bench --bench ocr_benchmark -- --output-format bencher | tee ocr_output.txt + + - name: Download previous OCR benchmark data + uses: actions/cache@v4 + with: + path: ./ocr_cache + key: ${{ runner.os }}-ocr-benchmark + + - name: Store OCR benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + name: OCR Benchmark + tool: "cargo" + output-file-path: ocr_output.txt + external-data-json-path: ./ocr_cache/ocr-benchmark-data.json + fail-on-alert: true + github-token: ${{ secrets.GITHUB_TOKEN }} + comment-on-alert: true + summary-always: true + alert-comment-cc-users: "@louis030195" + + - name: Upload OCR benchmark results + uses: actions/upload-artifact@v3 + with: + name: ocr-benchmark-results + path: ocr_output.txt diff --git a/screenpipe-server/src/core.rs b/screenpipe-server/src/core.rs index 6f6649ef..eba98e99 100644 --- a/screenpipe-server/src/core.rs +++ b/screenpipe-server/src/core.rs @@ -17,32 +17,6 @@ use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use tokio::task::JoinHandle; -pub enum RecorderControl { - Pause, - Resume, - Stop, -} - -// Wrapper struct for DataOutput -pub struct DataOutputWrapper { - pub data_output: rusty_tesseract::tesseract::output_data::DataOutput, -} - -impl DataOutputWrapper { - pub fn to_json(&self) -> String { - let data_json: Vec = self.data_output.data.iter().map(|d| { - format!( - r#"{{"level": {}, "page_num": {}, "block_num": {}, "par_num": {}, "line_num": {}, "word_num": {}, "left": {}, "top": {}, "width": {}, "height": {}, "conf": {}, "text": "{}"}}"#, - d.level, d.page_num, d.block_num, d.par_num, d.line_num, d.word_num, d.left, d.top, d.width, d.height, d.conf, d.text - ) - }).collect(); - format!( - r#"{{"output": "{}", "data": [{}]}}"#, - self.data_output.output, - data_json.join(", ") - ) - } -} pub async fn start_continuous_recording( db: Arc, diff --git a/screenpipe-vision/Cargo.toml b/screenpipe-vision/Cargo.toml index c7001e41..d8e154b6 100644 --- a/screenpipe-vision/Cargo.toml +++ b/screenpipe-vision/Cargo.toml @@ -81,6 +81,10 @@ path = "src/bin/screenpipe-vision.rs" name = "vision_benchmark" harness = false +[[bench]] +name = "ocr_benchmark" +harness = false + [target.'cfg(target_os = "windows")'.dependencies] windows = { version = "0.58", features = ["Graphics_Imaging", "Media_Ocr", "Storage", "Storage_Streams"] } diff --git a/screenpipe-vision/benches/ocr_benchmark.rs b/screenpipe-vision/benches/ocr_benchmark.rs new file mode 100644 index 00000000..4530a8c1 --- /dev/null +++ b/screenpipe-vision/benches/ocr_benchmark.rs @@ -0,0 +1,136 @@ +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use std::path::PathBuf; +use std::time::Duration; + +#[cfg(target_os = "macos")] +use screenpipe_vision::perform_ocr_apple; + +const EXPECTED_KEYWORDS: &[&str] = &[ + "ocr_handles", + "Vec", + "pool_size", + "task::spawn", + "async move", + "should_stop.lock().await", + "RecvError::Lagged", + "debug!", + "error!", + "frame_counter", + "start_time", + "last_processed_frame", + "control_rx.try_recv()", + "ControlMessage::Pause", + "ControlMessage::Resume", + "ControlMessage::Stop", + "is_paused.lock().await", + "tokio::time::sleep", + "capture_start", + "monitor.capture_image()", + "DynamicImage::ImageRgba8", + "capture_duration", + "image_hash", + "calculate_hash", + "result_tx_clone", + "image_arc", + "Arc::new", + "queue_size", + "ocr_tx.receiver_count()", + "MAX_QUEUE_SIZE", + "frames_to_skip", +]; + +// Helper function to load test image +fn load_test_image() -> image::DynamicImage { + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.push("tests"); + path.push("testing_OCR.png"); + image::open(&path).expect("Failed to open image") +} + +// Performance test for Apple Vision OCR +#[cfg(target_os = "macos")] +fn bench_apple_vision_ocr(c: &mut Criterion) { + let image = load_test_image(); + let mut group = c.benchmark_group("Apple Vision OCR"); + group.sample_size(10); // Reduce sample size to address warning + group.measurement_time(Duration::from_secs(10)); // Increase measurement time + + group.bench_function(BenchmarkId::new("Performance", ""), |b| { + b.iter(|| { + let result = perform_ocr_apple(black_box(&image)); + assert!(!result.is_empty(), "OCR failed"); + }) + }); + + group.finish(); +} + +// Accuracy test for Apple Vision OCR +#[cfg(target_os = "macos")] +fn test_apple_vision_ocr_accuracy() { + let image = load_test_image(); + let result = perform_ocr_apple(&image); + + let matched_keywords = EXPECTED_KEYWORDS + .iter() + .filter(|&&keyword| result.contains(keyword)) + .count(); + let accuracy = matched_keywords as f32 / EXPECTED_KEYWORDS.len() as f32; + + println!("Apple Vision OCR Accuracy: {:.2}", accuracy); + println!( + "Matched keywords: {}/{}", + matched_keywords, + EXPECTED_KEYWORDS.len() + ); + assert!(accuracy > 0.3, "Accuracy below threshold"); // Adjusted threshold based on observed results +} + +// Combined performance and accuracy test +#[cfg(target_os = "macos")] +fn bench_apple_vision_ocr_with_accuracy(c: &mut Criterion) { + let image = load_test_image(); + let mut group = c.benchmark_group("Apple Vision OCR with Accuracy"); + group.sample_size(10); // Reduce sample size to address warning + group.measurement_time(Duration::from_secs(10)); // Increase measurement time + + group.bench_function(BenchmarkId::new("Performance and Accuracy", ""), |b| { + b.iter_custom(|iters| { + let mut total_duration = Duration::new(0, 0); + let mut total_accuracy = 0.0; + + for _ in 0..iters { + let start = std::time::Instant::now(); + let result = perform_ocr_apple(black_box(&image)); + total_duration += start.elapsed(); + + let matched_keywords = EXPECTED_KEYWORDS + .iter() + .filter(|&&keyword| result.contains(keyword)) + .count(); + let accuracy = matched_keywords as f32 / EXPECTED_KEYWORDS.len() as f32; + total_accuracy += accuracy; + } + + println!("Average Accuracy: {:.2}", total_accuracy / iters as f32); + total_duration + }) + }); + + group.finish(); +} + +#[cfg(target_os = "macos")] +criterion_group!( + benches, + bench_apple_vision_ocr, + bench_apple_vision_ocr_with_accuracy +); +#[cfg(target_os = "macos")] +criterion_main!(benches); + +#[cfg(target_os = "macos")] +#[test] +fn run_accuracy_test() { + test_apple_vision_ocr_accuracy(); +}