Commit 5de796b 1 parent 86f1682 commit 5de796b Copy full SHA for 5de796b
File tree 1 file changed +13
-2
lines changed
vllm/model_executor/layers
1 file changed +13
-2
lines changed Original file line number Diff line number Diff line change @@ -503,8 +503,8 @@ def __init__(
503
503
dtype : torch .dtype ,
504
504
short_factor : List [float ],
505
505
long_factor : List [float ],
506
- short_mscale : float = 1.0 ,
507
- long_mscale : float = 1.0 ,
506
+ short_mscale : float = None ,
507
+ long_mscale : float = None ,
508
508
):
509
509
super ().__init__ ()
510
510
@@ -523,6 +523,17 @@ def __init__(
523
523
self .base = base
524
524
self .short_factor = short_factor
525
525
self .long_factor = long_factor
526
+
527
+ scale = self .max_position_embeddings / self .original_max_position_embeddings
528
+ if scale <= 1.0 :
529
+ scaling_factor = 1.0
530
+ else :
531
+ scaling_factor = math .sqrt (1 + math .log (scale ) / math .log (self .original_max_position_embeddings ))
532
+ if short_mscale is None :
533
+ short_mscale = scaling_factor
534
+ if long_mscale is None :
535
+ long_mscale = scaling_factor
536
+
526
537
self .short_mscale = short_mscale
527
538
self .long_mscale = long_mscale
528
539
You can’t perform that action at this time.
0 commit comments