@@ -29,17 +29,23 @@ def sizeof_fmt(num, suffix='B'):
29
29
def prepare_inputs (
30
30
batch_size : int ,
31
31
seq_len : int ,
32
+ context_len : int ,
32
33
varlen : bool ,
33
34
vocab_size : int ,
34
35
device : torch .device
35
36
):
36
37
if varlen :
37
38
tokens = torch .randint (high = vocab_size , size = (1 , batch_size * seq_len ), device = device )
38
39
offsets = torch .cat ([
39
- torch .tensor ([0 ], dtype = torch .long , device = device ),
40
- torch .randperm (batch_size * seq_len - 16 , device = device )[:batch_size - 1 ] + 16 ,
41
- torch .tensor ([batch_size * seq_len ], dtype = torch .long , device = device )
42
- ], 0 ).sort ()[0 ]
40
+ torch .tensor ([0 ]),
41
+ torch .randperm (batch_size * seq_len - 16 )[:torch .randint (8 , 64 , size = (1 ,))] + 16 ,
42
+ torch .tensor ([batch_size * seq_len ])
43
+ ], 0 ).sort ()[0 ].to (dtype = torch .int32 , device = device )
44
+ if context_len is not None :
45
+ offsets = torch .cat (
46
+ [torch .arange (i , j , context_len ) for i , j in zip (offsets [:- 1 ].tolist (), offsets [1 :].tolist ())] +
47
+ [torch .tensor ([len (tokens [0 ])])]
48
+ ).to (dtype = torch .int32 , device = device )
43
49
else :
44
50
tokens = torch .randint (high = vocab_size , size = (batch_size , seq_len ), device = device )
45
51
offsets = None
@@ -50,6 +56,7 @@ def profile(
50
56
name : str ,
51
57
batch_size : int = 8 ,
52
58
seq_len : int = 2048 ,
59
+ context_len : int = 2048 ,
53
60
varlen : bool = False ,
54
61
warmup_steps : int = 16 ,
55
62
steps : int = 32 ,
@@ -87,6 +94,7 @@ def profile(
87
94
tokens , offsets = prepare_inputs (
88
95
batch_size = batch_size ,
89
96
seq_len = seq_len ,
97
+ context_len = context_len ,
90
98
varlen = varlen ,
91
99
vocab_size = config .vocab_size ,
92
100
device = device
@@ -107,6 +115,7 @@ def profile(
107
115
tokens , offsets = prepare_inputs (
108
116
batch_size = batch_size ,
109
117
seq_len = seq_len ,
118
+ context_len = context_len ,
110
119
varlen = varlen ,
111
120
vocab_size = config .vocab_size ,
112
121
device = device
@@ -128,6 +137,7 @@ def profile(
128
137
parser .add_argument ("--name" , default = 'retnet' )
129
138
parser .add_argument ("--batch_size" , default = 8 , type = int )
130
139
parser .add_argument ("--seq_len" , default = 2048 , type = int )
140
+ parser .add_argument ("--context_len" , default = None , type = int )
131
141
parser .add_argument ("--varlen" , action = 'store_true' )
132
142
parser .add_argument ("--warmup_steps" , default = 16 , type = int )
133
143
parser .add_argument ("--steps" , default = 32 , type = int )
@@ -136,6 +146,7 @@ def profile(
136
146
name = args .name ,
137
147
batch_size = args .batch_size ,
138
148
seq_len = args .seq_len ,
149
+ context_len = args .context_len ,
139
150
varlen = args .varlen ,
140
151
warmup_steps = args .warmup_steps ,
141
152
steps = args .steps
0 commit comments