Skip to content

Commit

Permalink
Make sure gemini models are fetched on validation
Browse files Browse the repository at this point in the history
  • Loading branch information
machinewrapped committed May 24, 2024
1 parent 609c175 commit 0b32797
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions PySubtitle/Providers/Provider_Gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class GeminiProvider(TranslationProvider):
name = "Gemini"

information = """
<p>Select the <a href="https://ai.google.dev/models/gemini">AI model</a> to use as a translator.</p>
<p>Select the <a href="https://ai.google.dev/models/gemini">AI model</a> to use as a translator.</p>
<p>Please note that the Gemini API can currently only be accessed from IP addresses in <a href="https://ai.google.dev/available_regions">certain regions</a>.</p>
<p>You must ensure that the Generative Language API is enabled for your project and/or API key.</p>
"""
Expand All @@ -40,7 +40,7 @@ def __init__(self, settings : dict):
@property
def api_key(self):
return self.settings.get('api_key')

def GetTranslationClient(self, settings : dict) -> TranslationClient:
genai.configure(api_key=self.api_key)
client_settings = self.settings.copy()
Expand All @@ -56,7 +56,7 @@ def GetOptions(self) -> dict:
options = {
'api_key': (str, "A Google Gemini API key is required to use this provider (https://makersuite.google.com/app/apikey)")
}

if self.api_key:
try:
models = self.available_models
Expand All @@ -72,7 +72,7 @@ def GetOptions(self) -> dict:

except FailedPrecondition as e:
options['model'] = (["Unable to access the Gemini API"], str(e))

return options

def GetAvailableModels(self) -> list[str]:
Expand All @@ -93,8 +93,8 @@ def ValidateSettings(self) -> bool:
if not self.api_key:
self.validation_message = "API Key is required"
return False
if not self.gemini_models:

if not self.GetAvailableModels():
self.validation_message = "Unable to retrieve models. Gemini API may be unavailable in your region."
return False

Expand All @@ -120,7 +120,7 @@ def _get_true_name(self, display_name : str) -> str:
for m in self.gemini_models:
if m.display_name == display_name:
return m.name

raise ValueError(f"Model {display_name} not found")

except ImportError:
Expand Down

0 comments on commit 0b32797

Please sign in to comment.