forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add bce with logit loss * add bce with logit loss * remove imports * fix tiny bug * add test documentation and refactor function * fix test cases and formatting * distilbet files * Apply various cleanups. * More cleanups. * More polish. --------- Co-authored-by: laurent <[email protected]>
- Loading branch information
1 parent
ca19a9a
commit 762e996
Showing
4 changed files
with
500 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# candle-distilbert | ||
|
||
DistilBert is a distiled version of the Bert model. | ||
|
||
## Sentence embeddings | ||
|
||
DistilBert is used to compute the sentence embeddings for a prompt. The model weights | ||
are downloaded from the hub on the first run. | ||
|
||
```bash | ||
cargo run --example distilbert --release -- --prompt "Here is a test sentence" | ||
|
||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441], | ||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244], | ||
> [ 0.0702, -0.1311, -0.4914, ..., 0.3483, -0.6194, 0.1829], | ||
> ... | ||
> [ 0.2993, -0.0106, -0.4640, ..., 0.2844, -0.6732, 0.0042], | ||
> [ 0.1066, -0.0081, -0.4299, ..., 0.3435, -0.7729, 0.0190], | ||
> [ 0.8903, 0.2055, -0.2541, ..., 0.3208, -0.6585, 0.0586]]] | ||
> Tensor[[1, 7, 768], f32] | ||
|
||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE}; | ||
|
||
use anyhow::{Error as E, Result}; | ||
use candle::{Device, Tensor}; | ||
use candle_nn::VarBuilder; | ||
use clap::Parser; | ||
use hf_hub::{api::sync::Api, Repo, RepoType}; | ||
use tokenizers::Tokenizer; | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// Run on CPU rather than on GPU. | ||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Enable tracing (generates a trace-timestamp.json file). | ||
#[arg(long)] | ||
tracing: bool, | ||
|
||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending | ||
#[arg(long)] | ||
model_id: Option<String>, | ||
|
||
#[arg(long)] | ||
revision: Option<String>, | ||
|
||
/// When set, compute embeddings for this prompt. | ||
#[arg(long)] | ||
prompt: String, | ||
|
||
/// Use the pytorch weights rather than the safetensors ones | ||
#[arg(long)] | ||
use_pth: bool, | ||
|
||
/// The number of times to run the prompt. | ||
#[arg(long, default_value = "1")] | ||
n: usize, | ||
|
||
/// L2 normalization for embeddings. | ||
#[arg(long, default_value = "true")] | ||
normalize_embeddings: bool, | ||
} | ||
|
||
impl Args { | ||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> { | ||
let device = candle_examples::device(self.cpu)?; | ||
let default_model = "distilbert-base-uncased".to_string(); | ||
let default_revision = "main".to_string(); | ||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) { | ||
(Some(model_id), Some(revision)) => (model_id, revision), | ||
(Some(model_id), None) => (model_id, "main".to_string()), | ||
(None, Some(revision)) => (default_model, revision), | ||
(None, None) => (default_model, default_revision), | ||
}; | ||
|
||
let repo = Repo::with_revision(model_id, RepoType::Model, revision); | ||
let (config_filename, tokenizer_filename, weights_filename) = { | ||
let api = Api::new()?; | ||
let api = api.repo(repo); | ||
let config = api.get("config.json")?; | ||
let tokenizer = api.get("tokenizer.json")?; | ||
let weights = if self.use_pth { | ||
api.get("pytorch_model.bin")? | ||
} else { | ||
api.get("model.safetensors")? | ||
}; | ||
(config, tokenizer, weights) | ||
}; | ||
let config = std::fs::read_to_string(config_filename)?; | ||
let config: Config = serde_json::from_str(&config)?; | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
|
||
let vb = if self.use_pth { | ||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)? | ||
} else { | ||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } | ||
}; | ||
let model = DistilBertModel::load(vb, &config)?; | ||
Ok((model, tokenizer)) | ||
} | ||
} | ||
|
||
fn get_mask(size: usize, device: &Device) -> Tensor { | ||
let mask: Vec<_> = (0..size) | ||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i))) | ||
.collect(); | ||
Tensor::from_slice(&mask, (size, size), device).unwrap() | ||
} | ||
|
||
fn main() -> Result<()> { | ||
use tracing_chrome::ChromeLayerBuilder; | ||
use tracing_subscriber::prelude::*; | ||
|
||
let args = Args::parse(); | ||
let _guard = if args.tracing { | ||
println!("tracing..."); | ||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); | ||
tracing_subscriber::registry().with(chrome_layer).init(); | ||
Some(guard) | ||
} else { | ||
None | ||
}; | ||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?; | ||
let device = &model.device; | ||
|
||
let tokenizer = tokenizer | ||
.with_padding(None) | ||
.with_truncation(None) | ||
.map_err(E::msg)?; | ||
let tokens = tokenizer | ||
.encode(args.prompt, true) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; | ||
let mask = get_mask(tokens.len(), device); | ||
|
||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()); | ||
println!("mask: {:?}", mask.to_vec2::<u8>()); | ||
|
||
let ys = model.forward(&token_ids, &mask)?; | ||
println!("{ys}"); | ||
|
||
Ok(()) | ||
} | ||
|
||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> { | ||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) | ||
} |
Oops, something went wrong.