-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
2,100 additions
and
314 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
fondant==0.10.0 | ||
fondant[component,aws,azure,gcp,docker]==0.10.1 | ||
notebook==7.0.6 | ||
weaviate-client==3.25.3 | ||
langchain==0.0.329 |
Empty file.
This file was deleted.
Oops, something went wrong.
16 changes: 0 additions & 16 deletions
16
src/components/aggregate_eval_results/fondant_component.yaml
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
8 changes: 8 additions & 0 deletions
8
...onents/aggregate_eval_results/src/main.py → src/components/aggregrate_eval_results.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import typing as t | ||
import pandas as pd | ||
import pyarrow as pa | ||
from fondant.component import PandasTransformComponent | ||
from fondant.pipeline import lightweight_component | ||
|
||
@lightweight_component( | ||
consumes={"text":pa.string()}, | ||
produces={"text":pa.string(), "original_document_id":pa.string()}, | ||
extra_requires=["langchain==0.0.329"] | ||
) | ||
class ChunkTextComponent(PandasTransformComponent): | ||
"""Component that chunks text into smaller segments. | ||
More information about the different chunking strategies can be here: | ||
- https://python.langchain.com/docs/modules/data_connection/document_transformers/ | ||
- https://www.pinecone.io/learn/chunking-strategies/. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
chunk_size: int, | ||
chunk_overlap: int, | ||
): | ||
""" | ||
Args: | ||
chunk_size: the chunk size | ||
chunk_overlap: the overlap between chunks | ||
""" | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
self.chunker = RecursiveCharacterTextSplitter( | ||
chunk_size=chunk_size, | ||
chunk_overlap=chunk_overlap | ||
) | ||
|
||
def chunk_text(self, row) -> t.List[t.Tuple]: | ||
# Multi-index df has id under the name attribute | ||
doc_id = row.name | ||
text_data = row["text"] | ||
docs = self.chunker.create_documents([text_data]) | ||
|
||
return [ | ||
(doc_id, f"{doc_id}_{chunk_id}", chunk.page_content) | ||
for chunk_id, chunk in enumerate(docs) | ||
] | ||
|
||
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: | ||
import itertools | ||
results = dataframe.apply( | ||
self.chunk_text, | ||
axis=1, | ||
).to_list() | ||
|
||
# Flatten results | ||
results = list(itertools.chain.from_iterable(results)) | ||
|
||
# Turn into dataframes | ||
results_df = pd.DataFrame( | ||
results, | ||
columns=["original_document_id", "id", "text"], | ||
) | ||
results_df = results_df.set_index("id") | ||
|
||
return results_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
import typing as t | ||
import pyarrow as pa | ||
import pandas as pd | ||
from fondant.component import PandasTransformComponent | ||
from fondant.pipeline import lightweight_component | ||
|
||
|
||
|
||
@lightweight_component( | ||
consumes={ | ||
"question": pa.string(), | ||
"retrieved_chunks": pa.list_(pa.string()) | ||
}, | ||
produces={ | ||
"context_precision": pa.float32(), | ||
"context_relevancy": pa.float32() | ||
}, | ||
extra_requires=["ragas==0.0.21"] | ||
) | ||
class RagasEvaluator(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
llm_module_name: str, | ||
llm_class_name: str, | ||
llm_kwargs: dict | ||
) -> None: | ||
""" | ||
Args: | ||
llm_module_name: Module from which the LLM is imported. Defaults to | ||
langchain.chat_models | ||
llm_class_name: Name of the selected llm. Defaults to ChatOpenAI | ||
llm_kwargs: Arguments of the selected llm | ||
""" | ||
self.llm = self.extract_llm( | ||
llm_module_name=llm_module_name, | ||
llm_class_name=llm_class_name, | ||
llm_kwargs=llm_kwargs, | ||
) | ||
|
||
from ragas.llms import LangchainLLM | ||
self.gpt_wrapper = LangchainLLM(llm=self.llm) | ||
self.metric_functions = self.extract_metric_functions( | ||
metrics=["context_precision", "context_relevancy"], | ||
) | ||
self.set_llm(self.metric_functions) | ||
|
||
# import the metric functions selected | ||
@staticmethod | ||
def import_from(module_name: str, element_name: str): | ||
module = __import__(module_name, fromlist=[element_name]) | ||
return getattr(module, element_name) | ||
|
||
def extract_llm(self, llm_module_name: str, llm_class_name: str, llm_kwargs: dict): | ||
module = self.import_from( | ||
module_name=llm_module_name, | ||
element_name=llm_class_name, | ||
) | ||
return module(**llm_kwargs) | ||
|
||
def extract_metric_functions(self, metrics: list): | ||
functions = [] | ||
for metric in metrics: | ||
functions.append(self.import_from("ragas.metrics", metric)) | ||
return functions | ||
|
||
def set_llm(self, metric_functions: list): | ||
for metric_function in metric_functions: | ||
metric_function.llm = self.gpt_wrapper | ||
|
||
# evaluate the retriever | ||
@staticmethod | ||
def create_hf_ds(dataframe: pd.DataFrame): | ||
dataframe = dataframe.rename( | ||
columns={"retrieved_chunks": "contexts"}, | ||
) | ||
|
||
from datasets import Dataset | ||
return Dataset.from_pandas(dataframe) | ||
|
||
def ragas_eval(self, dataset): | ||
from ragas import evaluate | ||
return evaluate(dataset=dataset, metrics=self.metric_functions) | ||
|
||
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: | ||
hf_dataset = self.create_hf_ds( | ||
dataframe=dataframe[["question", "retrieved_chunks"]], | ||
) | ||
if "id" in hf_dataset.column_names: | ||
hf_dataset = hf_dataset.remove_columns("id") | ||
|
||
result = self.ragas_eval(dataset=hf_dataset) | ||
results_df = result.to_pandas() | ||
results_df = results_df.set_index(dataframe.index) | ||
|
||
return results_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import typing as t | ||
import dask.dataframe as dd | ||
import pandas as pd | ||
import pyarrow as pa | ||
from fondant.component import PandasTransformComponent | ||
from fondant.pipeline import lightweight_component | ||
|
||
@lightweight_component( | ||
produces={"retrieved_chunks": pa.list_(pa.string())}, | ||
extra_requires=["weaviate-client==3.24.1"] | ||
) | ||
class RetrieveFromWeaviateComponent(PandasTransformComponent): | ||
def __init__( | ||
self, | ||
*, | ||
weaviate_url: str, | ||
class_name: str, | ||
top_k: int, | ||
) -> None: | ||
""" | ||
Args: | ||
weaviate_url: An argument passed to the component. | ||
class_name: Name of class to query | ||
top_k: Amount of context to return. | ||
additional_config: Additional configuration passed to the weaviate client. | ||
additional_headers: Additional headers passed to the weaviate client. | ||
hybrid_query: The hybrid query to be used for retrieval. Optional parameter. | ||
hybrid_alpha: Argument to change how much each search affects the results. An alpha | ||
of 1 is a pure vector search. An alpha of 0 is a pure keyword search. | ||
rerank: Whether to rerank the results based on the hybrid query. Defaults to False. | ||
Check this notebook for more information on reranking: | ||
https://github.com/weaviate/recipes/blob/main/ranking/cohere-ranking/cohere-ranking.ipynb | ||
https://weaviate.io/developers/weaviate/search/rerank. | ||
""" | ||
import weaviate | ||
|
||
# Initialize your component here based on the arguments | ||
self.client = weaviate.Client( | ||
url=weaviate_url, | ||
additional_config=None, | ||
additional_headers=None, | ||
) | ||
self.class_name = class_name | ||
self.k = top_k | ||
|
||
def teardown(self) -> None: | ||
del self.client | ||
|
||
def retrieve_chunks_from_embeddings(self, vector_query: list): | ||
"""Get results from weaviate database.""" | ||
query = ( | ||
self.client.query.get(self.class_name, ["passage"]) | ||
.with_near_vector({"vector": vector_query}) | ||
.with_limit(self.k) | ||
.with_additional(["distance"]) | ||
) | ||
|
||
result = query.do() | ||
if "data" in result: | ||
result_dict = result["data"]["Get"][self.class_name] | ||
return [retrieved_chunk["passage"] for retrieved_chunk in result_dict] | ||
|
||
def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame: | ||
|
||
if "embedding" in dataframe.columns: | ||
dataframe["retrieved_chunks"] = dataframe["embedding"].apply( | ||
self.retrieve_chunks_from_embeddings, | ||
) | ||
|
||
elif "prompt" in dataframe.columns: | ||
dataframe["retrieved_chunks"] = dataframe["prompt"].apply( | ||
self.retrieve_chunks_from_prompts, | ||
) | ||
else: | ||
msg = "Dataframe must contain either an 'embedding' column or a 'prompt' column." | ||
raise ValueError( | ||
msg, | ||
) | ||
|
||
return dataframe |
Oops, something went wrong.