Skip to content

Commit

Permalink
revert to english factual consistency prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
yosukehigashi committed Feb 21, 2024
1 parent 7a9bbb6 commit 6835c4f
Showing 1 changed file with 17 additions and 115 deletions.
132 changes: 17 additions & 115 deletions src/langcheck/metrics/ja/source_based_text_quality.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Dict, List, Optional, Tuple, cast
from typing import Dict, List, Optional, cast

from openai import OpenAI
from transformers.pipelines import pipeline
Expand Down Expand Up @@ -80,25 +80,15 @@ def factual_consistency(
], ('Unsupported model type. '
'The supported ones are ["local", "openai", "azure_openai"]')

if model_type == 'local':
scores = _factual_consistency_local(generated_outputs, sources)
explanations = None
else: # openai or azure_openai
scores, explanations = _factual_consistency_openai(
generated_outputs, sources, model_type, openai_client, openai_args)
# The English prompt works well enough for Japanese
# TODO: Investigate the performance improvement with Japanese prompt
if model_type == 'openai' or model_type == 'azure_openai':
metric_value = en_factual_consistency(generated_outputs, sources,
prompts, model_type,
openai_client, openai_args)
metric_value.language = 'ja'
return metric_value

return MetricValue(metric_name='factual_consistency',
prompts=prompts,
generated_outputs=generated_outputs,
reference_outputs=None,
sources=sources,
explanations=explanations,
metric_values=scores,
language='ja')


def _factual_consistency_local(generated_outputs: List[str],
sources: List[str]) -> List[float]:
global _factual_consistency_translation_pipeline
if _factual_consistency_translation_pipeline is None:
_factual_consistency_translation_pipeline = pipeline(
Expand All @@ -122,103 +112,15 @@ def _factual_consistency_local(generated_outputs: List[str],
# Compute the factual consistency scores in English.
factual_consistency_scores = en_factual_consistency(
generated_outputs=en_generated_outputs, sources=en_source).metric_values
# Local factual consistency scores are of type List[float]
return factual_consistency_scores # type: ignore


def _factual_consistency_openai(
generated_outputs: List[str], sources: List[str], client_type: str,
client: Optional[OpenAI], openai_args: Optional[Dict[str, str]]
) -> Tuple[List[Optional[float]], List[Optional[str]]]:
'''Calculates the factual consistency and their associated explanations
between each generated output and its corresponding source text. We leverage
the function calling API to make sure that the output is structured such
that we can compute a score. If a score could not be computed, `None` is
inserted to the score and explanation lists.
Ref:
https://platform.openai.com/docs/guides/gpt/function-calling
Args:
generated_outputs: The model generated output(s) to evaluate
sources: The source text(s), one string per generated output
client_type: The type of OpenAI client ('openai' or 'azure_openai')
client: (Optional) OpenAI or AzureOpenAI client. If this is None, we
will attempt to create a default client depending on the
``client_type``.
openai_args: (Optional) Dict of additional args to pass in to the
``client.chat.completions.create`` function
Returns:
score_list: a list of scores
explanation_list: a list of explanations for the scores
'''

# TODO: The prompt formation, and the scoring system, can do with some
# improvement. There are some cases where consistent outputs get incorrectly
# assessed as "Partially Consistent", and there's no differentiation
# between an output that is unrelated to the source and an output that is
# straight up contradictory.
def _prompt(src: str, gen_output: str) -> str:
return f'''
提出されたテキストの論理的な一貫性を評価してください。データは以下の通りです:
[BEGIN DATA]
************
[ソース]: {src}
************
[提出内容]: {gen_output}
************
[END DATA]
ソーステキストと提出されたテキストの関係を判断してください。利用可能な評価は以下の通りです:
`Fully Consistent` - ソーステキストが提出されたテキストを含意する。
`Partially Consistent` - ソーステキストが提出されたテキストを含意するか、
または二つのテキストが矛盾するか、判断できない。
`Not Consistent` - ソーステキストと提出されたテキストが矛盾する。

深呼吸をして、この問題をステップバイステップで取り組んでください。
'''

def _function_call_prompt(long_assessment: str) -> str:
return f'''
以下はテキストの事実上の一貫性に関する評価です:
************
[評価]: {long_assessment}
************
結果として出た評価を保存してください。利用可能な評価は以下の通りです:
`Fully Consistent`
`Partially Consistent`
`Not Consistent`
'''

factuality_assessment_to_score = {
'Fully Consistent': 1.0,
'Partially Consistent': 0.5,
'Not Consistent': 0.0
}
oai_evaluator = OpenAIBasedEvaluator(
assessment_to_score_mapping=factuality_assessment_to_score,
function_name='save_factual_consistency_assessment',
function_description=(
"Saves a submitted claim's factual consistency assessment."),
argument_name='factuality',
argument_description='The factual consistency assessment of the claim',
client_type=client_type,
client=client,
openai_args=openai_args)

score_list = []

explanation_list = []
for src, gen in tqdm_wrapper(zip(sources, generated_outputs),
desc='Calculating scores',
total=len(generated_outputs)):
score, explanation = oai_evaluator.get_score(
_prompt(src=src, gen_output=gen), _function_call_prompt)
score_list.append(score)
explanation_list.append(explanation)
return score_list, explanation_list
return MetricValue(metric_name='factual_consistency',
prompts=prompts,
generated_outputs=generated_outputs,
reference_outputs=None,
sources=sources,
explanations=None,
metric_values=factual_consistency_scores,
language='ja')


def context_relevance(
Expand Down

0 comments on commit 6835c4f

Please sign in to comment.