Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Evaluation for long_context_tasks failed with a KeyError: 'continuation_indices' #1073

Closed
songkq opened this issue Mar 29, 2024 · 3 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@songkq
Copy link

songkq commented Mar 29, 2024

@maxisawesome @bmosaicml Hi, could you please give some advice for this issue? Thanks~

Environment

CentOS
python 3.10.13
llm-foundry==0.7.0

git clone https://github.com/mosaicml/llm-foundry.git
cd llm-foundry
pip install cmake packaging
pip install -e ".[gpu]"
pip install flash-attn==2.5.6

To reproduce

Steps to reproduce the behavior:

  1. edit eval/yamls/openai_eval.yaml
seed: 1
max_seq_len: 131072
device_eval_batch_size: 4
models:
-
  model_name: openai/gpt-3.5-turbo
  model:
    name: openai_chat
    version: gpt-3.5-turbo
  tokenizer:
    name: tiktoken
    kwargs:
      model_name: gpt-3.5-turbo

icl_tasks: "eval/yamls/long_context_tasks.yaml"
eval_gauntlet: "eval/yamls/eval_gauntlet_long_context_length.yaml"
  1. run composer eval/eval.py eval/yamls/openai_eval.yaml

  2. failed with KeyError

 /benchmark/public/llm-foundry/scripts/eval/eval.py:45 │
│ 2 in <module>                                                                │
│                                                                              │
│   449 │   cli_cfg = om.from_cli(args_list)                                   │
│   450 │   cfg = om.merge(yaml_cfg, cli_cfg)                                  │
│   451 │   assert isinstance(cfg, DictConfig)                                 │
│ ❱ 452 │   main(cfg)                                                          │
│   453                                                                        │
│                                                                              │
│ /benchmark/public/llm-foundry/scripts/eval/eval.py:29 │
│ 4 in main                                                                    │
│                                                                              │
│   291 │                                                                      │
│   292 │   for model_cfg in model_configs:                                    │
│   293 │   │   (trainer, logger_keys, eval_gauntlet_callback,                 │
│ ❱ 294 │   │    eval_gauntlet_df) = evaluate_model(                           │
│   295 │   │   │    model_cfg=model_cfg,                                      │
│   296 │   │   │    dist_timeout=dist_timeout,                                │
│   297 │   │   │    run_name=run_name,                                        │
│                                                                              │
│ /benchmark/public/llm-foundry/scripts/eval/eval.py:14 │
│ 9 in evaluate_model                                                          │
│                                                                              │
│   146 │   if torch.cuda.is_available():                                      │
│   147 │   │   torch.cuda.synchronize()                                       │
│   148 │   a = time.time()                                                    │
│ ❱ 149 │   trainer.eval(eval_dataloader=evaluators,                           │
│   150 │   │   │   │    subset_num_batches=eval_subset_num_batches)           │
│   151 │   if torch.cuda.is_available():                                      │
│   152 │   │   torch.cuda.synchronize()                                       │
│                                                                              │
│ /env/miniconda3/envs/llm_factory/lib/python3.10/site-pack │
│ ages/composer/trainer/trainer.py:3164 in eval                                │
│                                                                              │
│   3161 │   │   │   print(f"[debug] trainer.eval evaluator.subset_num_batches │
│   3162 │   │   │   print(f"[debug] trainer.eval eval_subset_num_batches = {e │
│   3163 │   │   │                                                             │
│ ❱ 3164 │   │   │   self._eval_loop(                                          │
│   3165 │   │   │   │   evaluator=evaluator,                                  │
│   3166 │   │   │   │   metrics=self.state.eval_metrics[evaluator.label],     │
│   3167 │   │   │   │   subset_num_batches=eval_subset_num_batches,           │
│                                                                              │
│ /env/miniconda3/envs/llm_factory/lib/python3.10/site-pack │
│ ages/composer/trainer/trainer.py:3301 in _eval_loop                          │
│                                                                              │
│   3298 │   │   │   │   │   │   │   │   self.state.precision_config,          │
│   3299 │   │   │   │   │   │   │   │   self.state.deepspeed_enabled,         │
│   3300 │   │   │   │   │   │   │   ):                                        │
│ ❱ 3301 │   │   │   │   │   │   │   │   self.state.outputs = self._original_m │
│   3302 │   │   │   │   │   │   │                                             │
│   3303 │   │   │   │   │   │   │   self.engine.run_event(Event.EVAL_AFTER_FO │
│   3304                                                                       │
│                                                                              │
│ /benchmark/public/llm-foundry/llmfoundry/models/infer │
│ ence_api_wrapper/openai_causal_lm.py:200 in eval_forward                     │
│                                                                              │
│   197 │   │   # decoding the whole continuation at once.                     │
│   198 │   │   padding_tok = self.tokenizer.pad_token_id if self.tokenizer.pa │
│   199 │   │   output_logits_batch = []                                       │
│ ❱ 200 │   │   batch = self.rebatch(batch)                                    │
│   201 │   │   for tokens, cont_idxs in zip(batch['input_ids'],               │
│   202 │   │   │   │   │   │   │   │   │    batch['continuation_indices']):   │
│   203                                                                        │
│                                                                              │
│ /benchmark/public/llm-foundry/llmfoundry/models/infer │
│ ence_api_wrapper/openai_causal_lm.py:173 in rebatch                          │
│                                                                              │
│   170 │   │   │   'labels': []                                               │
│   171 │   │   }                                                              │
│   172 │   │   for tokens, cont_idxs in zip(batch['input_ids'],               │
│ ❱ 173 │   │   │   │   │   │   │   │   │    batch['continuation_indices']):   │
│   174 │   │   │   tokens, cont_idxs = self.retokenize(tokens.tolist(),       │
│   175 │   │   │   │   │   │   │   │   │   │   │   │   cont_idxs.tolist())    │
│   176                                                                        │
╰──────────────────────────────────────────────────────────────────────────────╯
KeyError: 'continuation_indices'

Expected behavior

Successful Evaluation for long_context_tasks

Additional context

@songkq songkq added the bug Something isn't working label Mar 29, 2024
@maxisawesome
Copy link
Contributor

Howdy! Right now, the openai client is only compatible with language modeling tasks, and the tasks in long_context_tasks.yaml are all question_answering tasks. That being said, we're currently working on supporting the entire gauntlet for the openai client.

@songkq
Copy link
Author

songkq commented Apr 7, 2024

@maxisawesome Thanks. I'm wondering how I can reproduce the results reported in the blog. Could you please give some kind advice?
From: https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
image

@maxisawesome
Copy link
Contributor

We used this branch. hf_eval.yaml, which I've included below, is an example of how we performed long_context evals. Please note this branch is under active development and might change without warning.

seed: 1
max_seq_len: 16000
device_eval_batch_size: 1
models:
  -
    model_name: openai/gpt-3.5-turbo
    model:
      name: openai_chat
      version: gpt-3.5-turbo
    tokenizer:
      name: tiktoken
      kwargs:
        model_name: gpt-3.5-turbo

icl_tasks: "eval/yamls/long_context_tasks.yaml"
eval_gauntlet: "eval/yamls/eval_gauntlet_long_context.yaml"

@dakinggg dakinggg closed this as completed May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants