|
106 | 106 | "source": [
|
107 | 107 | "import os\n",
|
108 | 108 | "import openai\n",
|
| 109 | + "\n", |
109 | 110 | "openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
|
110 | 111 | "\n",
|
111 | 112 | "completion = openai.ChatCompletion.create(\n",
|
112 |
| - " model=\"gpt-3.5-turbo\",\n", |
113 |
| - " messages=[\n", |
114 |
| - " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", |
115 |
| - " ]\n", |
| 113 | + " model=\"gpt-3.5-turbo\",\n", |
| 114 | + " messages=[\n", |
| 115 | + " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n", |
| 116 | + " ],\n", |
116 | 117 | ")\n",
|
117 | 118 | "\n",
|
118 |
| - "print(completion.choices[0].message)\n" |
| 119 | + "print(completion.choices[0].message)" |
119 | 120 | ]
|
120 | 121 | },
|
121 | 122 | {
|
|
125 | 126 | "metadata": {},
|
126 | 127 | "outputs": [],
|
127 | 128 | "source": [
|
128 |
| - "\n", |
129 | 129 | "def llm2(prompt, **kwargs):\n",
|
130 | 130 | " response = openai.ChatCompletion.create(\n",
|
131 |
| - " model=kwargs.get(\"model\",\"gpt-3.5-turbo-16k\"),\n", |
132 |
| - " messages=[{\"role\": \"system\", \"content\":prompt}],\n", |
| 131 | + " model=kwargs.get(\"model\", \"gpt-3.5-turbo-16k\"),\n", |
| 132 | + " messages=[{\"role\": \"system\", \"content\": prompt}],\n", |
133 | 133 | " temperature=kwargs.get(\"temperature\", 0),\n",
|
134 | 134 | " top_p=kwargs.get(\"top_p\", 1),\n",
|
135 | 135 | " frequency_penalty=kwargs.get(\"frequency_penalty\", 0.0),\n",
|
|
139 | 139 | " )\n",
|
140 | 140 | " return response\n",
|
141 | 141 | "\n",
|
| 142 | + "\n", |
142 | 143 | "def llm(prompt, **kwargs):\n",
|
143 | 144 | " response = openai.Completion.create(\n",
|
144 | 145 | " model=kwargs.get(\"model\", \"text-davinci-003\"),\n",
|
|
375 | 376 | }
|
376 | 377 | ],
|
377 | 378 | "source": [
|
378 |
| - "llm2([Question_generation.format(2,answer)])" |
| 379 | + "llm2([Question_generation.format(2, answer)])" |
379 | 380 | ]
|
380 | 381 | },
|
381 | 382 | {
|
|
1039 | 1040 | ],
|
1040 | 1041 | "source": [
|
1041 | 1042 | "def get_all_facts(item):\n",
|
1042 |
| - " all_facts = item['context']['sentences']\n", |
| 1043 | + " all_facts = item[\"context\"][\"sentences\"]\n", |
1043 | 1044 | " all_facts = [sent for para in all_facts for sent in para]\n",
|
1044 |
| - " return {\"full_context\":''.join(all_facts)}\n", |
1045 |
| - "hotpot_qa = hotpot_qa.map(get_all_facts, batched=False) " |
| 1045 | + " return {\"full_context\": \"\".join(all_facts)}\n", |
| 1046 | + "\n", |
| 1047 | + "\n", |
| 1048 | + "hotpot_qa = hotpot_qa.map(get_all_facts, batched=False)" |
1046 | 1049 | ]
|
1047 | 1050 | },
|
1048 | 1051 | {
|
|
1090 | 1093 | "metadata": {},
|
1091 | 1094 | "outputs": [],
|
1092 | 1095 | "source": [
|
1093 |
| - "i=15\n", |
1094 |
| - "q,c = hotpot_qa[i]['question'],hotpot_qa[i]['full_context']" |
| 1096 | + "i = 15\n", |
| 1097 | + "q, c = hotpot_qa[i][\"question\"], hotpot_qa[i][\"full_context\"]" |
1095 | 1098 | ]
|
1096 | 1099 | },
|
1097 | 1100 | {
|
|
1112 | 1115 | "outputs": [],
|
1113 | 1116 | "source": [
|
1114 | 1117 | "q = \"what is general relativity?\"\n",
|
1115 |
| - "n=2" |
| 1118 | + "n = 2" |
1116 | 1119 | ]
|
1117 | 1120 | },
|
1118 | 1121 | {
|
|
1123 | 1126 | "outputs": [],
|
1124 | 1127 | "source": [
|
1125 | 1128 | "import wikipediaapi\n",
|
| 1129 | + "\n", |
1126 | 1130 | "wiki_wiki = wikipediaapi.Wikipedia(\n",
|
1127 |
| - " language='en',\n", |
1128 |
| - " extract_format=wikipediaapi.ExtractFormat.WIKI\n", |
| 1131 | + " language=\"en\", extract_format=wikipediaapi.ExtractFormat.WIKI\n", |
1129 | 1132 | ")\n",
|
1130 | 1133 | "\n",
|
1131 | 1134 | "p_wiki = wiki_wiki.page(\"Black hole\")\n",
|
1132 | 1135 | "\n",
|
| 1136 | + "\n", |
1133 | 1137 | "def get_page_section(page, section):\n",
|
1134 | 1138 | " all_text = \"\"\n",
|
1135 | 1139 | " p_wiki = wiki_wiki.page(page)\n",
|
1136 | 1140 | " sections = p_wiki.sections_by_title(section)\n",
|
1137 | 1141 | " for s in sections:\n",
|
1138 | 1142 | " all_text += s.full_text()\n",
|
1139 |
| - " return all_text\n" |
| 1143 | + " return all_text" |
1140 | 1144 | ]
|
1141 | 1145 | },
|
1142 | 1146 | {
|
|
1152 | 1156 | "\n",
|
1153 | 1157 | "cross_encoder = CrossEncoder(\"cross-encoder/stsb-TinyBERT-L-4\")\n",
|
1154 | 1158 | "\n",
|
1155 |
| - " \n", |
| 1159 | + "\n", |
1156 | 1160 | "def sent_tokenize(sent):\n",
|
1157 |
| - " return [s[:-1] if s.endswith('.') else s for s in sent.strip().split('. ')]\n", |
| 1161 | + " return [s[:-1] if s.endswith(\".\") else s for s in sent.strip().split(\". \")]\n", |
| 1162 | + "\n", |
1158 | 1163 | "\n",
|
1159 | 1164 | "class SentenceAgreement:\n",
|
1160 |
| - " \n", |
1161 | 1165 | " def __init__(self, scoring=\"bert_score\"):\n",
|
1162 |
| - " \n", |
1163 | 1166 | " self.scoring = scoring\n",
|
1164 | 1167 | "\n",
|
1165 |
| - " \n", |
1166 | 1168 | " @staticmethod\n",
|
1167 | 1169 | " def bert_score(para1, para2):\n",
|
1168 |
| - " \n", |
1169 | 1170 | " sentences1, sentences2 = sent_tokenize(para1), sent_tokenize(para2)\n",
|
1170 | 1171 | " scores = cross_encoder.predict(list(itertools.product(sentences1, sentences2)))\n",
|
1171 | 1172 | " scores = scores.reshape(len(sentences1), len(sentences2))\n",
|
1172 | 1173 | " return scores.max(axis=1).mean()\n",
|
1173 | 1174 | "\n",
|
1174 | 1175 | " @staticmethod\n",
|
1175 | 1176 | " def jaccard_score(para1, para2):\n",
|
1176 |
| - " \n", |
1177 | 1177 | " sentences1, sentences2 = sent_tokenize(para1), sent_tokenize(para2)\n",
|
1178 | 1178 | " intersect = len(np.intersect1d(sentences1, sentences2))\n",
|
1179 | 1179 | " union = len(np.union1d(sentences1, sentences2))\n",
|
1180 |
| - " return intersect/union\n", |
1181 |
| - " \n", |
1182 |
| - " def evaluate(self,answers:List[List[str]]):\n", |
1183 |
| - " \n", |
| 1180 | + " return intersect / union\n", |
| 1181 | + "\n", |
| 1182 | + " def evaluate(self, answers: List[List[str]]):\n", |
1184 | 1183 | " \"\"\"\n",
|
1185 | 1184 | " eval nC2 combinations\n",
|
1186 | 1185 | " \"\"\"\n",
|
1187 | 1186 | " scores = []\n",
|
1188 |
| - " groups = combinations(answers,2)\n", |
| 1187 | + " groups = combinations(answers, 2)\n", |
1189 | 1188 | " for group in groups:\n",
|
1190 | 1189 | " if self.scoring == \"jaccard\":\n",
|
1191 | 1190 | " score = self.jaccard_score(*group)\n",
|
1192 | 1191 | " elif self.scoring == \"bert_score\":\n",
|
1193 | 1192 | " score = self.bert_score(*group)\n",
|
1194 | 1193 | " scores.append(score)\n",
|
1195 |
| - " return np.mean(scores)\n", |
1196 |
| - " " |
| 1194 | + " return np.mean(scores)" |
1197 | 1195 | ]
|
1198 | 1196 | },
|
1199 | 1197 | {
|
|
1204 | 1202 | "outputs": [],
|
1205 | 1203 | "source": [
|
1206 | 1204 | "class ContextRelevacy:\n",
|
1207 |
| - " \n", |
1208 |
| - " def __init__(self, strictness = 2, agreement_metric=\"bert_score\"):\n", |
1209 |
| - " \n", |
| 1205 | + " def __init__(self, strictness=2, agreement_metric=\"bert_score\"):\n", |
1210 | 1206 | " self.strictness = strictness\n",
|
1211 | 1207 | " self.sent_agreement = SentenceAgreement(agreement_metric)\n",
|
1212 |
| - " \n", |
1213 |
| - " def score(self,question,context):\n", |
| 1208 | + "\n", |
| 1209 | + " def score(self, question, context):\n", |
1214 | 1210 | " scores = []\n",
|
1215 |
| - " outputs = llm(Context_relevency.format(q,c),n=self.strictness,temperature=1)\n", |
1216 |
| - " outputs = [outputs['choices'][i]['text'].strip() for i in range(self.strictness)]\n", |
| 1211 | + " outputs = llm(Context_relevency.format(q, c), n=self.strictness, temperature=1)\n", |
| 1212 | + " outputs = [\n", |
| 1213 | + " outputs[\"choices\"][i][\"text\"].strip() for i in range(self.strictness)\n", |
| 1214 | + " ]\n", |
1217 | 1215 | " context_sents = sent_tokenize(context)\n",
|
1218 | 1216 | " for output in outputs:\n",
|
1219 |
| - " indices = [context.find(sent) for sent in sent_tokenize(output) if context.find(sent)!=-1]\n", |
1220 |
| - " scores.append(len(indices)/len(context_sents))\n", |
1221 |
| - " \n", |
| 1217 | + " indices = [\n", |
| 1218 | + " context.find(sent)\n", |
| 1219 | + " for sent in sent_tokenize(output)\n", |
| 1220 | + " if context.find(sent) != -1\n", |
| 1221 | + " ]\n", |
| 1222 | + " scores.append(len(indices) / len(context_sents))\n", |
| 1223 | + "\n", |
1222 | 1224 | " if self.strictness > 1:\n",
|
1223 | 1225 | " agr_score = self.sent_agreement.evaluate(outputs)\n",
|
1224 | 1226 | " else:\n",
|
1225 |
| - " agr_score =1 \n", |
1226 |
| - " return agr_score * np.mean(scores)\n" |
| 1227 | + " agr_score = 1\n", |
| 1228 | + " return agr_score * np.mean(scores)" |
1227 | 1229 | ]
|
1228 | 1230 | },
|
1229 | 1231 | {
|
|
1234 | 1236 | "outputs": [],
|
1235 | 1237 | "source": [
|
1236 | 1238 | "c = get_page_section(\"HIV/AIDS\", \"Prevention\")\n",
|
1237 |
| - "c = ' '.join(c.split(' ')[:500])\n", |
| 1239 | + "c = \" \".join(c.split(\" \")[:500])\n", |
1238 | 1240 | "q = \"When was the first HIV case detected?\""
|
1239 | 1241 | ]
|
1240 | 1242 | },
|
|
1245 | 1247 | "metadata": {},
|
1246 | 1248 | "outputs": [],
|
1247 | 1249 | "source": [
|
1248 |
| - "output = llm([Context_relevency.format(q,c), Context_relevency.format(\"How to prevent AIDS?\",c)],n=n,temperature=1)" |
| 1250 | + "output = llm(\n", |
| 1251 | + " [\n", |
| 1252 | + " Context_relevency.format(q, c),\n", |
| 1253 | + " Context_relevency.format(\"How to prevent AIDS?\", c),\n", |
| 1254 | + " ],\n", |
| 1255 | + " n=n,\n", |
| 1256 | + " temperature=1,\n", |
| 1257 | + ")" |
1249 | 1258 | ]
|
1250 | 1259 | },
|
1251 | 1260 | {
|
|
1397 | 1406 | }
|
1398 | 1407 | ],
|
1399 | 1408 | "source": [
|
1400 |
| - "context_relevancy.score(dataset[\"baseline\"].select(range(0,3)))" |
| 1409 | + "context_relevancy.score(dataset[\"baseline\"].select(range(0, 3)))" |
1401 | 1410 | ]
|
1402 | 1411 | },
|
1403 | 1412 | {
|
|
1491 | 1500 | }
|
1492 | 1501 | ],
|
1493 | 1502 | "source": [
|
1494 |
| - "context_relevancy.score(dataset[\"baseline\"].select(range(0,3)))" |
| 1503 | + "context_relevancy.score(dataset[\"baseline\"].select(range(0, 3)))" |
1495 | 1504 | ]
|
1496 | 1505 | },
|
1497 | 1506 | {
|
|
0 commit comments