Skip to content

Commit 5de796b

Browse files
committed
update mscale calculation to keep back compatible with previous phi models
1 parent 86f1682 commit 5de796b

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

vllm/model_executor/layers/rotary_embedding.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,8 @@ def __init__(
503503
dtype: torch.dtype,
504504
short_factor: List[float],
505505
long_factor: List[float],
506-
short_mscale: float = 1.0,
507-
long_mscale: float = 1.0,
506+
short_mscale: float = None,
507+
long_mscale: float = None,
508508
):
509509
super().__init__()
510510

@@ -523,6 +523,17 @@ def __init__(
523523
self.base = base
524524
self.short_factor = short_factor
525525
self.long_factor = long_factor
526+
527+
scale = self.max_position_embeddings / self.original_max_position_embeddings
528+
if scale <= 1.0:
529+
scaling_factor = 1.0
530+
else:
531+
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
532+
if short_mscale is None:
533+
short_mscale = scaling_factor
534+
if long_mscale is None:
535+
long_mscale = scaling_factor
536+
526537
self.short_mscale = short_mscale
527538
self.long_mscale = long_mscale
528539

0 commit comments

Comments
 (0)