diff --git a/scripts/run_senteval.py b/scripts/run_senteval.py index f6b23f4..daaa2a6 100644 --- a/scripts/run_senteval.py +++ b/scripts/run_senteval.py @@ -22,20 +22,10 @@ logger = logging.getLogger(__name__) AGGREGATE_SCORES_KEY = "aggregate_scores" +# subset of downstream tasks only for Semantic Text Similarity DOWNSTREAM_TASKS = [ - "CR", - "MR", - "MPQA", - "SUBJ", - "SST2", - "SST5", - "TREC", - "MRPC", - "SNLI", - "SICKEntailment", "SICKRelatedness", "STSBenchmark", - "ImageCaptionRetrieval", "STS12", "STS13", "STS14", @@ -43,16 +33,6 @@ "STS16", ] PROBING_TASKS = [ - "Length", - "WordContent", - "Depth", - "TopConstituents", - "BigramShift", - "Tense", - "SubjNumber", - "ObjNumber", - "OddManOut", - "CoordinationInversion", ] TRANSFER_TASKS = DOWNSTREAM_TASKS + PROBING_TASKS @@ -168,9 +148,10 @@ def _compute_aggregate_scores(results: Dict, ignore_tasks: List[str] = None) -> # Aggregate scores for "downstream" tasks aggregate_scores["downstream"]["dev"] /= num_downstream_tasks aggregate_scores["downstream"]["test"] /= num_downstream_tasks + # remove probing task from division to prevent zero division errors # Aggregate scores for "probing" tasks - aggregate_scores["probing"]["dev"] /= num_probing_tasks - aggregate_scores["probing"]["test"] /= num_probing_tasks + # aggregate_scores["probing"]["dev"] /= num_probing_tasks + # aggregate_scores["probing"]["test"] /= num_probing_tasks return aggregate_scores