Skip to content

Commit

Permalink
fix: solved bug from ArgillaTrainer with extractive QA (#4204)
Browse files Browse the repository at this point in the history
<!-- Thanks for your contribution! As part of our Community Growers
initiative 🌱, we're donating Justdiggit bunds in your name to reforest
sub-Saharan Africa. To claim your Community Growers certificate, please
contact David Berenstein in our Slack community or fill in this form
https://tally.so/r/n9XrxK once your PR has been merged. -->

# Description

Fixed the error that appeared when training for extractive question
answering.

In order to compute the metrics with the evaluation data the format
expected by squad dataset works with an extra field in the dataset to
locate the answers during the validation pass:


```diff
feature_dict = {
    "question": datasets.Value("string"),
    "context": datasets.Value("string"),
    "answer": datasets.Sequence(
        feature={
            "text": datasets.Value(dtype="string", id=None),
            "answer_start": datasets.Value(dtype="int32", id=None),
        },
        length=-1,
        id=None,
    ),
+ "id": datasets.Value(dtype="int32"),
}
```

Closes #4158

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] Breaking change (fix or feature that would cause existing
functionality to not work as expected)

**How Has This Been Tested**

(Please describe the tests that you ran to verify your changes. And
ideally, reference `tests`)

The previous tests now include test sample to force the metrics
computation

**Checklist**

