From 646f297fd3789597a58c22d0774e3a1a3b32b73b Mon Sep 17 00:00:00 2001 From: "zehan.song" Date: Wed, 31 Jul 2024 16:36:57 +0800 Subject: [PATCH] support musa backend for MooreThreads --- fastchat/model/model_adapter.py | 9 ++++++++- fastchat/serve/model_worker.py | 15 ++++++++++++--- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 135c108854..ca6cd0ae4e 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -271,6 +271,12 @@ def load_model( import torch_npu except ImportError: warnings.warn("Ascend Extension for PyTorch is not installed.") + elif device == "musa": + kwargs = {"torch_dtype": torch.float16} + try: + import torch_musa + except ImportError: + warnings.warn("Musa Extension for PyTorch is not installed.") else: raise ValueError(f"Invalid device: {device}") @@ -377,6 +383,7 @@ def load_model( "mps", "xpu", "npu", + "musa", ): model.to(device) @@ -495,7 +502,7 @@ def add_model_args(parser): parser.add_argument( "--device", type=str, - choices=["cpu", "cuda", "mps", "xpu", "npu"], + choices=["cpu", "cuda", "mps", "xpu", "npu", "musa"], default="cuda", help="The device type", ) diff --git a/fastchat/serve/model_worker.py b/fastchat/serve/model_worker.py index 683a78556d..8c70e412aa 100644 --- a/fastchat/serve/model_worker.py +++ b/fastchat/serve/model_worker.py @@ -89,6 +89,7 @@ def __init__( xft_config=xft_config, debug=debug, ) + self.model_path = model_path self.device = device if self.tokenizer.pad_token == None: self.tokenizer.pad_token = self.tokenizer.eos_token @@ -165,6 +166,9 @@ def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict): else: data = model_output.hidden_states[-1] + #import pdb + #pdb.set_trace() + #sum_embeddings = data[:, 0] if hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling: sum_embeddings = data[:, 0] else: @@ -245,9 +249,10 @@ def get_embeddings(self, params): ) + tokenizer.cls_token_id ) - chunk_input_ids = torch.cat( - [cls_tokens, chunk_input_ids], dim=-1 - ) + if "bge-m3" not in self.model_path.lower(): + chunk_input_ids = torch.cat( + [cls_tokens, chunk_input_ids], dim=-1 + ) mask = torch.ones( (chunk_attention_mask.size(0), 1), dtype=chunk_attention_mask.dtype, @@ -260,6 +265,8 @@ def get_embeddings(self, params): chunk_embeddings, token_num = self.__process_embed_chunk( chunk_input_ids, chunk_attention_mask, **model_type_dict ) + import pdb + pdb.set_trace() if ( hasattr(self.model, "use_cls_pooling") and self.model.use_cls_pooling @@ -271,6 +278,8 @@ def get_embeddings(self, params): all_embeddings_tensor = torch.stack(all_embeddings) embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num + import pdb + pdb.set_trace() normalized_embeddings = F.normalize(embedding, p=2, dim=1) ret["token_num"] = all_token_num