Skip to content

Commit

Permalink
Merge branch 'main' into nltk-download
Browse files Browse the repository at this point in the history
  • Loading branch information
yosukehigashi authored Mar 1, 2024
2 parents e903129 + 21fc283 commit fa085d6
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 31 deletions.
17 changes: 13 additions & 4 deletions src/langcheck/metrics/de/reference_free_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,21 @@ def sentiment(
return_tensors='pt',
padding=True)

batch_size = 8
scores = []
with torch.no_grad():
# Probabilities of [negative, neutral, positive]
probs = torch.nn.functional.softmax(
_sentiment_model(**input_tokens).logits, dim=1)
for i in tqdm_wrapper(
range(0, len(generated_outputs), batch_size),
total=(len(generated_outputs) + batch_size - 1) //
batch_size):
batch_input_tokens = {
k: v[i:i + batch_size] for k, v in input_tokens.items()
}
# Probabilities of [negative, neutral, positive]
probs = torch.nn.functional.softmax(
_sentiment_model(**batch_input_tokens).logits, dim=1)
scores.extend((probs[:, 1] / 2 + probs[:, 2]).tolist())

scores = (probs[:, 1] / 2 + probs[:, 2]).tolist()
explanations = None

return MetricValue(metric_name='sentiment',
Expand Down
20 changes: 15 additions & 5 deletions src/langcheck/metrics/de/source_based_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,21 @@ def factual_consistency(
# Currently, the type checks are not working for the pipeline, since
# too diverse types can be returned.
translation = Translate(_factual_consistency_translation_model_path)

en_source = [translation(source) for source in sources]
en_generated_outputs = [
translation(gen_out) for gen_out in generated_outputs
]
batch_size = 8
en_source = []
for i in tqdm_wrapper(range(0, len(sources), batch_size),
desc='Translating sources',
total=(len(sources) + batch_size - 1) // batch_size):
batch_sources = sources[i:i + batch_size]
en_source.extend([translation(src) for src in batch_sources])
en_generated_outputs = []
for i in tqdm_wrapper(range(0, len(generated_outputs), batch_size),
desc='Translating generated outputs',
total=(len(generated_outputs) + batch_size - 1) //
batch_size):
batch_generated_outputs = generated_outputs[i:i + batch_size]
en_generated_outputs.extend(
[translation(gen_out) for gen_out in batch_generated_outputs])

# Compute the factual consistency scores in English.
metric_value = en_factual_consistency(
Expand Down
18 changes: 13 additions & 5 deletions src/langcheck/metrics/en/reference_free_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,20 @@ def _sentiment_local(generated_outputs: List[str]) -> List[float]:
return_tensors='pt',
padding=True)

batch_size = 8
scores = []
with torch.no_grad():
# Probabilities of [negative, neutral, positive]
probs = torch.nn.functional.softmax(
_sentiment_model(**input_tokens).logits, dim=1)

return (probs[:, 1] / 2 + probs[:, 2]).tolist()
for i in tqdm_wrapper(range(0, len(generated_outputs), batch_size),
total=(len(generated_outputs) + batch_size - 1) //
batch_size):
batch_input_tokens = {
k: v[i:i + batch_size] for k, v in input_tokens.items()
}
# Probabilities of [negative, neutral, positive]
probs = torch.nn.functional.softmax(
_sentiment_model(**batch_input_tokens).logits, dim=1)
scores.extend((probs[:, 1] / 2 + probs[:, 2]).tolist())
return scores


def _sentiment_openai(
Expand Down
17 changes: 12 additions & 5 deletions src/langcheck/metrics/ja/reference_free_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,19 @@ def _sentiment_local(generated_outputs: List[str]) -> List[float]:
return_tensors='pt',
padding=True)

batch_size = 8
scores = []
with torch.no_grad():
# Probabilities of [negative, neutral, positive]
probs = torch.nn.functional.softmax(
_sentiment_model(**input_tokens).logits, dim=1)

scores = (probs[:, 1] / 2 + probs[:, 2]).tolist()
for i in tqdm_wrapper(range(0, len(generated_outputs), batch_size),
total=(len(generated_outputs) + batch_size - 1) //
batch_size):
batch_input_tokens = {
k: v[i:i + batch_size] for k, v in input_tokens.items()
}
# Probabilities of [negative, neutral, positive]
probs = torch.nn.functional.softmax(
_sentiment_model(**batch_input_tokens).logits, dim=1)
scores.extend((probs[:, 1] / 2 + probs[:, 2]).tolist())
return scores


Expand Down
38 changes: 26 additions & 12 deletions src/langcheck/metrics/ja/source_based_text_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,32 @@ def factual_consistency(
# Translate the sources and generated outputs to English.
# Currently, the type checks are not working for the pipeline, since
# too diverse types can be returned.
en_source = [
cast(str,
d['translation_text']) # type: ignore[reportGeneralTypeIssues]
for d in _factual_consistency_translation_pipeline(
sources) # type: ignore[reportGeneralTypeIssues]
]
en_generated_outputs = [
cast(str,
d['translation_text']) # type: ignore[reportGeneralTypeIssues]
for d in _factual_consistency_translation_pipeline(
generated_outputs) # type: ignore[reportGeneralTypeIssues]
]
batch_size = 8
en_source = []
for i in tqdm_wrapper(range(0, len(sources), batch_size),
desc='Translating sources',
total=(len(sources) + batch_size - 1) // batch_size):
batch_sources = sources[i:i + batch_size]
en_source.extend([
cast(str,
d['translation_text']) # type: ignore[reportGeneralTypeIssues]
for d in _factual_consistency_translation_pipeline(
batch_sources) # type: ignore[reportGeneralTypeIssues]
])
en_generated_outputs = []
for i in tqdm_wrapper(range(0, len(generated_outputs), batch_size),
desc='Translating generated outputs',
total=(len(generated_outputs) + batch_size - 1) //
batch_size):
batch_generated_outputs = generated_outputs[i:i + batch_size]
en_generated_outputs.extend([
cast(str,
d['translation_text']) # type: ignore[reportGeneralTypeIssues]
for d in _factual_consistency_translation_pipeline(
batch_generated_outputs
) # type: ignore[reportGeneralTypeIssues]
])

# Compute the factual consistency scores in English.
factual_consistency_scores = en_factual_consistency(
generated_outputs=en_generated_outputs, sources=en_source).metric_values
Expand Down

0 comments on commit fa085d6

Please sign in to comment.