Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Barron committed Jan 23, 2025
1 parent d5f49d6 commit f787c08
Showing 1 changed file with 50 additions and 44 deletions.
94 changes: 50 additions & 44 deletions llms/mlx_lm/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from importlib.metadata import version
from pathlib import Path
from typing import Optional, Union
from typing import Optional

import lm_eval
import mlx.core as mx
Expand Down Expand Up @@ -43,13 +43,13 @@ def _rstrip_until(s, untils):


def _pad_inputs(inputs):
lengths = mx.array([len(x) for x in inputs])
lengths = np.array([len(x) for x in inputs])
maxlen = lengths.max()
padded = mx.stack(
[mx.pad(mx.array(x), (0, maxlen - len(x))) for x in inputs],
padded = np.stack(
[np.pad(x, (0, maxlen - len(x))) for x in inputs],
axis=0,
)
return padded, lengths
return mx.array(padded), mx.array(lengths)


@register_model("mlxlm")
Expand All @@ -65,26 +65,24 @@ def __init__(
self._batch_size = batch_size
self._model, self.tokenizer = load(path_or_hf_repo)
self._max_tokens = max_tokens or self.tokenizer.model_max_length
self.use_chat_template = use_chat_template or (
self.use_chat_template = use_chat_template and (
self.tokenizer.chat_template is not None
)

def _score_fn(self, inputs, step_size=64):
def _score_fn(self, inputs, step_size: int = 64):
inputs, lengths = _pad_inputs(inputs)
inputs, targets = inputs[..., :-1], inputs[..., 1:]

cache = make_prompt_cache(self._model)

# TODO: come up with a better way to get the dtype
dtype = self._model.model.embed_tokens(inputs).dtype

scores, is_greedy = [], []
for i in range(0, inputs.shape[1], step_size):
inp = inputs[:, i : i + step_size]
T = inp.shape[1]

offset = cache[0].offset
mask = create_causal_mask(T, offset, lengths=lengths).astype(dtype)
mask = create_causal_mask(T, offset, lengths=lengths)
mask = mask == 0

logits = self._model(inp, cache=cache, mask=mask)
log_probs = nn.log_softmax(logits.astype(mx.float32))
Expand All @@ -107,24 +105,29 @@ def _score_fn(self, inputs, step_size=64):
return scores, lengths, is_greedy

def _loglikelihood(self, texts, score_spans=None):
results = []
all_scores = mx.zeros(len(texts))
all_is_greedy = mx.zeros(len(texts), dtype=mx.bool_)
for i in tqdm(range(0, len(texts), self._batch_size)):
batch = texts[i : i + self._batch_size]
scores, length, is_greedy = self._score_fn(batch)
for j in range(len(batch)):
if score_spans is None: # full sequence score
l = length[j].item()
score = scores[j][:l].astype(mx.float32).sum()
ig = is_greedy[j][:l].astype(mx.int32).sum()
else: # subsequence score
start, end = score_spans[i + j]
score = scores[j][start:end].astype(mx.float32).sum()
ig = is_greedy[j][start:end].astype(mx.int32).sum()
length = end - start

results.append((score.item(), ig.item(), length))
scores, lengths, is_greedy = self._score_fn(batch)

ind = np.arange(scores.shape[-1])
if score_spans is not None:
spans = score_spans[i : i + self._batch_size]
lengths = [end - start for start, end in spans]
masks = mx.array(
np.array([(ind >= start) & (ind < end) for start, end in spans])
)
else:
masks = ind[None] < lengths[:, None]

return results
scores = (masks * scores).sum(axis=-1)
is_greedy = (masks * is_greedy).sum(axis=-1)

all_scores[i : i + self._batch_size] = scores
all_is_greedy[i : i + self._batch_size] = is_greedy == lengths

return all_scores, all_is_greedy

def _tokenize(self, texts):
return [
Expand Down Expand Up @@ -203,23 +206,20 @@ def loglikelihood(self, requests) -> list[tuple[float, bool]]:
shortened = [shortened[i] for i in sorted_indices]
completion_spans = [completion_spans[i] for i in sorted_indices]

group = mx.distributed.init() if mx.distributed.is_available() else None
if group is not None:
# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]
group = mx.distributed.init()

# split strided so we have approximately the same lengths on each node
shortened = shortened[group.rank() :: group.size()]
completion_spans = completion_spans[group.rank() :: group.size()]

# model scoring, returns num_requests x (logp, is_greedy, length).
results = self._loglikelihood(
scores, is_greedy = self._loglikelihood(
shortened,
score_spans=completion_spans,
)

scores = mx.array([r[0] for r in results])
is_greedy = mx.array([r[1] == r[2] for r in results])

# all gather the results across groups
if group is not None:
if group.size() > 1:
per_group = int(np.ceil(num_results / group.size()))
scores = mx.pad(scores, ((0, per_group - len(scores)),))
is_greedy = mx.pad(is_greedy, ((0, per_group - len(is_greedy))))
Expand All @@ -237,7 +237,15 @@ def loglikelihood(self, requests) -> list[tuple[float, bool]]:
return results

tokenizer_name = lm_eval.models.huggingface.HFLM.tokenizer_name
apply_chat_template = lm_eval.models.huggingface.HFLM.apply_chat_template

def apply_chat_template(
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str:
if len(chat_history) == 0:
return ""
return lm_eval.models.huggingface.HFLM.apply_chat_template(
chat_history, add_generation_prompt
)

def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
Expand Down Expand Up @@ -275,7 +283,8 @@ def loglikelihood_rolling(self, requests) -> list[float]:
"Estimating loglikelihood rolling for %d sequences." % len(requests)
)
inputs = self._tokenize([req.args[0] for req in requests])
return [t[0] for t in self._loglikelihood(inputs)]
scores, _ = self._loglikelihood(inputs)
return scores.tolist()

def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence
Expand Down Expand Up @@ -338,7 +347,7 @@ def main():
)
parser.add_argument(
"--limit",
default=1.0,
default=None,
help="Limit the number of examples per task.",
type=float,
)
Expand All @@ -352,11 +361,8 @@ def main():
)
parser.add_argument(
"--apply-chat-template",
action=argparse.BooleanOptionalAction,
help="Specifies whether to apply a chat template to the prompt. If "
"the model has a chat template, this defaults to `True`, "
"otherwise `False`.",
default=None,
action="store_true",
help="Specifies whether to apply a chat template to the prompt.",
)
args = parser.parse_args()

Expand Down

0 comments on commit f787c08

Please sign in to comment.