Skip to content

Commit

Permalink
Accept function pointer or closure for freq scaling (#2634)
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui authored Dec 20, 2024
1 parent 06fdb9f commit 7a19b5f
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions crates/burn-core/src/nn/rope_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl RotaryEncodingConfig {
/// Panics if the size of input embedding dimension is not even.
/// Panics if the theta parameter is not positive.
pub fn init<B: Backend>(&self, device: &B::Device) -> RotaryEncoding<B> {
self.initialize(None, device)
self.initialize(|x| x, device)
}

/// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function.
Expand All @@ -43,10 +43,10 @@ impl RotaryEncodingConfig {
/// Panics if the theta parameter is not positive.
pub fn init_with_frequency_scaling<B: Backend>(
&self,
scaling: fn(Tensor<B, 1>) -> Tensor<B, 1>,
scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
device: &B::Device,
) -> RotaryEncoding<B> {
self.initialize(Some(scaling), device)
self.initialize(scaling, device)
}

/// Initialize a new [RotaryEncoding](RotaryEncoding) module.
Expand All @@ -57,7 +57,7 @@ impl RotaryEncodingConfig {
/// Panics if the theta parameter is not positive.
fn initialize<B: Backend>(
&self,
scaling: Option<fn(Tensor<B, 1>) -> Tensor<B, 1>>,
scaling: impl Fn(Tensor<B, 1>) -> Tensor<B, 1>,
device: &B::Device,
) -> RotaryEncoding<B> {
assert_eq!(
Expand All @@ -79,11 +79,9 @@ impl RotaryEncodingConfig {
// Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))`
// This is done since burn doesn't support exponentiation of scalar to tensor
let theta_i = exponent.mul_scalar(self.theta.ln()).exp();
let mut theta_i = theta_i.powf_scalar(-1.0);
let theta_i = theta_i.powf_scalar(-1.0);

if let Some(scaling) = scaling {
theta_i = scaling(theta_i)
}
let theta_i = scaling(theta_i);

// Generate frequency values for positional embeddings
let frequencies: Tensor<B, 2> =
Expand Down

0 comments on commit 7a19b5f

Please sign in to comment.