Skip to content

Commit b5770f0

Browse files
authored
feat: added langchains LLM (#66)
fixes #46 fixes #53 You can now use langchains LLM abstraction to access all the LLM endpoints langchain supports. eg ```python from langchain.chat_models import ChatOpenAI gpt4 = ChatOpenAI(model_name="gpt-4") gpt4.generate_prompt(prompts=[prompts]) # init a new Metric with llm cr = ContextRelevancy(llm=gpt4) cr.init_model() result = cr.score(ds.select(range(4))) result["context_relavency"] # [0.46687018871307373, 0.1532887363433838,0.17359847468989234, 0.17340516530234237] ``` We're also now using OpenAI's chat models as default which brings a 10x decrease in cost. also using `gpt-3.5-turbo-16k` as default for even bigger context size
1 parent eefb0ca commit b5770f0

File tree

9 files changed

+208
-158
lines changed

9 files changed

+208
-158
lines changed

.github/workflows/ci.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ jobs:
3838
ragas:
3939
- "src/ragas/**"
4040
- "tests/**"
41-
- "examples/**"
4241
docs:
4342
- *related
4443
- requirements/docs-requirements.txt
@@ -52,7 +51,7 @@ jobs:
5251
fail-fast: false
5352
matrix:
5453
os: [ubuntu-latest, macos-latest, windows-latest]
55-
python-version: ["3.7", "3.8", "3.9", "3.10"]
54+
python-version: ["3.8", "3.9", "3.10"]
5655

5756
if: ${{ (github.event_name == 'pull_request' && needs.diff.outputs.ragas == 'true') || github.event_name == 'push' }}
5857
name: python${{ matrix.python-version }}_unit_tests (${{ matrix.os }})
@@ -86,6 +85,7 @@ jobs:
8685
pip install "."
8786
pip install -r requirements/test.txt
8887
88+
8989
- name: Run unit tests
9090
run: |
9191
# OPTS=(--cov-config pyproject.toml --cov=src/bentoml --cov-append)
@@ -94,7 +94,7 @@ jobs:
9494
OPTS=(--dist loadfile -n auto)
9595
fi
9696
# Now run the unit tests
97-
pytest tests/unit "${OPTS[@]}"
97+
OPENAI_API_KEY="test" pytest tests/unit "${OPTS[@]}"
9898
9999
codestyle_check:
100100
runs-on: ubuntu-latest

experiments/assesments/metrics_assesments.ipynb

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,17 @@
106106
"source": [
107107
"import os\n",
108108
"import openai\n",
109+
"\n",
109110
"openai.api_key = os.getenv(\"OPENAI_API_KEY\")\n",
110111
"\n",
111112
"completion = openai.ChatCompletion.create(\n",
112-
" model=\"gpt-3.5-turbo\",\n",
113-
" messages=[\n",
114-
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
115-
" ]\n",
113+
" model=\"gpt-3.5-turbo\",\n",
114+
" messages=[\n",
115+
" {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"},\n",
116+
" ],\n",
116117
")\n",
117118
"\n",
118-
"print(completion.choices[0].message)\n"
119+
"print(completion.choices[0].message)"
119120
]
120121
},
121122
{
@@ -125,11 +126,10 @@
125126
"metadata": {},
126127
"outputs": [],
127128
"source": [
128-
"\n",
129129
"def llm2(prompt, **kwargs):\n",
130130
" response = openai.ChatCompletion.create(\n",
131-
" model=kwargs.get(\"model\",\"gpt-3.5-turbo-16k\"),\n",
132-
" messages=[{\"role\": \"system\", \"content\":prompt}],\n",
131+
" model=kwargs.get(\"model\", \"gpt-3.5-turbo-16k\"),\n",
132+
" messages=[{\"role\": \"system\", \"content\": prompt}],\n",
133133
" temperature=kwargs.get(\"temperature\", 0),\n",
134134
" top_p=kwargs.get(\"top_p\", 1),\n",
135135
" frequency_penalty=kwargs.get(\"frequency_penalty\", 0.0),\n",
@@ -139,6 +139,7 @@
139139
" )\n",
140140
" return response\n",
141141
"\n",
142+
"\n",
142143
"def llm(prompt, **kwargs):\n",
143144
" response = openai.Completion.create(\n",
144145
" model=kwargs.get(\"model\", \"text-davinci-003\"),\n",
@@ -375,7 +376,7 @@
375376
}
376377
],
377378
"source": [
378-
"llm2([Question_generation.format(2,answer)])"
379+
"llm2([Question_generation.format(2, answer)])"
379380
]
380381
},
381382
{
@@ -1039,10 +1040,12 @@
10391040
],
10401041
"source": [
10411042
"def get_all_facts(item):\n",
1042-
" all_facts = item['context']['sentences']\n",
1043+
" all_facts = item[\"context\"][\"sentences\"]\n",
10431044
" all_facts = [sent for para in all_facts for sent in para]\n",
1044-
" return {\"full_context\":''.join(all_facts)}\n",
1045-
"hotpot_qa = hotpot_qa.map(get_all_facts, batched=False) "
1045+
" return {\"full_context\": \"\".join(all_facts)}\n",
1046+
"\n",
1047+
"\n",
1048+
"hotpot_qa = hotpot_qa.map(get_all_facts, batched=False)"
10461049
]
10471050
},
10481051
{
@@ -1090,8 +1093,8 @@
10901093
"metadata": {},
10911094
"outputs": [],
10921095
"source": [
1093-
"i=15\n",
1094-
"q,c = hotpot_qa[i]['question'],hotpot_qa[i]['full_context']"
1096+
"i = 15\n",
1097+
"q, c = hotpot_qa[i][\"question\"], hotpot_qa[i][\"full_context\"]"
10951098
]
10961099
},
10971100
{
@@ -1112,7 +1115,7 @@
11121115
"outputs": [],
11131116
"source": [
11141117
"q = \"what is general relativity?\"\n",
1115-
"n=2"
1118+
"n = 2"
11161119
]
11171120
},
11181121
{
@@ -1123,20 +1126,21 @@
11231126
"outputs": [],
11241127
"source": [
11251128
"import wikipediaapi\n",
1129+
"\n",
11261130
"wiki_wiki = wikipediaapi.Wikipedia(\n",
1127-
" language='en',\n",
1128-
" extract_format=wikipediaapi.ExtractFormat.WIKI\n",
1131+
" language=\"en\", extract_format=wikipediaapi.ExtractFormat.WIKI\n",
11291132
")\n",
11301133
"\n",
11311134
"p_wiki = wiki_wiki.page(\"Black hole\")\n",
11321135
"\n",
1136+
"\n",
11331137
"def get_page_section(page, section):\n",
11341138
" all_text = \"\"\n",
11351139
" p_wiki = wiki_wiki.page(page)\n",
11361140
" sections = p_wiki.sections_by_title(section)\n",
11371141
" for s in sections:\n",
11381142
" all_text += s.full_text()\n",
1139-
" return all_text\n"
1143+
" return all_text"
11401144
]
11411145
},
11421146
{
@@ -1152,48 +1156,42 @@
11521156
"\n",
11531157
"cross_encoder = CrossEncoder(\"cross-encoder/stsb-TinyBERT-L-4\")\n",
11541158
"\n",
1155-
" \n",
1159+
"\n",
11561160
"def sent_tokenize(sent):\n",
1157-
" return [s[:-1] if s.endswith('.') else s for s in sent.strip().split('. ')]\n",
1161+
" return [s[:-1] if s.endswith(\".\") else s for s in sent.strip().split(\". \")]\n",
1162+
"\n",
11581163
"\n",
11591164
"class SentenceAgreement:\n",
1160-
" \n",
11611165
" def __init__(self, scoring=\"bert_score\"):\n",
1162-
" \n",
11631166
" self.scoring = scoring\n",
11641167
"\n",
1165-
" \n",
11661168
" @staticmethod\n",
11671169
" def bert_score(para1, para2):\n",
1168-
" \n",
11691170
" sentences1, sentences2 = sent_tokenize(para1), sent_tokenize(para2)\n",
11701171
" scores = cross_encoder.predict(list(itertools.product(sentences1, sentences2)))\n",
11711172
" scores = scores.reshape(len(sentences1), len(sentences2))\n",
11721173
" return scores.max(axis=1).mean()\n",
11731174
"\n",
11741175
" @staticmethod\n",
11751176
" def jaccard_score(para1, para2):\n",
1176-
" \n",
11771177
" sentences1, sentences2 = sent_tokenize(para1), sent_tokenize(para2)\n",
11781178
" intersect = len(np.intersect1d(sentences1, sentences2))\n",
11791179
" union = len(np.union1d(sentences1, sentences2))\n",
1180-
" return intersect/union\n",
1181-
" \n",
1182-
" def evaluate(self,answers:List[List[str]]):\n",
1183-
" \n",
1180+
" return intersect / union\n",
1181+
"\n",
1182+
" def evaluate(self, answers: List[List[str]]):\n",
11841183
" \"\"\"\n",
11851184
" eval nC2 combinations\n",
11861185
" \"\"\"\n",
11871186
" scores = []\n",
1188-
" groups = combinations(answers,2)\n",
1187+
" groups = combinations(answers, 2)\n",
11891188
" for group in groups:\n",
11901189
" if self.scoring == \"jaccard\":\n",
11911190
" score = self.jaccard_score(*group)\n",
11921191
" elif self.scoring == \"bert_score\":\n",
11931192
" score = self.bert_score(*group)\n",
11941193
" scores.append(score)\n",
1195-
" return np.mean(scores)\n",
1196-
" "
1194+
" return np.mean(scores)"
11971195
]
11981196
},
11991197
{
@@ -1204,26 +1202,30 @@
12041202
"outputs": [],
12051203
"source": [
12061204
"class ContextRelevacy:\n",
1207-
" \n",
1208-
" def __init__(self, strictness = 2, agreement_metric=\"bert_score\"):\n",
1209-
" \n",
1205+
" def __init__(self, strictness=2, agreement_metric=\"bert_score\"):\n",
12101206
" self.strictness = strictness\n",
12111207
" self.sent_agreement = SentenceAgreement(agreement_metric)\n",
1212-
" \n",
1213-
" def score(self,question,context):\n",
1208+
"\n",
1209+
" def score(self, question, context):\n",
12141210
" scores = []\n",
1215-
" outputs = llm(Context_relevency.format(q,c),n=self.strictness,temperature=1)\n",
1216-
" outputs = [outputs['choices'][i]['text'].strip() for i in range(self.strictness)]\n",
1211+
" outputs = llm(Context_relevency.format(q, c), n=self.strictness, temperature=1)\n",
1212+
" outputs = [\n",
1213+
" outputs[\"choices\"][i][\"text\"].strip() for i in range(self.strictness)\n",
1214+
" ]\n",
12171215
" context_sents = sent_tokenize(context)\n",
12181216
" for output in outputs:\n",
1219-
" indices = [context.find(sent) for sent in sent_tokenize(output) if context.find(sent)!=-1]\n",
1220-
" scores.append(len(indices)/len(context_sents))\n",
1221-
" \n",
1217+
" indices = [\n",
1218+
" context.find(sent)\n",
1219+
" for sent in sent_tokenize(output)\n",
1220+
" if context.find(sent) != -1\n",
1221+
" ]\n",
1222+
" scores.append(len(indices) / len(context_sents))\n",
1223+
"\n",
12221224
" if self.strictness > 1:\n",
12231225
" agr_score = self.sent_agreement.evaluate(outputs)\n",
12241226
" else:\n",
1225-
" agr_score =1 \n",
1226-
" return agr_score * np.mean(scores)\n"
1227+
" agr_score = 1\n",
1228+
" return agr_score * np.mean(scores)"
12271229
]
12281230
},
12291231
{
@@ -1234,7 +1236,7 @@
12341236
"outputs": [],
12351237
"source": [
12361238
"c = get_page_section(\"HIV/AIDS\", \"Prevention\")\n",
1237-
"c = ' '.join(c.split(' ')[:500])\n",
1239+
"c = \" \".join(c.split(\" \")[:500])\n",
12381240
"q = \"When was the first HIV case detected?\""
12391241
]
12401242
},
@@ -1245,7 +1247,14 @@
12451247
"metadata": {},
12461248
"outputs": [],
12471249
"source": [
1248-
"output = llm([Context_relevency.format(q,c), Context_relevency.format(\"How to prevent AIDS?\",c)],n=n,temperature=1)"
1250+
"output = llm(\n",
1251+
" [\n",
1252+
" Context_relevency.format(q, c),\n",
1253+
" Context_relevency.format(\"How to prevent AIDS?\", c),\n",
1254+
" ],\n",
1255+
" n=n,\n",
1256+
" temperature=1,\n",
1257+
")"
12491258
]
12501259
},
12511260
{
@@ -1397,7 +1406,7 @@
13971406
}
13981407
],
13991408
"source": [
1400-
"context_relevancy.score(dataset[\"baseline\"].select(range(0,3)))"
1409+
"context_relevancy.score(dataset[\"baseline\"].select(range(0, 3)))"
14011410
]
14021411
},
14031412
{
@@ -1491,7 +1500,7 @@
14911500
}
14921501
],
14931502
"source": [
1494-
"context_relevancy.score(dataset[\"baseline\"].select(range(0,3)))"
1503+
"context_relevancy.score(dataset[\"baseline\"].select(range(0, 3)))"
14951504
]
14961505
},
14971506
{

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ dependencies = [
66
"sentence-transformers",
77
"datasets",
88
"protobuf<=3.20.0",
9-
"backoff",
9+
"langchain>=0.0.218",
1010
"openai",
11+
"pydantic<2.0"
1112
]
1213
dynamic = ["version", "readme"]
1314

src/ragas/async_utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Async utils."""
2+
import asyncio
3+
from typing import Any, Coroutine, List
4+
5+
6+
def run_async_tasks(
7+
tasks: List[Coroutine],
8+
show_progress: bool = False,
9+
progress_bar_desc: str = "Running async tasks",
10+
) -> List[Any]:
11+
"""Run a list of async tasks."""
12+
13+
tasks_to_execute: List[Any] = tasks
14+
if show_progress:
15+
try:
16+
import nest_asyncio
17+
from tqdm.asyncio import tqdm
18+
19+
# jupyter notebooks already have an event loop running
20+
# we need to reuse it instead of creating a new one
21+
nest_asyncio.apply()
22+
loop = asyncio.get_event_loop()
23+
24+
async def _tqdm_gather() -> List[Any]:
25+
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
26+
27+
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
28+
return tqdm_outputs
29+
# run the operation w/o tqdm on hitting a fatal
30+
# may occur in some environments where tqdm.asyncio
31+
# is not supported
32+
except Exception:
33+
pass
34+
35+
async def _gather() -> List[Any]:
36+
return await asyncio.gather(*tasks_to_execute)
37+
38+
outputs: List[Any] = asyncio.run(_gather())
39+
return outputs

src/ragas/metrics/base.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from math import floor
1313

1414
from datasets import Dataset
15+
from langchain.chat_models.base import BaseChatModel
16+
from langchain.llms.base import BaseLLM
1517

1618

1719
def make_batches(total_size: int, batch_size: int) -> list[range]:
@@ -31,17 +33,18 @@ def make_batches(total_size: int, batch_size: int) -> list[range]:
3133

3234
@dataclass
3335
class Metric(ABC):
34-
@property
35-
@abstractmethod
36-
def batch_size(self: t.Self) -> int:
37-
...
36+
batch_size: int
37+
llm: t.Optional[BaseLLM | BaseChatModel] = None
38+
39+
def __post_init__(self: t.Self):
40+
if self.llm is None:
41+
from langchain.chat_models import ChatOpenAI
42+
43+
self.llm = ChatOpenAI(model_name="gpt-3.5-turbo-16k") # type: ignore
3844

3945
@property
4046
@abstractmethod
41-
def name(self: t.Self) -> str:
42-
"""
43-
the metric name
44-
"""
47+
def name(self) -> str:
4548
...
4649

4750
@abstractmethod

0 commit comments

Comments
 (0)