Skip to content

Commit

Permalink
Add some KV cache to blip. (huggingface#1150)
Browse files Browse the repository at this point in the history
* Add some KV cache to blip.

* Mention BLIP in the readme.
  • Loading branch information
LaurentMazare authored Oct 22, 2023
1 parent 62fc965 commit df2f89b
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 26 deletions.
9 changes: 7 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ We also provide a some command line based examples using state of the art models
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.

Run them using commands like:
```
Expand Down Expand Up @@ -163,8 +165,11 @@ If you have an addition to this list, please submit a pull request.
- T5.
- Bert.
- Whisper (multi-lingual support).
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8.
Expand Down
14 changes: 7 additions & 7 deletions candle-examples/examples/blip/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,17 +86,17 @@ pub fn main() -> anyhow::Result<()> {

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let config = blip::Config::image_captioning_large();
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
let vision_model = model.vision_model();
let text_decoder = model.text_decoder();
let mut model = blip::BlipForConditionalGeneration::new(&config, vb)?;
println!("model built");
// TODO: Maybe add support for the conditional prompt.
let image_embeds = image.unsqueeze(0)?.apply(vision_model)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;

let mut token_ids = vec![30522u32];
for _index in 0..1000 {
let input_ids = Tensor::new(token_ids.as_slice(), &device)?.broadcast_left(1)?;
let logits = text_decoder.forward(&input_ids, &image_embeds)?;
for index in 0..1000 {
let context_size = if index > 0 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.text_decoder().forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
Expand Down
4 changes: 2 additions & 2 deletions candle-transformers/src/models/blip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ impl BlipForConditionalGeneration {
&self.vision_model
}

pub fn text_decoder(&self) -> &blip_text::TextLMHeadModel {
&self.text_decoder
pub fn text_decoder(&mut self) -> &mut blip_text::TextLMHeadModel {
&mut self.text_decoder
}
}
73 changes: 58 additions & 15 deletions candle-transformers/src/models/blip_text.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,10 @@ impl TextEmbeddings {
position_ids,
})
}
}

impl Module for TextEmbeddings {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
fn forward(&self, xs: &Tensor, past_kv_len: usize) -> Result<Tensor> {
let seq_len = xs.dim(1)?;
// Use past_key_values_length if we add a kv cache.
let position_ids = self.position_ids.narrow(1, 0, seq_len)?;
let position_ids = self.position_ids.narrow(1, past_kv_len, seq_len)?;
let embeddings = self.word_embedddings.forward(xs)?;
let position_embeddings = self.position_embeddings.forward(&position_ids)?;
(embeddings + position_embeddings)?.apply(&self.layer_norm)
Expand All @@ -65,6 +62,7 @@ struct TextSelfAttention {
attention_head_size: usize,
num_attention_heads: usize,
attention_scale: f64,
kv_cache: Option<(Tensor, Tensor)>,
}

impl TextSelfAttention {
Expand All @@ -88,6 +86,7 @@ impl TextSelfAttention {
attention_head_size,
num_attention_heads,
attention_scale,
kv_cache: None,
})
}

Expand All @@ -102,8 +101,12 @@ impl TextSelfAttention {
.permute((0, 2, 1, 3))
}

fn reset_kv_cache(&mut self) {
self.kv_cache = None
}

fn forward(
&self,
&mut self,
xs: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
Expand All @@ -115,7 +118,15 @@ impl TextSelfAttention {
None => {
let key = self.transpose_for_scores(&self.key.forward(xs)?)?;
let value = self.transpose_for_scores(&self.value.forward(xs)?)?;
// TODO: kv cache
let (key, value) = match &self.kv_cache {
None => (key, value),
Some((prev_key, prev_value)) => {
let key = Tensor::cat(&[prev_key, &key], 2)?;
let value = Tensor::cat(&[prev_value, &value], 2)?;
(key, value)
}
};
self.kv_cache = Some((key.clone(), value.clone()));
(key, value)
}
Some(xs) => {
Expand Down Expand Up @@ -172,8 +183,12 @@ impl TextAttention {
Ok(Self { self_, output })
}

fn reset_kv_cache(&mut self) {
self.self_.reset_kv_cache()
}

fn forward(
&self,
&mut self,
xs: &Tensor,
encoder_hidden_states: Option<&Tensor>,
attention_mask: Option<&Tensor>,
Expand Down Expand Up @@ -251,14 +266,21 @@ impl TextLayer {
})
}

fn reset_kv_cache(&mut self) {
self.attention.reset_kv_cache();
if let Some(ca) = &mut self.cross_attention {
ca.reset_kv_cache()
}
}

fn forward(
&self,
&mut self,
xs: &Tensor,
encoder_hidden_states: &Tensor,
attention_mask: &Tensor,
) -> Result<Tensor> {
let attention_output = self.attention.forward(xs, None, Some(attention_mask))?;
let attention_output = match &self.cross_attention {
let attention_output = match &mut self.cross_attention {
Some(ca) => ca.forward(&attention_output, Some(encoder_hidden_states), None)?,
None => candle::bail!("expected some cross-attn"),
};
Expand All @@ -283,14 +305,18 @@ impl TextEncoder {
Ok(Self { layers })
}

fn reset_kv_cache(&mut self) {
self.layers.iter_mut().for_each(|l| l.reset_kv_cache())
}

fn forward(
&self,
&mut self,
xs: &Tensor,
encoder_hidden_states: &Tensor,
attention_mask: &Tensor,
) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, encoder_hidden_states, attention_mask)?
}
Ok(xs)
Expand Down Expand Up @@ -389,6 +415,7 @@ impl Module for TextOnlyMLMHead {
struct TextModel {
embeddings: TextEmbeddings,
encoder: TextEncoder,
past_kv_len: usize,
// We do not need the pooler for caption generation
}

Expand All @@ -399,22 +426,30 @@ impl TextModel {
Ok(Self {
embeddings,
encoder,
past_kv_len: 0,
})
}

fn forward(
&self,
&mut self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
attention_mask: &Tensor,
) -> Result<Tensor> {
let embedding_output = self.embeddings.forward(input_ids)?;
let (_b_sz, seq_len) = input_ids.dims2()?;
let embedding_output = self.embeddings.forward(input_ids, self.past_kv_len)?;
let sequence_output =
self.encoder
.forward(&embedding_output, encoder_hidden_states, attention_mask)?;
self.past_kv_len += seq_len;
// We're interested in the sequence-output rather than the pooled-output.
Ok(sequence_output)
}

fn reset_kv_cache(&mut self) {
self.past_kv_len = 0;
self.encoder.reset_kv_cache();
}
}

#[derive(Debug, Clone)]
Expand All @@ -430,7 +465,11 @@ impl TextLMHeadModel {
Ok(Self { bert, cls })
}

pub fn forward(&self, input_ids: &Tensor, encoder_hidden_states: &Tensor) -> Result<Tensor> {
pub fn forward(
&mut self,
input_ids: &Tensor,
encoder_hidden_states: &Tensor,
) -> Result<Tensor> {
let seq_len = input_ids.dim(1)?;
let mask: Vec<_> = (0..seq_len)
.flat_map(|i| (0..seq_len).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
Expand All @@ -441,4 +480,8 @@ impl TextLMHeadModel {
// return_logits is false so we don't discard the last sequence element.
Ok(prediction_scores)
}

pub fn reset_kv_cache(&mut self) {
self.bert.reset_kv_cache()
}
}

0 comments on commit df2f89b

Please sign in to comment.