- [x] I followed the style guidelines of this project
- [x] I did a self-review of my code
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
- [ ] I filled out [the contributor form](https://tally.so/r/n9XrxK)
(see text above)
- [x] I have added relevant notes to the `CHANGELOG.md` file (See
https://keepachangelog.com/)
  • Loading branch information
plaguss authored Nov 10, 2023
1 parent da043b1 commit 876806f
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ These are the section headers that we use:
### Fixed

- Fixed error in `ArgillaTrainer`, with numerical labels use `RatingQuestion` instead of `RankingQuestion` ([#4171](https://github.com/argilla-io/argilla/pull/4171))
- Fixed error in `ArgillaTrainer`, now we can train for `extractive_question_answering` using a validation sample ([#4204](https://github.com/argilla-io/argilla/pull/4204))

## [1.19.0]()

Expand Down
4 changes: 4 additions & 0 deletions src/argilla/client/feedback/training/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,7 @@ def _prepare_for_training_with_transformers(
"question": [],
"context": [],
"answer": [],
"id": [],
}
for entry in data:
if any([entry.get("question") is None, entry.get("context") is None, entry.get("answer") is None]):
Expand All @@ -1316,6 +1317,8 @@ def _prepare_for_training_with_transformers(
datasets_dict["context"].append(entry["context"])
datasets_dict["answer"].append({"answer_start": [answer_start], "text": [entry["answer"]]})

datasets_dict["id"] = list(range(len(data)))

feature_dict = {
"question": datasets.Value("string"),
"context": datasets.Value("string"),
Expand All @@ -1327,6 +1330,7 @@ def _prepare_for_training_with_transformers(
length=-1,
id=None,
),
"id": datasets.Value(dtype="int32"),
}

ds = datasets.Dataset.from_dict(datasets_dict, features=datasets.Features(feature_dict))
Expand Down
103 changes: 91 additions & 12 deletions src/argilla/training/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,31 @@ def question_answering_preprocess_function(examples):
inputs["end_positions"] = end_positions
return inputs

def question_answering_preprocess_function_validation(examples):
questions = [q.strip() for q in examples["question"]]
inputs = self._transformers_tokenizer(
questions,
examples["context"],
truncation="only_second",
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)

sample_map = inputs.pop("overflow_to_sample_mapping")
example_ids = []

for i in range(len(inputs["input_ids"])):
sample_idx = sample_map[i]
example_ids.append(examples["id"][sample_idx])

sequence_ids = inputs.sequence_ids(i)
offset = inputs["offset_mapping"][i]
inputs["offset_mapping"][i] = [o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)]

inputs["example_id"] = example_ids
return inputs

# set correct tokenization
if self._record_class == TextClassificationRecord:
preprocess_function = text_classification_preprocess_function
Expand Down Expand Up @@ -323,15 +348,26 @@ def question_answering_preprocess_function(examples):
self._tokenized_train_dataset = self._tokenized_train_dataset.rename_column("label", "labels")

if self._eval_dataset is not None:
self._tokenized_eval_dataset = self._eval_dataset.map(
preprocess_function, batched=True, remove_columns=remove_columns
)
if self._model_class == AutoModelForQuestionAnswering:
# We need to preprocess the validation dataset separately, because we need to return the example_id
self._tokenized_eval_dataset = self._eval_dataset.map(
question_answering_preprocess_function_validation,
batched=True,
remove_columns=remove_columns,
)
else:
self._tokenized_eval_dataset = self._eval_dataset.map(
preprocess_function, batched=True, remove_columns=remove_columns
)

if replace_labels:
self._tokenized_eval_dataset = self._tokenized_eval_dataset.rename_column("label", "labels")
else:
self._tokenized_eval_dataset = None

def compute_metrics(self):
import collections

import evaluate
import numpy as np
from transformers import AutoModelForQuestionAnswering
Expand Down Expand Up @@ -391,19 +427,62 @@ def compute_metrics(p):

func = compute_metrics
elif AutoModelForQuestionAnswering:
f1 = evaluate.load("f1")
squad = evaluate.load("squad")

def compute_metrics_question_answering(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# Copy from https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt#fine-tuning-the-model-with-the-trainer-api
n_best = 20

# Calculate Exact Match (EM)
em = sum([int(p == l) for p, l in zip(preds, labels)]) / len(labels)
def compute_metrics_question_answering(pred):
start_logits, end_logits = pred.predictions
features = self._tokenized_eval_dataset
examples = self._eval_dataset

example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(features):
example_to_features[feature["example_id"]].append(idx)

predicted_answers = []
for example in examples:
example_id = example["id"]
context = example["context"]
answers = []

# Loop through all features associated with that example
for feature_index in example_to_features[example_id]:
start_logit = start_logits[feature_index]
end_logit = end_logits[feature_index]
offsets = features[feature_index]["offset_mapping"]

start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
for start_index in start_indexes:
for end_index in end_indexes:
# Skip answers that are not fully in the context
if offsets[start_index] is None or offsets[end_index] is None:
continue
# Skip answers with a length that is either < 0 or > max_answer_length
if (
end_index < start_index
or end_index - start_index + 1 > self._transformers_tokenizer.model_max_length
):
continue

answer = {
"text": context[offsets[start_index][0] : offsets[end_index][1]],
"logit_score": start_logit[start_index] + end_logit[end_index],
}
answers.append(answer)

# Select the answer with the best score
if len(answers) > 0:
best_answer = max(answers, key=lambda x: x["logit_score"])
predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})
else:
predicted_answers.append({"id": example_id, "prediction_text": ""})

# Calculate F1-score
f1_score = f1(labels, preds, average="macro")
theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]

return {"exact_match": em, "f1": f1_score}
return squad.compute(predictions=predicted_answers, references=theoretical_answers)

func = compute_metrics_question_answering
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/client/feedback/training/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,9 @@ def test_question_answering_without_formatting_func(
context=dataset.field_by_name("text"),
answer=dataset.question_by_name("question-1"),
)
trainer = ArgillaTrainer(dataset=dataset, task=task, framework="transformers")
trainer = ArgillaTrainer(dataset=dataset, task=task, framework="transformers", train_size=0.8)
trainer.update_config(num_iterations=1)
trainer.train(__OUTPUT_DIR__)
train_with_cleanup(trainer, __OUTPUT_DIR__)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 876806f

Please sign in to comment.