Skip to content

Commit

Permalink
Distibert (huggingface#1366)
Browse files Browse the repository at this point in the history
* add bce with logit loss

* add bce with logit loss

* remove imports

* fix tiny bug

* add test documentation and refactor function

* fix test cases and formatting

* distilbet files

* Apply various cleanups.

* More cleanups.

* More polish.

---------

Co-authored-by: laurent <[email protected]>
  • Loading branch information
ToluClassics and LaurentMazare authored Nov 24, 2023
1 parent ca19a9a commit 762e996
Show file tree
Hide file tree
Showing 4 changed files with 500 additions and 0 deletions.
22 changes: 22 additions & 0 deletions candle-examples/examples/distilbert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# candle-distilbert

DistilBert is a distiled version of the Bert model.

## Sentence embeddings

DistilBert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.

```bash
cargo run --example distilbert --release -- --prompt "Here is a test sentence"

> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829],
> ...
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042],
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190],
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]]
> Tensor[[1, 7, 768], f32]

```
135 changes: 135 additions & 0 deletions candle-examples/examples/distilbert/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};

use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
#[arg(long)]
model_id: Option<String>,

#[arg(long)]
revision: Option<String>,

/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: String,

/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,

/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,

/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
}

impl Args {
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
let device = candle_examples::device(self.cpu)?;
let default_model = "distilbert-base-uncased".to_string();
let default_revision = "main".to_string();
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};

let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = DistilBertModel::load(vb, &config)?;
Ok((model, tokenizer))
}
}

fn get_mask(size: usize, device: &Device) -> Tensor {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.collect();
Tensor::from_slice(&mask, (size, size), device).unwrap()
}

fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;

let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
let mask = get_mask(tokens.len(), device);

println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
println!("mask: {:?}", mask.to_vec2::<u8>());

let ys = model.forward(&token_ids, &mask)?;
println!("{ys}");

Ok(())
}

pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
}
Loading

0 comments on commit 762e996

Please sign in to comment.