From 1bc7bd03c456e65d210971ae32ca243165927fc8 Mon Sep 17 00:00:00 2001
From: Yifan Mai <yifan@cs.stanford.edu>
Date: Mon, 13 Jan 2025 20:36:54 -0800
Subject: [PATCH] Allow running recipes from the Unitxt catalog (#3267)

---
 src/helm/benchmark/metrics/unitxt_metrics.py     |  5 ++++-
 src/helm/benchmark/run_specs/unitxt_run_specs.py |  7 ++++---
 src/helm/benchmark/scenarios/unitxt_scenario.py  | 10 ++++++++--
 3 files changed, 16 insertions(+), 6 deletions(-)

diff --git a/src/helm/benchmark/metrics/unitxt_metrics.py b/src/helm/benchmark/metrics/unitxt_metrics.py
index 2fc684b9c24..95ea4325bca 100644
--- a/src/helm/benchmark/metrics/unitxt_metrics.py
+++ b/src/helm/benchmark/metrics/unitxt_metrics.py
@@ -18,7 +18,10 @@ class UnitxtMetric(MetricInterface):
 
     def __init__(self, **kwargs):
         super().__init__()
-        dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items())
+        if len(kwargs) == 1 and "recipe" in kwargs:
+            dataset_name = kwargs["recipe"]
+        else:
+            dataset_name = ",".join(f"{key}={value}" for key, value in kwargs.items())
         self.dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True)
 
     def evaluate(
diff --git a/src/helm/benchmark/run_specs/unitxt_run_specs.py b/src/helm/benchmark/run_specs/unitxt_run_specs.py
index 3b2ee2e201a..0a4b51a9d59 100644
--- a/src/helm/benchmark/run_specs/unitxt_run_specs.py
+++ b/src/helm/benchmark/run_specs/unitxt_run_specs.py
@@ -10,8 +10,9 @@
 @run_spec_function("unitxt")
 def get_unitxt_spec(**kwargs) -> RunSpec:
     card = kwargs.get("card")
-    if not card:
-        raise Exception("Unitxt card must be specified")
+    recipe = kwargs.get("recipe")
+    if not card and not recipe:
+        raise Exception("Unitxt card or recipe must be specified")
     if os.environ.get("HELM_UNITXT_SHORTEN_RUN_SPEC_NAMES", "").lower() == "true":
         name_suffix = ",".join(
             [f"{key}={value}" for key, value in kwargs.items() if key not in ["template_card_index", "loader_limit"]]
@@ -46,5 +47,5 @@ def get_unitxt_spec(**kwargs) -> RunSpec:
             MetricSpec(class_name="helm.benchmark.metrics.unitxt_metrics.UnitxtMetric", args=kwargs),
         ]
         + get_basic_metric_specs([]),
-        groups=[f"unitxt_{card}"],
+        groups=[f"unitxt_{card or recipe}"],
     )
diff --git a/src/helm/benchmark/scenarios/unitxt_scenario.py b/src/helm/benchmark/scenarios/unitxt_scenario.py
index 95e0d125464..e321c066a47 100644
--- a/src/helm/benchmark/scenarios/unitxt_scenario.py
+++ b/src/helm/benchmark/scenarios/unitxt_scenario.py
@@ -32,13 +32,19 @@ def __init__(self, **kwargs):
         self.kwargs = kwargs
 
     def get_instances(self, output_path: str) -> List[Instance]:
-        dataset_name = ",".join(f"{key}={value}" for key, value in self.kwargs.items())
+        if len(self.kwargs) == 1 and "recipe" in self.kwargs:
+            dataset_name = self.kwargs["recipe"]
+        else:
+            dataset_name = ",".join(f"{key}={value}" for key, value in self.kwargs.items())
         dataset = load_dataset("unitxt/data", dataset_name, trust_remote_code=True)
 
         instances: List[Instance] = []
 
         for unitxt_split_name, helm_split_name in UnitxtScenario.UNITXT_SPLIT_NAME_TO_HELM_SPLIT_NAME.items():
-            for index, row in enumerate(dataset[unitxt_split_name]):
+            dataset_split = dataset.get(unitxt_split_name)
+            if dataset_split is None:
+                continue
+            for index, row in enumerate(dataset_split):
                 references = [
                     Reference(
                         output=Output(text=reference_text),