From 7a19b5f0da04a8c1c436e877e1bb873acf03d45c Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 20 Dec 2024 16:06:22 -0500 Subject: [PATCH] Accept function pointer or closure for freq scaling (#2634) --- crates/burn-core/src/nn/rope_encoding.rs | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/crates/burn-core/src/nn/rope_encoding.rs b/crates/burn-core/src/nn/rope_encoding.rs index 65e0f8605c..578bdb003a 100644 --- a/crates/burn-core/src/nn/rope_encoding.rs +++ b/crates/burn-core/src/nn/rope_encoding.rs @@ -31,7 +31,7 @@ impl RotaryEncodingConfig { /// Panics if the size of input embedding dimension is not even. /// Panics if the theta parameter is not positive. pub fn init(&self, device: &B::Device) -> RotaryEncoding { - self.initialize(None, device) + self.initialize(|x| x, device) } /// Initialize a new [RotaryEncoding](RotaryEncoding) module with a custom frequency scaling function. @@ -43,10 +43,10 @@ impl RotaryEncodingConfig { /// Panics if the theta parameter is not positive. pub fn init_with_frequency_scaling( &self, - scaling: fn(Tensor) -> Tensor, + scaling: impl Fn(Tensor) -> Tensor, device: &B::Device, ) -> RotaryEncoding { - self.initialize(Some(scaling), device) + self.initialize(scaling, device) } /// Initialize a new [RotaryEncoding](RotaryEncoding) module. @@ -57,7 +57,7 @@ impl RotaryEncodingConfig { /// Panics if the theta parameter is not positive. fn initialize( &self, - scaling: Option) -> Tensor>, + scaling: impl Fn(Tensor) -> Tensor, device: &B::Device, ) -> RotaryEncoding { assert_eq!( @@ -79,11 +79,9 @@ impl RotaryEncodingConfig { // Calculate (10000 ^ (2i / d_model)) by using the log base property `exp(log(10000) * (2i / d_model))` // This is done since burn doesn't support exponentiation of scalar to tensor let theta_i = exponent.mul_scalar(self.theta.ln()).exp(); - let mut theta_i = theta_i.powf_scalar(-1.0); + let theta_i = theta_i.powf_scalar(-1.0); - if let Some(scaling) = scaling { - theta_i = scaling(theta_i) - } + let theta_i = scaling(theta_i); // Generate frequency values for positional embeddings let frequencies: Tensor =