Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Scope input_length and session_len to BuildContext
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Sep 26, 2023
1 parent 6ba5126 commit 78b0e25
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 33 deletions.
8 changes: 7 additions & 1 deletion crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ pub struct InferenceSession {

/// How many tokens have been fed into the model's working memory so far.
#[doc(hidden)]
pub n_past: usize,
n_past: usize,

/// How much memory is required per token for the temporary context used
/// during inference.
Expand Down Expand Up @@ -98,6 +98,12 @@ pub struct BuildContext<'session> {
pub n_past: usize,
}

impl<'session> BuildContext<'session> {
pub fn input_length(&self) -> usize {
self.embd.nelements()
}
}

unsafe impl Send for InferenceSession {}
impl InferenceSession {
/// Create a new InferenceSession
Expand Down
20 changes: 15 additions & 5 deletions crates/models/bloom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ impl KnownModel for Bloom {
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
) {
let input_len = input_tokens.len();
let session_len = session.n_past;
let ctx_size = self.params.context_size;

let Hyperparameters {
Expand All @@ -133,6 +131,8 @@ impl KnownModel for Bloom {
} = self.hyperparameters;

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let session_len = builder.n_past;
let input_len = builder.input_length();
let ctx0 = builder.ctx0.borrow();
let (memory_k_size, memory_v_size) = (
builder.memory_k.element_size(),
Expand Down Expand Up @@ -337,9 +337,19 @@ impl KnownModel for Bloom {
});

// finish evaluation
common::read_last_token(session, &outputs.result, n_vocab, input_len);
common::extract_logits(output_request, &outputs.result, n_vocab, input_len);
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length);
common::extract_logits(
output_request,
&outputs.result,
n_vocab,
outputs.output_length,
);
common::extract_embeddings(
output_request,
&outputs.embedding_result,
n_embd,
outputs.output_length,
);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
Expand Down
23 changes: 17 additions & 6 deletions crates/models/falcon/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,6 @@ impl KnownModel for Falcon {
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
) {
let input_len = input_tokens.len();
let session_len = session.n_past;
let ctx_size = self.params.context_size;

let Hyperparameters {
Expand All @@ -170,9 +168,12 @@ impl KnownModel for Falcon {
} = self.hyperparameters;

let head_dim = n_embd / n_head;
let n = input_len;

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let input_len = builder.input_length();
let n = input_len;
let session_len = builder.n_past;

let mut ctx0 = builder.ctx0.borrow_mut();
let embd = builder.embd;
let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, embd);
Expand Down Expand Up @@ -358,9 +359,19 @@ impl KnownModel for Falcon {
});

// finish evaluation
common::read_last_token(session, &outputs.result, n_vocab, input_len);
common::extract_logits(output_request, &outputs.result, n_vocab, input_len);
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length);
common::extract_logits(
output_request,
&outputs.result,
n_vocab,
outputs.output_length,
);
common::extract_embeddings(
output_request,
&outputs.embedding_result,
n_embd,
outputs.output_length,
);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
Expand Down
20 changes: 15 additions & 5 deletions crates/models/gpt2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,6 @@ impl KnownModel for Gpt2 {
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
) {
let input_len = input_tokens.len();
let session_len = session.n_past;
let ctx_size = self.params.context_size;

let Hyperparameters {
Expand All @@ -154,6 +152,8 @@ impl KnownModel for Gpt2 {
} = self.hyperparameters;

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let input_len = builder.input_length();
let session_len = builder.n_past;
let mut ctx0 = builder.ctx0.borrow_mut();
let (memory_k_size, memory_v_size) = (
builder.memory_k.element_size(),
Expand Down Expand Up @@ -325,9 +325,19 @@ impl KnownModel for Gpt2 {
});

// finish evaluation
common::read_last_token(session, &outputs.result, n_vocab, input_len);
common::extract_logits(output_request, &outputs.result, n_vocab, input_len);
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length);
common::extract_logits(
output_request,
&outputs.result,
n_vocab,
outputs.output_length,
);
common::extract_embeddings(
output_request,
&outputs.embedding_result,
n_embd,
outputs.output_length,
);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
Expand Down
21 changes: 16 additions & 5 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,6 @@ impl KnownModel for GptJ {
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
) {
let input_len = input_tokens.len();
let session_len = session.n_past;
let ctx_size = self.params.context_size;

let Hyperparameters {
Expand All @@ -151,6 +149,9 @@ impl KnownModel for GptJ {
} = self.hyperparameters;

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let input_len = builder.input_length();
let session_len = builder.n_past;

let mut ctx0 = builder.ctx0.borrow_mut();
let (memory_k_size, memory_v_size) = (
builder.memory_k.element_size(),
Expand Down Expand Up @@ -306,9 +307,19 @@ impl KnownModel for GptJ {
});

// finish evaluation
common::read_last_token(session, &outputs.result, n_vocab, input_len);
common::extract_logits(output_request, &outputs.result, n_vocab, input_len);
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, input_len);
common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length);
common::extract_logits(
output_request,
&outputs.result,
n_vocab,
outputs.output_length,
);
common::extract_embeddings(
output_request,
&outputs.embedding_result,
n_embd,
outputs.output_length,
);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
Expand Down
21 changes: 16 additions & 5 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ impl KnownModel for GptNeoX {
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
) {
let n = input_tokens.len();
let n_past = session.n_past;
let n_ctx = self.params.context_size;

let Hyperparameters {
Expand All @@ -174,6 +172,9 @@ impl KnownModel for GptNeoX {
} = self.hyperparameters;

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let n = builder.input_length();
let n_past = builder.n_past;

let mut ctx0 = builder.ctx0.borrow_mut();
let embd = builder.embd;
let mut input_layer = ctx0.op_get_rows(&self.wte, embd);
Expand Down Expand Up @@ -343,9 +344,19 @@ impl KnownModel for GptNeoX {
});

// finish evaluation
common::read_last_token(session, &outputs.result, n_vocab, n);
common::extract_logits(output_request, &outputs.result, n_vocab, n);
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n);
common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length);
common::extract_logits(
output_request,
&outputs.result,
n_vocab,
outputs.output_length,
);
common::extract_embeddings(
output_request,
&outputs.embedding_result,
n_embd,
outputs.output_length,
);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
Expand Down
2 changes: 1 addition & 1 deletion crates/models/llama/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ impl KnownModel for Llama {

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let session_len = builder.n_past;
let input_len = builder.embd.nelements();
let input_len = builder.input_length();

let mut ctx0 = builder.ctx0.borrow_mut();

Expand Down
20 changes: 15 additions & 5 deletions crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,6 @@ impl KnownModel for Mpt {
input_tokens: &[TokenId],
output_request: &mut OutputRequest,
) {
let n = input_tokens.len();
let session_len = session.n_past;
let ctx_size = self.params.context_size;

let Hyperparameters {
Expand All @@ -110,6 +108,8 @@ impl KnownModel for Mpt {
} = self.hyperparameters;

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let n = builder.input_length();
let session_len = builder.n_past;
let ctx0 = builder.ctx0.borrow();
let (memory_k_size, memory_v_size) = (
builder.memory_k.element_size(),
Expand Down Expand Up @@ -243,9 +243,19 @@ impl KnownModel for Mpt {
});

// finish evaluation
common::read_last_token(session, &outputs.result, n_vocab, n);
common::extract_logits(output_request, &outputs.result, n_vocab, n);
common::extract_embeddings(output_request, &outputs.embedding_result, n_embd, n);
common::read_last_token(session, &outputs.result, n_vocab, outputs.output_length);
common::extract_logits(
output_request,
&outputs.result,
n_vocab,
outputs.output_length,
);
common::extract_embeddings(
output_request,
&outputs.embedding_result,
n_embd,
outputs.output_length,
);
}

fn hyperparameters(&self) -> &Self::Hyperparameters {
Expand Down

0 comments on commit 78b0e25

Please sign in to comment.