Skip to content

Commit

Permalink
fix notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
mrchtr committed Feb 6, 2024
1 parent e8e4bf3 commit 70dbe30
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 2,660 deletions.
6 changes: 4 additions & 2 deletions src/components/aggregrate_eval_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@


@lightweight_component(
consumes={
"context_precision": pa.float32(),
"context_relevancy": pa.float32(),
},
produces={
"metric": pa.string(),
"score": pa.float32()
}
)
class AggregateResults(DaskTransformComponent):
def __init__(self, consumes: dict, **kwargs):
self.consumes = consumes

def transform(self, dataframe: dd.DataFrame) -> dd.DataFrame:
metrics = list(self.consumes.keys())
Expand Down
11 changes: 3 additions & 8 deletions src/components/retrieve_from_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ def retrieve_chunks_from_embeddings(self, vector_query: list):
)

result = query.do()
if "data" in result:
result_dict = result["data"]["Get"][self.class_name]
return [retrieved_chunk["passage"] for retrieved_chunk in result_dict]
result_dict = result["data"]["Get"][self.class_name]
return [retrieved_chunk["passage"] for retrieved_chunk in result_dict]

def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:

Expand All @@ -67,12 +66,8 @@ def transform(self, dataframe: pd.DataFrame) -> pd.DataFrame:
self.retrieve_chunks_from_embeddings,
)

elif "prompt" in dataframe.columns:
dataframe["retrieved_chunks"] = dataframe["prompt"].apply(
self.retrieve_chunks_from_prompts,
)
else:
msg = "Dataframe must contain either an 'embedding' column or a 'prompt' column."
msg = "Dataframe must contain an 'embedding' column"
raise ValueError(
msg,
)
Expand Down
Loading

0 comments on commit 70dbe30

Please sign in to comment.