From ad5ca0e1cda1313a97be3889bf8fcc77fc9f5913 Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 11 Oct 2024 17:41:58 +0000 Subject: [PATCH] linting --- vllm/model_executor/models/dbrx.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 41343cd1b6e9d..77ebef8eda51d 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -18,7 +18,8 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import (default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig @@ -106,9 +107,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_weight, [-1, self.intermediate_size * self.tp_size, self.d_model], ) - param_data[:, - shard_size:2 * shard_size, :] = loaded_weight[:, - shard, :] + param_data[:, shard_size:2 * + shard_size, :] = loaded_weight[:, shard, :] elif param_name.endswith("weight_scale"): param_data[:, 1] = loaded_weight else: