Skip to content

Commit

Permalink
support convert embeddings to bf16
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Dec 8, 2023
1 parent 1753ead commit a6c4977
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,10 +558,14 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
embedding_ends = [embedding_ends]
embeddings = [embeddings]
# convert to lookup table type
# TODO bf16
if self.tm_model.config.weight_type == 'fp32':
embeddings = [[x.astype(np.float32) for x in y]
for y in embeddings]
elif self.tm_model.config.weight_type == 'bf16':
embeddings = [[
torch.from_numpy(x).bfloat16().view(torch.half).numpy()
for x in y
] for y in embeddings]
else:
embeddings = [[x.astype(np.float16) for x in y]
for y in embeddings]
Expand Down

0 comments on commit a6c4977

Please sign in to comment.