Skip to content

Commit a60b7d2

Browse files
[sharktank] Fix numerics for perplexity with vmfb (#436)
Fix cache update to resolve numerics issue. Update cache directly via IREE's DeviceArray.
1 parent 445511a commit a60b7d2

File tree

4 files changed

+143
-134
lines changed

4 files changed

+143
-134
lines changed

sharktank/sharktank/evaluate/perplexity_vmfb.py

+40-31
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ def __init__(
7373
self.iree_hip_target = iree_hip_target
7474
self.iree_hal_target_backends = iree_hal_target_backends
7575
self.kv_cache_type = kv_cache_type
76-
self.activation_dtype = torch.float32
77-
self.attention_dtype = torch.float32
76+
self.activation_dtype = torch.float16
77+
self.attention_dtype = torch.float16
7878
self.tensor_parallelism_size = tensor_parallelism_size
7979
self.attention_kernel = attention_kernel
8080

@@ -166,6 +166,8 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
166166
external_weight_path=self.weight_path_str,
167167
)
168168

169+
self.haldevice = self.runner.config.device
170+
169171
@timeit
170172
def get_prompts(self):
171173
test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[
@@ -189,40 +191,19 @@ def get_prompts(self):
189191

190192
def prefill_vmfb(self, token_batch, i):
191193

192-
logger.debug(f"Prefill:")
193-
194-
logger.debug("Input:")
195-
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
196-
197-
token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
198-
token_ids=token_batch.tolist(),
199-
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
200-
)
201-
202-
logger.debug(f"{token_batch}")
203-
204-
token_batch = torch.tensor(token_batch, device=self.torch_device)
205-
self.seq_lens_batch = torch.tensor(seq_lens_batch, device=self.torch_device)
206-
207-
self.batch = self.generator.begin_eval_batch(
208-
token_batch=token_batch,
209-
seq_lens_batch=self.seq_lens_batch,
210-
bs=self.bs,
211-
)
212-
213194
seq_block_ids = self.batch.pad_block_ids()
214195
prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"](
215196
token_batch,
216-
self.seq_lens_batch,
197+
self.batch.seq_lens,
217198
seq_block_ids,
218-
self.batch.cache_state[0].to(torch.float16),
199+
self.cache_state,
219200
)
220201

221202
prefill_logits = torch.tensor(prefill_logits[:, :, :])
222203

223204
tokens = torch.tensor(
224205
self.generator.model.extract_tokens_from_logits(
225-
prefill_logits, seq_lens_batch
206+
prefill_logits, self.batch.seq_lens
226207
)
227208
).unsqueeze(1)
228209
self.batch.add_result_token(tokens)
@@ -237,17 +218,17 @@ def decode_vmfb(self, token_batch, i):
237218
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
238219
logger.debug(f"{token_batch.tolist()}")
239220

240-
start_positions = self.seq_lens_batch.clone()
241-
self.seq_lens_batch.add_(1)
221+
start_positions = self.batch.seq_lens.clone()
222+
self.batch.seq_lens.add_(1)
242223
self.batch.allocate_seq_block_ids()
243224
seq_block_ids = self.batch.pad_block_ids()
244225

245226
decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"](
246227
token_batch,
247-
self.seq_lens_batch,
228+
self.batch.seq_lens,
248229
start_positions,
249230
seq_block_ids,
250-
self.batch.cache_state[0].to(torch.float16),
231+
self.cache_state,
251232
)
252233

253234
decode_logits = torch.tensor(decode_logits[:, :, :])
@@ -287,6 +268,7 @@ def get_logits(self):
287268
start = 0
288269
for i in tqdm(
289270
range(start, self.max_prompt_length - 1),
271+
mininterval=300,
290272
desc="eval: Calculating logits",
291273
):
292274
logger.debug(f"Iteration: {i}")
@@ -295,8 +277,35 @@ def get_logits(self):
295277

296278
token_batch = self.token_ids[:, : i + 1]
297279

280+
logger.debug(f"Prefill:")
281+
282+
logger.debug("Input:")
283+
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
284+
285+
token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
286+
token_ids=token_batch.tolist(),
287+
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
288+
)
289+
290+
logger.debug(f"{token_batch}")
291+
292+
token_batch = torch.tensor(token_batch, device=self.torch_device)
293+
self.seq_lens_batch = torch.tensor(
294+
seq_lens_batch, device=self.torch_device
295+
)
296+
297+
self.batch = self.generator.begin_eval_batch(
298+
token_batch=token_batch,
299+
seq_lens_batch=self.seq_lens_batch,
300+
bs=self.bs,
301+
)
302+
303+
self.cache_state = ireert.asdevicearray(
304+
self.haldevice, self.batch.cache_state[0].to("cpu").numpy()
305+
)
306+
298307
prefill_logits = self.prefill_vmfb(token_batch, i)
299-
self.out_logits = prefill_logits[:, 0:1, :]
308+
self.out_logits = prefill_logits[:, -1:, :]
300309

301310
is_first_token = False
302311

