From 6835c4fa1f75f9d91b6fd35b52bdb1f55ef8d0dd Mon Sep 17 00:00:00 2001 From: Yosuke Higashi Date: Wed, 21 Feb 2024 07:35:30 +0000 Subject: [PATCH] revert to english factual consistency prompt --- .../metrics/ja/source_based_text_quality.py | 132 +++--------------- 1 file changed, 17 insertions(+), 115 deletions(-) diff --git a/src/langcheck/metrics/ja/source_based_text_quality.py b/src/langcheck/metrics/ja/source_based_text_quality.py index 8afb3a5b..977ce38b 100644 --- a/src/langcheck/metrics/ja/source_based_text_quality.py +++ b/src/langcheck/metrics/ja/source_based_text_quality.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional, Tuple, cast +from typing import Dict, List, Optional, cast from openai import OpenAI from transformers.pipelines import pipeline @@ -80,25 +80,15 @@ def factual_consistency( ], ('Unsupported model type. ' 'The supported ones are ["local", "openai", "azure_openai"]') - if model_type == 'local': - scores = _factual_consistency_local(generated_outputs, sources) - explanations = None - else: # openai or azure_openai - scores, explanations = _factual_consistency_openai( - generated_outputs, sources, model_type, openai_client, openai_args) + # The English prompt works well enough for Japanese + # TODO: Investigate the performance improvement with Japanese prompt + if model_type == 'openai' or model_type == 'azure_openai': + metric_value = en_factual_consistency(generated_outputs, sources, + prompts, model_type, + openai_client, openai_args) + metric_value.language = 'ja' + return metric_value - return MetricValue(metric_name='factual_consistency', - prompts=prompts, - generated_outputs=generated_outputs, - reference_outputs=None, - sources=sources, - explanations=explanations, - metric_values=scores, - language='ja') - - -def _factual_consistency_local(generated_outputs: List[str], - sources: List[str]) -> List[float]: global _factual_consistency_translation_pipeline if _factual_consistency_translation_pipeline is None: _factual_consistency_translation_pipeline = pipeline( @@ -122,103 +112,15 @@ def _factual_consistency_local(generated_outputs: List[str], # Compute the factual consistency scores in English. factual_consistency_scores = en_factual_consistency( generated_outputs=en_generated_outputs, sources=en_source).metric_values - # Local factual consistency scores are of type List[float] - return factual_consistency_scores # type: ignore - - -def _factual_consistency_openai( - generated_outputs: List[str], sources: List[str], client_type: str, - client: Optional[OpenAI], openai_args: Optional[Dict[str, str]] -) -> Tuple[List[Optional[float]], List[Optional[str]]]: - '''Calculates the factual consistency and their associated explanations - between each generated output and its corresponding source text. We leverage - the function calling API to make sure that the output is structured such - that we can compute a score. If a score could not be computed, `None` is - inserted to the score and explanation lists. - - Ref: - https://platform.openai.com/docs/guides/gpt/function-calling - - Args: - generated_outputs: The model generated output(s) to evaluate - sources: The source text(s), one string per generated output - client_type: The type of OpenAI client ('openai' or 'azure_openai') - client: (Optional) OpenAI or AzureOpenAI client. If this is None, we - will attempt to create a default client depending on the - ``client_type``. - openai_args: (Optional) Dict of additional args to pass in to the - ``client.chat.completions.create`` function - - Returns: - score_list: a list of scores - explanation_list: a list of explanations for the scores - ''' - - # TODO: The prompt formation, and the scoring system, can do with some - # improvement. There are some cases where consistent outputs get incorrectly - # assessed as "Partially Consistent", and there's no differentiation - # between an output that is unrelated to the source and an output that is - # straight up contradictory. - def _prompt(src: str, gen_output: str) -> str: - return f''' - 提出されたテキストの論理的な一貫性を評価してください。データは以下の通りです: - [BEGIN DATA] - ************ - [ソース]: {src} - ************ - [提出内容]: {gen_output} - ************ - [END DATA] - - ソーステキストと提出されたテキストの関係を判断してください。利用可能な評価は以下の通りです: - `Fully Consistent` - ソーステキストが提出されたテキストを含意する。 - `Partially Consistent` - ソーステキストが提出されたテキストを含意するか、 - または二つのテキストが矛盾するか、判断できない。 - `Not Consistent` - ソーステキストと提出されたテキストが矛盾する。 - 深呼吸をして、この問題をステップバイステップで取り組んでください。 - ''' - - def _function_call_prompt(long_assessment: str) -> str: - return f''' - 以下はテキストの事実上の一貫性に関する評価です: - ************ - [評価]: {long_assessment} - ************ - - 結果として出た評価を保存してください。利用可能な評価は以下の通りです: - `Fully Consistent` - `Partially Consistent` - `Not Consistent` - ''' - - factuality_assessment_to_score = { - 'Fully Consistent': 1.0, - 'Partially Consistent': 0.5, - 'Not Consistent': 0.0 - } - oai_evaluator = OpenAIBasedEvaluator( - assessment_to_score_mapping=factuality_assessment_to_score, - function_name='save_factual_consistency_assessment', - function_description=( - "Saves a submitted claim's factual consistency assessment."), - argument_name='factuality', - argument_description='The factual consistency assessment of the claim', - client_type=client_type, - client=client, - openai_args=openai_args) - - score_list = [] - - explanation_list = [] - for src, gen in tqdm_wrapper(zip(sources, generated_outputs), - desc='Calculating scores', - total=len(generated_outputs)): - score, explanation = oai_evaluator.get_score( - _prompt(src=src, gen_output=gen), _function_call_prompt) - score_list.append(score) - explanation_list.append(explanation) - return score_list, explanation_list + return MetricValue(metric_name='factual_consistency', + prompts=prompts, + generated_outputs=generated_outputs, + reference_outputs=None, + sources=sources, + explanations=None, + metric_values=factual_consistency_scores, + language='ja') def context_relevance(