Skip to content

Commit 32b7ddc

Browse files
committed
rope theta + nits
1 parent 3afbe13 commit 32b7ddc

5 files changed

+59
-250
lines changed

example_chat_completion.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,15 @@ def main(
1818
max_gen_len: Optional[int] = None,
1919
):
2020
"""
21-
Entry point of the program for generating text using a pretrained model.
21+
Examples to run with the models finetuned for chat. Prompts correspond of chat
22+
turns between the user and assistant with the final one always being the user.
2223
23-
Args:
24-
ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
25-
tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
26-
temperature (float, optional): The temperature value for controlling randomness in generation.
27-
Defaults to 0.6.
28-
top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
29-
Defaults to 0.9.
30-
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 512.
31-
max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 8.
32-
max_gen_len (int, optional): The maximum length of generated sequences. If None, it will be
33-
set to the model's max sequence length. Defaults to None.
24+
An optional system prompt at the beginning to control how the model should respond
25+
is also supported.
26+
27+
The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
28+
29+
`max_gen_len` is optional because finetuned models are able to stop generations naturally.
3430
"""
3531
generator = Llama.build(
3632
ckpt_dir=ckpt_dir,

example_text_completion.py

+4-11
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,11 @@ def main(
1818
max_batch_size: int = 4,
1919
):
2020
"""
21-
Entry point of the program for generating text using a pretrained model.
21+
Examples to run with the pre-trained models (no fine-tuning). Prompts are
22+
usually in the form of an incomplete text prefix that the model can then try to complete.
2223
23-
Args:
24-
ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
25-
tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
26-
temperature (float, optional): The temperature value for controlling randomness in generation.
27-
Defaults to 0.6.
28-
top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
29-
Defaults to 0.9.
30-
max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
31-
max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
32-
max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
24+
The context window of llama3 models is 8192 tokens, so `max_seq_len` needs to be <= 8192.
25+
`max_gen_len` is needed because pre-trained models usually do not stop completions naturally.
3326
"""
3427
generator = Llama.build(
3528
ckpt_dir=ckpt_dir,

llama/generation.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818

1919
from llama.model import ModelArgs, Transformer
20-
from llama.tokenizer import Dialog, Message, ChatFormat, Tokenizer
20+
from llama.tokenizer import ChatFormat, Dialog, Message, Tokenizer
2121

2222

2323
class CompletionPrediction(TypedDict, total=False):
@@ -43,7 +43,7 @@ def build(
4343
seed: int = 1,
4444
) -> "Llama":
4545
"""
46-
Build a Llama instance by initializing and loading a pre-trained model.
46+
Build a Llama instance by initializing and loading a model checkpoint.
4747
4848
Args:
4949
ckpt_dir (str): Path to the directory containing checkpoint files.
@@ -63,7 +63,6 @@ def build(
6363
Note:
6464
This method initializes the distributed process group, sets the device to CUDA,
6565
and loads the pre-trained model and tokenizer.
66-
6766
"""
6867
if not torch.distributed.is_initialized():
6968
torch.distributed.init_process_group("nccl")
@@ -99,7 +98,10 @@ def build(
9998
)
10099
tokenizer = Tokenizer(model_path=tokenizer_path)
101100
assert model_args.vocab_size == tokenizer.n_words
102-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
101+
if torch.cuda.is_bf16_supported():
102+
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
103+
else:
104+
torch.set_default_tensor_type(torch.cuda.HalfTensor)
103105
model = Transformer(model_args)
104106
model.load_state_dict(checkpoint, strict=False)
105107
print(f"Loaded in {time.time() - start_time:.2f} seconds")
@@ -212,8 +214,8 @@ def generate(
212214
for stop_token in self.tokenizer.stop_tokens:
213215
try:
214216
eos_idx = toks.index(stop_token)
215-
toks = toks[: eos_idx]
216-
probs = probs[: eos_idx] if logprobs else None
217+
toks = toks[:eos_idx]
218+
probs = probs[:eos_idx] if logprobs else None
217219
except ValueError:
218220
pass
219221
out_tokens.append(toks)
@@ -293,22 +295,16 @@ def chat_completion(
293295
Returns:
294296
List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
295297
296-
Raises:
297-
AssertionError: If the last message in a dialog is not from the user.
298-
AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
299-
300298
Note:
301299
This method generates assistant responses for the provided conversational dialogs.
302300
It employs nucleus sampling to introduce controlled randomness in text generation.
303301
If logprobs is True, token log probabilities are computed for each generated token.
304-
305302
"""
306303
if max_gen_len is None:
307304
max_gen_len = self.model.params.max_seq_len - 1
308305

309306
prompt_tokens = [
310-
self.formatter.encode_dialog_prompt(dialog)
311-
for dialog in dialogs
307+
self.formatter.encode_dialog_prompt(dialog) for dialog in dialogs
312308
]
313309
generation_tokens, generation_logprobs = self.generate(
314310
prompt_tokens=prompt_tokens,
@@ -354,7 +350,6 @@ def sample_top_p(probs, p):
354350
Note:
355351
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
356352
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
357-
358353
"""
359354
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
360355
probs_sum = torch.cumsum(probs_sort, dim=-1)

0 commit comments

Comments
 (0)