Skip to content

Commit

Permalink
feat: add ocr accurary benches
Browse files Browse the repository at this point in the history
  • Loading branch information
louis030195 committed Aug 22, 2024
1 parent f764669 commit 10394f3
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 26 deletions.
40 changes: 40 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,43 @@ jobs:
with:
name: benchmark-results
path: output.txt

apple_ocr_benchmark:
name: Run OCR benchmark
runs-on: macos-latest
steps:
- uses: actions/checkout@v3
- uses: dtolnay/rust-toolchain@stable

- name: Install dependencies
run: |
brew install ffmpeg
- name: Run OCR benchmarks
run: |
cargo bench --bench ocr_benchmark -- --output-format bencher | tee ocr_output.txt
- name: Download previous OCR benchmark data
uses: actions/cache@v4
with:
path: ./ocr_cache
key: ${{ runner.os }}-ocr-benchmark

- name: Store OCR benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
name: OCR Benchmark
tool: "cargo"
output-file-path: ocr_output.txt
external-data-json-path: ./ocr_cache/ocr-benchmark-data.json
fail-on-alert: true
github-token: ${{ secrets.GITHUB_TOKEN }}
comment-on-alert: true
summary-always: true
alert-comment-cc-users: "@louis030195"

- name: Upload OCR benchmark results
uses: actions/upload-artifact@v3
with:
name: ocr-benchmark-results
path: ocr_output.txt
26 changes: 0 additions & 26 deletions screenpipe-server/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,6 @@ use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::task::JoinHandle;
pub enum RecorderControl {
Pause,
Resume,
Stop,
}

// Wrapper struct for DataOutput
pub struct DataOutputWrapper {
pub data_output: rusty_tesseract::tesseract::output_data::DataOutput,
}

impl DataOutputWrapper {
pub fn to_json(&self) -> String {
let data_json: Vec<String> = self.data_output.data.iter().map(|d| {
format!(
r#"{{"level": {}, "page_num": {}, "block_num": {}, "par_num": {}, "line_num": {}, "word_num": {}, "left": {}, "top": {}, "width": {}, "height": {}, "conf": {}, "text": "{}"}}"#,
d.level, d.page_num, d.block_num, d.par_num, d.line_num, d.word_num, d.left, d.top, d.width, d.height, d.conf, d.text
)
}).collect();
format!(
r#"{{"output": "{}", "data": [{}]}}"#,
self.data_output.output,
data_json.join(", ")
)
}
}

pub async fn start_continuous_recording(
db: Arc<DatabaseManager>,
Expand Down
4 changes: 4 additions & 0 deletions screenpipe-vision/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,10 @@ path = "src/bin/screenpipe-vision.rs"
name = "vision_benchmark"
harness = false

[[bench]]
name = "ocr_benchmark"
harness = false


[target.'cfg(target_os = "windows")'.dependencies]
windows = { version = "0.58", features = ["Graphics_Imaging", "Media_Ocr", "Storage", "Storage_Streams"] }
Expand Down
136 changes: 136 additions & 0 deletions screenpipe-vision/benches/ocr_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion};
use std::path::PathBuf;
use std::time::Duration;

#[cfg(target_os = "macos")]
use screenpipe_vision::perform_ocr_apple;

const EXPECTED_KEYWORDS: &[&str] = &[
"ocr_handles",
"Vec",
"pool_size",
"task::spawn",
"async move",
"should_stop.lock().await",
"RecvError::Lagged",
"debug!",
"error!",
"frame_counter",
"start_time",
"last_processed_frame",
"control_rx.try_recv()",
"ControlMessage::Pause",
"ControlMessage::Resume",
"ControlMessage::Stop",
"is_paused.lock().await",
"tokio::time::sleep",
"capture_start",
"monitor.capture_image()",
"DynamicImage::ImageRgba8",
"capture_duration",
"image_hash",
"calculate_hash",
"result_tx_clone",
"image_arc",
"Arc::new",
"queue_size",
"ocr_tx.receiver_count()",
"MAX_QUEUE_SIZE",
"frames_to_skip",
];

// Helper function to load test image
fn load_test_image() -> image::DynamicImage {
let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
path.push("tests");
path.push("testing_OCR.png");
image::open(&path).expect("Failed to open image")
}

// Performance test for Apple Vision OCR
#[cfg(target_os = "macos")]
fn bench_apple_vision_ocr(c: &mut Criterion) {
let image = load_test_image();
let mut group = c.benchmark_group("Apple Vision OCR");
group.sample_size(10); // Reduce sample size to address warning
group.measurement_time(Duration::from_secs(10)); // Increase measurement time

group.bench_function(BenchmarkId::new("Performance", ""), |b| {
b.iter(|| {
let result = perform_ocr_apple(black_box(&image));
assert!(!result.is_empty(), "OCR failed");
})
});

group.finish();
}

// Accuracy test for Apple Vision OCR
#[cfg(target_os = "macos")]
fn test_apple_vision_ocr_accuracy() {
let image = load_test_image();
let result = perform_ocr_apple(&image);

let matched_keywords = EXPECTED_KEYWORDS
.iter()
.filter(|&&keyword| result.contains(keyword))
.count();
let accuracy = matched_keywords as f32 / EXPECTED_KEYWORDS.len() as f32;

println!("Apple Vision OCR Accuracy: {:.2}", accuracy);
println!(
"Matched keywords: {}/{}",
matched_keywords,
EXPECTED_KEYWORDS.len()
);
assert!(accuracy > 0.3, "Accuracy below threshold"); // Adjusted threshold based on observed results
}

// Combined performance and accuracy test
#[cfg(target_os = "macos")]
fn bench_apple_vision_ocr_with_accuracy(c: &mut Criterion) {
let image = load_test_image();
let mut group = c.benchmark_group("Apple Vision OCR with Accuracy");
group.sample_size(10); // Reduce sample size to address warning
group.measurement_time(Duration::from_secs(10)); // Increase measurement time

group.bench_function(BenchmarkId::new("Performance and Accuracy", ""), |b| {
b.iter_custom(|iters| {
let mut total_duration = Duration::new(0, 0);
let mut total_accuracy = 0.0;

for _ in 0..iters {
let start = std::time::Instant::now();
let result = perform_ocr_apple(black_box(&image));
total_duration += start.elapsed();

let matched_keywords = EXPECTED_KEYWORDS
.iter()
.filter(|&&keyword| result.contains(keyword))
.count();
let accuracy = matched_keywords as f32 / EXPECTED_KEYWORDS.len() as f32;
total_accuracy += accuracy;
}

println!("Average Accuracy: {:.2}", total_accuracy / iters as f32);
total_duration
})
});

group.finish();
}

#[cfg(target_os = "macos")]
criterion_group!(
benches,
bench_apple_vision_ocr,
bench_apple_vision_ocr_with_accuracy
);
#[cfg(target_os = "macos")]
criterion_main!(benches);

#[cfg(target_os = "macos")]
#[test]
fn run_accuracy_test() {
test_apple_vision_ocr_accuracy();
}

0 comments on commit 10394f3

Please sign in to comment.