Skip to content

Commit

Permalink
Merge pull request #174 from citadel-ai/gemini-typing
Browse files Browse the repository at this point in the history
Update import paths of google genai modules to fix pyright errors
  • Loading branch information
liwii authored Feb 7, 2025
2 parents e5c7295 + 641e961 commit e627bab
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
# This workflow contains a single job called "pytest"
pytest:
# The type of runner that the job will run on
runs-on: ubuntu-latest
runs-on: ubuntu-latest-4-cores

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
Expand Down
12 changes: 7 additions & 5 deletions src/langcheck/metrics/eval_clients/_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from typing import Any

import google.ai.generativelanguage as glm
import google.generativeai as genai
import torch
from google.generativeai.client import configure
from google.generativeai.embedding import embed_content
from google.generativeai.generative_models import GenerativeModel

from langcheck.utils.progress_bar import tqdm_wrapper

Expand All @@ -20,7 +22,7 @@ class GeminiEvalClient(EvalClient):

def __init__(
self,
model: genai.GenerativeModel | None = None,
model: GenerativeModel | None = None,
model_args: dict[str, Any] | None = None,
generate_content_args: dict[str, Any] | None = None,
embed_model_name: str | None = None,
Expand Down Expand Up @@ -49,9 +51,9 @@ def __init__(
if model:
self._model = model
else:
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
configure(api_key=os.getenv("GOOGLE_API_KEY"))
model_args = model_args or {}
self._model = genai.GenerativeModel(**model_args)
self._model = GenerativeModel(**model_args)

self._generate_content_args = generate_content_args or {}
self._embed_model_name = embed_model_name
Expand Down Expand Up @@ -234,7 +236,7 @@ def __init__(self, embed_model_name: str | None):
def _embed(self, inputs: list[str]) -> torch.Tensor:
"""Embed the inputs using the Gemini API."""
# Embed the inputs
embed_response = genai.embed_content(
embed_response = embed_content(
model=self.embed_model_name, content=inputs
)

Expand Down
38 changes: 22 additions & 16 deletions tests/metrics/eval_clients/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ def test_get_text_response_gemini():
mock_response.candidates = [Mock(finish_reason=1)]
# Calling the google.generativeai.GenerativeModel.generate_content method
# requires a Google API key, so we mock the return value instead
with patch("google.generativeai.GenerativeModel.generate_content",
return_value=mock_response):

with patch(
"google.generativeai.GenerativeModel.generate_content",
return_value=mock_response,
):
# Set the necessary env vars for the GeminiEvalClient
os.environ["GOOGLE_API_KEY"] = "dummy_key"
client = GeminiEvalClient()
Expand All @@ -41,28 +42,30 @@ def test_get_float_score_gemini(language):
mock_response.text = short_assessment_result

class FunctionCallMock(Mock):

@classmethod
def to_dict(cls, instance):
return {"args": {"assessment": short_assessment_result}}

mock_response.candidates = [
Mock(finish_reason=1,
content=Mock(parts=[Mock(function_call=FunctionCallMock())]))
Mock(
finish_reason=1,
content=Mock(parts=[Mock(function_call=FunctionCallMock())]),
)
]

# Calling the google.generativeai.GenerativeModel.generate_content method
# requires a Google API key, so we mock the return value instead
with patch("google.generativeai.GenerativeModel.generate_content",
return_value=mock_response):

with patch(
"google.generativeai.GenerativeModel.generate_content",
return_value=mock_response,
):
# Set the necessary env vars for the GeminiEvalClient
os.environ["GOOGLE_API_KEY"] = "dummy_key"
client = GeminiEvalClient()

scores = client.get_float_score("dummy_metric", language,
unstructured_assessment_result,
score_map)
scores = client.get_float_score(
"dummy_metric", language, unstructured_assessment_result, score_map
)
assert len(scores) == len(unstructured_assessment_result)
for score in scores:
assert score == 1.0
Expand All @@ -73,14 +76,17 @@ def test_similarity_scorer_gemini():

# Calling the google.generativeai.embed_content method requires a Google
# API key, so we mock the return value instead
with patch("google.generativeai.embed_content",
Mock(return_value=mock_embedding_response)):
with patch(
"langcheck.metrics.eval_clients._gemini.embed_content",
Mock(return_value=mock_embedding_response),
):
# Set the necessary env vars for the GeminiEvalClient
os.environ["GOOGLE_API_KEY"] = "dummy_key"
gemini_client = GeminiEvalClient()
scorer = gemini_client.similarity_scorer()
# Since the mock embeddings are the same for the generated and reference
# outputs, the similarity score should be 1.
score = scorer.score(["The cat sat on the mat."],
["The cat sat on the mat."])
score = scorer.score(
["The cat sat on the mat."], ["The cat sat on the mat."]
)
assert 0.99 <= score[0] <= 1

0 comments on commit e627bab

Please sign in to comment.