Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

try multimodal embeddings #377

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ edition = "2021"

[workspace.dependencies]
# AI
candle = { package = "candle-core", version = "0.7.1" }
candle-nn = { package = "candle-nn", version = "0.7.1" }
candle-transformers = { package = "candle-transformers", version = "0.7.1" }
candle = { package = "candle-core", git = "https://github.com/huggingface/candle.git", branch = "main" }
candle-nn = { package = "candle-nn", git = "https://github.com/huggingface/candle.git", branch = "main" }
candle-transformers = { package = "candle-transformers", git = "https://github.com/huggingface/candle.git", branch = "main" }
tokenizers = "0.20.0"
hf-hub = "0.3.0"

Expand Down
3 changes: 2 additions & 1 deletion screenpipe-audio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ chrono = { version = "0.4.31", features = ["serde"] }
candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
vad-rs = "0.1.3"
tokenizers = { workspace = true }

vad-rs = "0.1.3"
anyhow = "1.0.86"
byteorder = "1.5.0"
hf-hub = "0.3.2"
Expand Down
5 changes: 5 additions & 0 deletions screenpipe-core/src/candle_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
use candle::Device;

pub fn get_device() -> Device {
Device::new_metal(0).unwrap_or(Device::new_cuda(0).unwrap_or(Device::Cpu))
}
3 changes: 3 additions & 0 deletions screenpipe-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,6 @@ pub use pipes::*;
pub mod pii_removal;
#[cfg(feature = "security")]
pub use pii_removal::*;

pub mod candle_utils;
pub use candle_utils::*;
13 changes: 13 additions & 0 deletions screenpipe-vision/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,23 @@ clap = { version = "4.0", features = ["derive"] }

# Integrations
screenpipe-integrations = { path = "../screenpipe-integrations" }
screenpipe-core = { path = "../screenpipe-core" }

tracing-subscriber = { workspace = true }
tracing = { workspace = true }

candle = { workspace = true }
candle-nn = { workspace = true }
candle-transformers = { workspace = true }
tokenizers = { workspace = true }
hf-hub = { workspace = true, features = ["tokio"] }

[features]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
mkl = ["candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]


[dev-dependencies]
tempfile = "3.3.0"
criterion = { workspace = true }
Expand Down
1 change: 1 addition & 0 deletions screenpipe-vision/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ pub mod capture_screenshot_by_window;
#[cfg(target_os = "windows")]
pub use microsoft::perform_ocr_windows;
pub use tesseract::perform_ocr_tesseract;
pub mod multimodal_embeddings;
116 changes: 116 additions & 0 deletions screenpipe-vision/src/multimodal_embeddings.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::ops::Mul;

use anyhow::Result;
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use candle_transformers::models::siglip::{Config, Model as SiglipModel};
use image::DynamicImage;
use tokenizers::Tokenizer;

pub struct MultimodalEmbedder {
model: SiglipModel,
tokenizer: Tokenizer,
device: Device,
config: Config,
}

impl MultimodalEmbedder {
pub fn new(device: &Device) -> Result<Self> {
let config = Config::base_patch16_224();

// Load the model weights from safetensors file
let model_file = {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("model.safetensors")?
};

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? };

let model = SiglipModel::new(&config, vb)?;
let tokenizer = Self::get_tokenizer(None)?;

Ok(Self {
model,
tokenizer,
device: device.clone(),
config,
})
}

fn get_tokenizer(tokenizer_path: Option<String>) -> Result<Tokenizer> {
let tokenizer_path = match tokenizer_path {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/siglip-base-patch16-224".to_string());
api.get("tokenizer.json")?
}
Some(path) => path.into(),
};

Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}

pub fn compute_embeddings(
&self,
image: &DynamicImage,
ocr_text: &str,
) -> Result<(Tensor, Tensor)> {
let image_tensor = self.preprocess_image(image)?;
let text_tensor = self.tokenize_text(ocr_text)?;

let (text_embeddings, image_embeddings) =
self.model.forward(&image_tensor, &text_tensor)?;
Ok((text_embeddings, image_embeddings))
}

pub fn compute_similarity(
&self,
text_embeddings: &Tensor,
image_embeddings: &Tensor,
) -> anyhow::Result<Tensor> {
// compute dot product between text and image embeddings
let similarity = text_embeddings.matmul(&image_embeddings.transpose(0, 1)?)?;

// apply softmax to get probabilities
let similarity = softmax(&similarity, 1)?;

Ok(similarity)
}

fn preprocess_image(&self, image: &DynamicImage) -> Result<Tensor> {
let image_size = self.config.vision_config.image_size;
let img = image.resize_to_fill(
image_size as u32,
image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (image_size, image_size, 3), &self.device)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?
.unsqueeze(0)?;
Ok(img)
}

fn tokenize_text(&self, text: &str) -> anyhow::Result<Tensor> {
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow::anyhow!(e))?;
let mut tokens = encoding.get_ids().to_vec();
let max_len = self.config.text_config.max_position_embeddings;
let pad_id = self.config.text_config.pad_token_id;

// Pad the sequence to have the correct length
let len_diff = max_len - tokens.len();
if len_diff > 0 {
tokens.extend(vec![pad_id; len_diff]);
}

let input_ids = Tensor::new(vec![tokens], &self.device)?;
Ok(input_ids)
}
}
45 changes: 45 additions & 0 deletions screenpipe-vision/tests/embedding_benchmark.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use anyhow::Result;
use candle::Device;
use image::DynamicImage;
use screenpipe_core::get_device;
use screenpipe_vision::multimodal_embeddings::MultimodalEmbedder;
use std::time::Instant;

// Mock function to simulate screenshot capture
fn capture_screenshot() -> Result<DynamicImage> {
// For this test, we'll create a dummy image
let img = DynamicImage::new_rgb8(224, 224);
Ok(img)
}

#[test]
fn test_screenshot_and_embedding_speed() -> Result<()> {
let device = get_device();
let embedder = MultimodalEmbedder::new(&device).unwrap();

let start = Instant::now();

// Capture screenshot
let screenshot = capture_screenshot()?;
let screenshot_time = start.elapsed();

// Perform OCR (mocked for this test)
let ocr_text = "This is a test OCR text";

// Compute embeddings
let embedding_start = Instant::now();
let (text_embeddings, image_embeddings) = embedder.compute_embeddings(&screenshot, ocr_text)?;
let embedding_time = embedding_start.elapsed();

// Compute similarity
let similarity = embedder.compute_similarity(&text_embeddings, &image_embeddings)?;

let total_time = start.elapsed();

println!("Screenshot capture time: {:?}", screenshot_time);
println!("Embedding computation time: {:?}", embedding_time);
println!("Total processing time: {:?}", total_time);
println!("Similarity shape: {:?}", similarity.shape());

Ok(())
}
Loading