17
17
)
18
18
19
19
from llama .model import ModelArgs , Transformer
20
- from llama .tokenizer import Dialog , Message , ChatFormat , Tokenizer
20
+ from llama .tokenizer import ChatFormat , Dialog , Message , Tokenizer
21
21
22
22
23
23
class CompletionPrediction (TypedDict , total = False ):
@@ -43,7 +43,7 @@ def build(
43
43
seed : int = 1 ,
44
44
) -> "Llama" :
45
45
"""
46
- Build a Llama instance by initializing and loading a pre-trained model.
46
+ Build a Llama instance by initializing and loading a model checkpoint .
47
47
48
48
Args:
49
49
ckpt_dir (str): Path to the directory containing checkpoint files.
@@ -63,7 +63,6 @@ def build(
63
63
Note:
64
64
This method initializes the distributed process group, sets the device to CUDA,
65
65
and loads the pre-trained model and tokenizer.
66
-
67
66
"""
68
67
if not torch .distributed .is_initialized ():
69
68
torch .distributed .init_process_group ("nccl" )
@@ -99,7 +98,10 @@ def build(
99
98
)
100
99
tokenizer = Tokenizer (model_path = tokenizer_path )
101
100
assert model_args .vocab_size == tokenizer .n_words
102
- torch .set_default_tensor_type (torch .cuda .HalfTensor )
101
+ if torch .cuda .is_bf16_supported ():
102
+ torch .set_default_tensor_type (torch .cuda .BFloat16Tensor )
103
+ else :
104
+ torch .set_default_tensor_type (torch .cuda .HalfTensor )
103
105
model = Transformer (model_args )
104
106
model .load_state_dict (checkpoint , strict = False )
105
107
print (f"Loaded in { time .time () - start_time :.2f} seconds" )
@@ -212,8 +214,8 @@ def generate(
212
214
for stop_token in self .tokenizer .stop_tokens :
213
215
try :
214
216
eos_idx = toks .index (stop_token )
215
- toks = toks [: eos_idx ]
216
- probs = probs [: eos_idx ] if logprobs else None
217
+ toks = toks [:eos_idx ]
218
+ probs = probs [:eos_idx ] if logprobs else None
217
219
except ValueError :
218
220
pass
219
221
out_tokens .append (toks )
@@ -293,22 +295,16 @@ def chat_completion(
293
295
Returns:
294
296
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
295
297
296
- Raises:
297
- AssertionError: If the last message in a dialog is not from the user.
298
- AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
299
-
300
298
Note:
301
299
This method generates assistant responses for the provided conversational dialogs.
302
300
It employs nucleus sampling to introduce controlled randomness in text generation.
303
301
If logprobs is True, token log probabilities are computed for each generated token.
304
-
305
302
"""
306
303
if max_gen_len is None :
307
304
max_gen_len = self .model .params .max_seq_len - 1
308
305
309
306
prompt_tokens = [
310
- self .formatter .encode_dialog_prompt (dialog )
311
- for dialog in dialogs
307
+ self .formatter .encode_dialog_prompt (dialog ) for dialog in dialogs
312
308
]
313
309
generation_tokens , generation_logprobs = self .generate (
314
310
prompt_tokens = prompt_tokens ,
@@ -354,7 +350,6 @@ def sample_top_p(probs, p):
354
350
Note:
355
351
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
356
352
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
357
-
358
353
"""
359
354
probs_sort , probs_idx = torch .sort (probs , dim = - 1 , descending = True )
360
355
probs_sum = torch .cumsum (probs_sort , dim = - 1 )
0 commit comments