sharktank/sharktank/utils/load_llm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def __init__(
3131
self.tokenizer = tokenizer
3232
if model.cache.is_paged:
3333
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
34+
self.free_pages = list(range(1, page_cache_size))
3435
else:
3536
self.shared_cache_state = None
36-
self.free_pages = list(range(1, 8192))
3737
self.end_token = end_token
3838

3939
@property

sharktank/tests/evaluate/baseline_perplexity_scores.json

+101-101
Original file line numberDiff line numberDiff line change
@@ -212,107 +212,107 @@
212212
},
213213
"llama3_8B_f16_decomposed_vmfb": {
214214
"perplexities": [
215-
21194.505859,
216-
19049.068359,
217-
14214.751953,
218-
15752.748047,
219-
8948.568359,
220-
9867.280273,
221-
16664.880859,
222-
10607.53125,
223-
9715.395508,
224-
14289.220703,
225-
25121.929688,
226-
8545.292969,
227-
21990.28125,
228-
8150.422363,
229-
4658.82666,
230-
13440.376953,
231-
11978.756836,
232-
9100.139648,
233-
7168.022949,
234-
14279.970703,
235-
19406.207031,
236-
13816.291016,
237-
14942.27832,
238-
20922.1875,
239-
17307.214844,
240-
10634.068359,
241-
10968.188477,
242-
11322.012695,
243-
7898.733887,
244-
7532.914062,
245-
10352.375,
246-
16628.289062,
247-
5661.084473,
248-
6998.464355,
249-
7167.906738,
250-
7252.662598,
251-
7832.401367,
252-
5824.921875,
253-
12029.311523,
254-
13104.125,
255-
6688.567871,
256-
7917.172852,
257-
13455.291992,
258-
7466.178223,
259-
8360.422852,
260-
5765.317383,
261-
21530.652344,
262-
13371.045898,
263-
41826.242188,
264-
13620.586914,
265-
13886.725586,
266-
13105.150391,
267-
27155.019531,
268-
8066.837402,
269-
6860.444824,
270-
9858.532227,
271-
7352.963867,
272-
15839.926758,
273-
4746.95459,
274-
8539.133789,
275-
12957.833008,
276-
10096.874023,
277-
6436.333496,
278-
6488.447754,
279-
12649.62793,
280-
9575.267578,
281-
2897.279785,
282-
12649.941406,
283-
14139.443359,
284-
12061.751953,
285-
10646.621094,
286-
15703.19043,
287-
13080.764648,
288-
9124.349609,
289-
14409.989258,
290-
10726.665039,
291-
6444.680664,
292-
10168.352539,
293-
5474.356934,
294-
10729.345703,
295-
4240.486328,
296-
11856.861328,
297-
6184.834473,
298-
16671.128906,
299-
9840.30957,
300-
39691.976562,
301-
21551.833984,
302-
6072.709961,
303-
18333.572266,
304-
6635.820801,
305-
8460.941406,
306-
14243.955078,
307-
34157.90625,
308-
9565.474609,
309-
5573.206055,
310-
9139.364258,
311-
6077.837402,
312-
13941.31543,
313-
10590.963867,
314-
12113.441406
215+
6.651368,
216+
22.059452,
217+
15.392176,
218+
17.418619,
219+
15.206824,
220+
7.907998,
221+
8.829535,
222+
22.355659,
223+
8.29262,
224+
20.958277,
225+
7.167404,
226+
14.592677,
227+
9.060788,
228+
7.274667,
229+
16.238981,
230+
6.666115,
231+
6.535679,
232+
7.086256,
233+
10.676177,
234+
8.979206,
235+
10.597121,
236+
42.038162,
237+
11.70071,
238+
65.731316,
239+
47.42622,
240+
20.109543,
241+
18.897541,
242+
13.781085,
243+
9.99165,
244+
5.955308,
245+
10.175659,
246+
23.628405,
247+
14.306578,
248+
9.719462,
249+
5.594786,
250+
14.198979,
251+
5.711433,
252+
17.381332,
253+
9.058512,
254+
8.286205,
255+
8.016202,
256+
18.4515,
257+
11.600831,
258+
3.945074,
259+
13.000222,
260+
10.373363,
261+
12.237907,
262+
21.408463,
263+
37.858665,
264+
25.794065,
265+
15.489001,
266+
14.004895,
267+
7.625473,
268+
10.993184,
269+
14.698832,
270+
11.062652,
271+
5.855446,
272+
15.625135,
273+
8.052419,
274+
14.365479,
275+
5.927001,
276+
6.931933,
277+
2.3014,
278+
15.769623,
279+
40.843319,
280+
8.022024,
281+
12.544907,
282+
10.090073,
283+
9.304819,
284+
10.679907,
285+
8.136175,
286+
21.540607,
287+
3.736973,
288+
15.381804,
289+
24.21562,
290+
14.385005,
291+
17.791706,
292+
16.498833,
293+
8.753955,
294+
12.941816,
295+
12.887664,
296+
13.725715,
297+
13.994792,
298+
10.769128,
299+
14.734674,
300+
26.970015,
301+
17.811842,
302+
9.847188,
303+
15.124973,
304+
15.623392,
305+
29.147844,
306+
12.309229,
307+
32.15152,
308+
33.225769,
309+
14.426914,
310+
17.496277,
311+
14.7356,
312+
15.503921,
313+
12.336852,
314+
16.469248
315315
],
316-
"mean_perplexity": 12191.57833
316+
"mean_perplexity": 14.991893
317317
}
318318
}

sharktank/tests/evaluate/perplexity_vmfb_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
class PerplexityTest(unittest.TestCase):
2323
def setUp(self):
2424
self.current_perplexity_all = {}
25-
self.delta = 10
25+
self.delta = 5e-1
2626
self.tensor_parallelism_size = 8
2727
with open(self.baseline_perplexity_scores, "r") as f:
2828
self.baseline_perplexity = json.load(f)

0 commit comments

Comments
 (0)