Skip to content

Commit 43fe3b0

Browse files
committed
Fix offsets dtype bugs
1 parent ad36d47 commit 43fe3b0

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

benchmarks/benchmark_training_throughput.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,23 @@ def sizeof_fmt(num, suffix='B'):
2929
def prepare_inputs(
3030
batch_size: int,
3131
seq_len: int,
32+
context_len: int,
3233
varlen: bool,
3334
vocab_size: int,
3435
device: torch.device
3536
):
3637
if varlen:
3738
tokens = torch.randint(high=vocab_size, size=(1, batch_size * seq_len), device=device)
3839
offsets = torch.cat([
39-
torch.tensor([0], dtype=torch.long, device=device),
40-
torch.randperm(batch_size * seq_len - 16, device=device)[:batch_size-1] + 16,
41-
torch.tensor([batch_size * seq_len], dtype=torch.long, device=device)
42-
], 0).sort()[0]
40+
torch.tensor([0]),
41+
torch.randperm(batch_size * seq_len - 16)[:torch.randint(8, 64, size=(1,))] + 16,
42+
torch.tensor([batch_size * seq_len])
43+
], 0).sort()[0].to(dtype=torch.int32, device=device)
44+
if context_len is not None:
45+
offsets = torch.cat(
46+
[torch.arange(i, j, context_len) for i, j in zip(offsets[:-1].tolist(), offsets[1:].tolist())] +
47+
[torch.tensor([len(tokens[0])])]
48+
).to(dtype=torch.int32, device=device)
4349
else:
4450
tokens = torch.randint(high=vocab_size, size=(batch_size, seq_len), device=device)
4551
offsets = None
@@ -50,6 +56,7 @@ def profile(
5056
name: str,
5157
batch_size: int = 8,
5258
seq_len: int = 2048,
59+
context_len: int = 2048,
5360
varlen: bool = False,
5461
warmup_steps: int = 16,
5562
steps: int = 32,
@@ -87,6 +94,7 @@ def profile(
8794
tokens, offsets = prepare_inputs(
8895
batch_size=batch_size,
8996
seq_len=seq_len,
97+
context_len=context_len,
9098
varlen=varlen,
9199
vocab_size=config.vocab_size,
92100
device=device
@@ -107,6 +115,7 @@ def profile(
107115
tokens, offsets = prepare_inputs(
108116
batch_size=batch_size,
109117
seq_len=seq_len,
118+
context_len=context_len,
110119
varlen=varlen,
111120
vocab_size=config.vocab_size,
112121
device=device
@@ -128,6 +137,7 @@ def profile(
128137
parser.add_argument("--name", default='retnet')
129138
parser.add_argument("--batch_size", default=8, type=int)
130139
parser.add_argument("--seq_len", default=2048, type=int)
140+
parser.add_argument("--context_len", default=None, type=int)
131141
parser.add_argument("--varlen", action='store_true')
132142
parser.add_argument("--warmup_steps", default=16, type=int)
133143
parser.add_argument("--steps", default=32, type=int)
@@ -136,6 +146,7 @@ def profile(
136146
name=args.name,
137147
batch_size=args.batch_size,
138148
seq_len=args.seq_len,
149+
context_len=args.context_len,
139150
varlen=args.varlen,
140151
warmup_steps=args.warmup_steps,
141152
steps=args.steps

0 commit comments

Comments
 (0)