From ca07b3d153c5d5753ca491f66d1cc312c5736f66 Mon Sep 17 00:00:00 2001 From: Jithin James Date: Mon, 22 May 2023 19:38:07 +0530 Subject: [PATCH] fix: lazyloading of model used in metrics to speed up import (#32) * added init_model to baseline * added init_model to everything * fix lint issues * added init model to qsquare * ignore type issue * fix linting --------- Co-authored-by: Jithin James --- Makefile | 2 +- ragas/metrics/base.py | 20 ++++++++++++++++++++ ragas/metrics/factual.py | 9 +++++---- ragas/metrics/similarity.py | 2 +- ragas/metrics/simple.py | 8 +++++++- tests/benchmarks/benchmark.py | 17 +++++++---------- 6 files changed, 41 insertions(+), 17 deletions(-) diff --git a/Makefile b/Makefile index fc5ab0fa8..52ce482a3 100644 --- a/Makefile +++ b/Makefile @@ -18,7 +18,7 @@ lint: ## Running lint checker: ruff @ruff check ragas examples tests type: ## Running type checker: pyright @echo "(pyright) Typechecking codebase..." - @pyright -p ragas + @pyright ragas clean: ## Clean all generated files @echo "Cleaning all generated files..." @cd $(GIT_ROOT)/docs && make clean diff --git a/ragas/metrics/base.py b/ragas/metrics/base.py index b2acdc9f5..21cde9b2e 100644 --- a/ragas/metrics/base.py +++ b/ragas/metrics/base.py @@ -13,17 +13,33 @@ class Metric(ABC): @property @abstractmethod def name(self: t.Self) -> str: + """ + the metric name + """ ... @property @abstractmethod def is_batchable(self: t.Self) -> bool: + """ + Attribute to check if this metric is is_batchable + """ + ... + + @abstractmethod + def init_model(): + """ + This method will lazy initialize the model. + """ ... @abstractmethod def score( self: t.Self, ground_truth: list[str], generated_text: list[str] ) -> list[float]: + """ + Run the metric on the ground_truth and generated_text and return score. + """ ... @@ -37,6 +53,10 @@ def eval(self, ground_truth: list[list[str]], generated_text: list[str]) -> Resu ds = Dataset.from_dict( {"ground_truth": ground_truth, "generated_text": generated_text} ) + + # initialize all the models in the metrics + [m.init_model() for m in self.metrics] + ds = ds.map( self._get_score, batched=self.batched, diff --git a/ragas/metrics/factual.py b/ragas/metrics/factual.py index d9ab22909..a936d4ff8 100644 --- a/ragas/metrics/factual.py +++ b/ragas/metrics/factual.py @@ -52,7 +52,7 @@ class EntailmentScore(Metric): batch_size: int = 4 device: t.Literal["cpu", "cuda"] | Device = "cpu" - def __post_init__(self): + def init_model(self): self.device = device_check(self.device) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name) @@ -212,10 +212,11 @@ class Qsquare(Metric): include_nouns: bool = True save_results: bool = False - def __post_init__(self): + def init_model(self): self.qa = QAGQ.from_pretrained(self.qa_model_name) self.qg = QAGQ.from_pretrained(self.qg_model_name) self.nli = EntailmentScore() + self.nli.init_model() try: self.nlp = spacy.load(SPACY_MODEL) except OSError: @@ -326,7 +327,7 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs): ) gnd_qans[i] = [ {"question": qstn, "answer": ans} - for qstn, ans in zip(questions, candidates) + for qstn, ans in zip(questions, candidates) # type: ignore ] for i, gen_text in enumerate(generated_text): @@ -334,7 +335,7 @@ def score(self, ground_truth: list[str], generated_text: list[str], **kwargs): gen_answers = self.generate_answers(questions, gen_text) _ = [ item.update({"predicted_answer": ans}) - for item, ans in zip(gnd_qans[i], gen_answers) + for item, ans in zip(gnd_qans[i], gen_answers) # type: ignore ] # del self.qa diff --git a/ragas/metrics/similarity.py b/ragas/metrics/similarity.py index 60b3d1d2e..a79468b69 100644 --- a/ragas/metrics/similarity.py +++ b/ragas/metrics/similarity.py @@ -18,7 +18,7 @@ class BERTScore(Metric): model_path: str = "all-MiniLM-L6-v2" batch_size: int = 1000 - def __post_init__(self): + def init_model(self): self.model = SentenceTransformer(self.model_path) @property diff --git a/ragas/metrics/simple.py b/ragas/metrics/simple.py index 2ac8ee8b9..643d5d777 100644 --- a/ragas/metrics/simple.py +++ b/ragas/metrics/simple.py @@ -26,6 +26,9 @@ def name(self): def is_batchable(self): return True + def init_model(self): + ... + def score(self, ground_truth: t.List[str], generated_text: t.List[str]): ground_truth_ = [[word_tokenize(text)] for text in ground_truth] generated_text_ = [word_tokenize(text) for text in generated_text] @@ -45,7 +48,7 @@ class ROUGE(Metric): type: t.Literal[ROUGE_TYPES] use_stemmer: bool = False - def __post_init__(self): + def init_model(self): self.scorer = rouge_scorer.RougeScorer( [self.type], use_stemmer=self.use_stemmer ) @@ -80,6 +83,9 @@ def name(self) -> str: def is_batchable(self): return True + def init_model(self): + ... + def score(self, ground_truth: t.List[str], generated_text: t.List[str]): if self.measure == "distance": score = [distance(s1, s2) for s1, s2 in zip(ground_truth, generated_text)] diff --git a/tests/benchmarks/benchmark.py b/tests/benchmarks/benchmark.py index 5bb413fc0..c868ceff8 100644 --- a/tests/benchmarks/benchmark.py +++ b/tests/benchmarks/benchmark.py @@ -7,26 +7,23 @@ from ragas.metrics import ( Evaluation, - edit_distance, + bert_score, edit_ratio, - q_square, rouge1, - rouge2, - rougeL, ) DEVICE = "cuda" if is_available() else "cpu" -BATCHES = [0, 1] +BATCHES = [0, 1, 30, 60] METRICS = { "Rouge1": rouge1, - "Rouge2": rouge2, - "RougeL": rougeL, + # "Rouge2": rouge2, + # "RougeL": rougeL, "EditRatio": edit_ratio, - "EditDistance": edit_distance, - # "SBERTScore": bert_score, + # "EditDistance": edit_distance, + "SBERTScore": bert_score, # "EntailmentScore": entailment_score, - "Qsquare": q_square, + # "Qsquare": q_square, } DS = load_dataset("explodinggradients/eli5-test", split="test_eli5") assert isinstance(DS, arrow_dataset.Dataset), "Not an arrow_dataset"