Skip to content

Commit

Permalink
fix index error when computing ppl on long-text prompt (#2697)
Browse files Browse the repository at this point in the history
* fix index error when computing ppl on long-text prompt

* update user guide
  • Loading branch information
lvhan028 authored Nov 1, 2024
1 parent 654c457 commit 993aa14
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
4 changes: 4 additions & 0 deletions docs/en/llm/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ logits = pipe.get_logits(input_ids)
ppl = pipe.get_ppl(input_ids)
```

```{note}
get_ppl returns the cross entropy loss without applying the exponential operation afterwards
```

- **Below is an example for pytorch backend. Please install triton first.**

```shell
Expand Down
4 changes: 4 additions & 0 deletions docs/zh_cn/llm/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ logits = pipe.get_logits(input_ids)
ppl = pipe.get_ppl(input_ids)
```

```{note}
get_ppl 返回的是 cross entropy loss,没有在之后加 exp 操作
```

- **使用 pytorch 后端**

需要先安装 triton
Expand Down
15 changes: 8 additions & 7 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,15 +212,16 @@ def get_ppl(self, input_ids: Union[List[int],
logger.info(f'sorted indices: {indices}')
for (start, end) in self._batch_iterator(sizes, max_input_len):
logger.info(f'start: {start}, end: {end}')
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
if start == end:
_input_ids = input_ids[indices[start]]
loss, target_count = self._get_long_text_ppl(
generator=generator,
input_ids=_input_ids,
max_input_len=max_input_len)
losses.append(loss)
target_counts.append(target_count)
else:
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
loss, target_count = self._get_ppl(
generator=generator,
input_ids=_input_ids,
Expand Down Expand Up @@ -261,24 +262,24 @@ def _batch_iterator(self, sizes, max_value):
i += 1

def _get_long_text_ppl(self, generator, input_ids, max_input_len):
assert isinstance(input_ids, List) and len(input_ids) == 1
seq_len = len(input_ids[0])
assert all(isinstance(_, int) for _ in input_ids)
seq_len = len(input_ids)
assert seq_len > max_input_len
logger.info(f'get long text ppl: seq_len {seq_len}')

losses = []
target_counts = []
for i in range(0, seq_len, max_input_len):
token_ids = input_ids[:, i:i + max_input_len]
token_ids = input_ids[i:i + max_input_len]
step = [i]
# shift token_ids by 1 to the left
target_ids = input_ids[:, i + 1:i + 1 + max_input_len]
target_ids = input_ids[i + 1:i + 1 + max_input_len]

loss, target_count = self._get_ppl(
generator=generator,
input_ids=token_ids,
input_ids=[token_ids],
max_input_len=max_input_len,
target_ids=target_ids,
target_ids=[target_ids],
steps=step,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= seq_len))
Expand Down

0 comments on commit 993aa14

Please sign in to comment.