diff --git a/PySubtitle/Providers/Anthropic/AnthropicClient.py b/PySubtitle/Providers/Anthropic/AnthropicClient.py index 7664ab2..3ea2be3 100644 --- a/PySubtitle/Providers/Anthropic/AnthropicClient.py +++ b/PySubtitle/Providers/Anthropic/AnthropicClient.py @@ -9,7 +9,6 @@ from PySubtitle.TranslationClient import TranslationClient from PySubtitle.Translation import Translation from PySubtitle.TranslationClient import TranslationClient - from PySubtitle.TranslationParser import TranslationParser from PySubtitle.TranslationPrompt import TranslationPrompt linesep = '\n' diff --git a/PySubtitle/Providers/Provider_Claude.py b/PySubtitle/Providers/Provider_Claude.py index 98cfe4b..66075e9 100644 --- a/PySubtitle/Providers/Provider_Claude.py +++ b/PySubtitle/Providers/Provider_Claude.py @@ -9,7 +9,6 @@ from PySubtitle.Helpers import GetEnvFloat, GetEnvInteger from PySubtitle.Providers.Anthropic.AnthropicClient import AnthropicClient from PySubtitle.TranslationClient import TranslationClient - from PySubtitle.TranslationParser import TranslationParser from PySubtitle.TranslationProvider import TranslationProvider class Provider_Claude(TranslationProvider): @@ -31,7 +30,8 @@ def __init__(self, settings : dict): "model": settings.get('model') or os.getenv('CLAUDE_MODEL'), "max_tokens": settings.get('max_tokens') or GetEnvInteger('CLAUDE_MAX_TOKENS', 4096), 'temperature': settings.get('temperature', GetEnvFloat('CLAUDE_TEMPERATURE', 0.0)), - 'rate_limit': settings.get('rate_limit', GetEnvFloat('CLAUDE_RATE_LIMIT', 10.0)) + 'rate_limit': settings.get('rate_limit', GetEnvFloat('CLAUDE_RATE_LIMIT', 10.0)), + 'model_names': settings.get('model_names', 'claude-3-haiku-20240307, claude-3-sonnet-20240229, claude-3-5-sonnet-20240620, claude-3-opus-20240229') }) self.refresh_when_changed = ['api_key', 'model'] @@ -57,7 +57,7 @@ def GetAvailableModels(self) -> list[str]: # TODO: surely the SDK has a method for this? # client = anthropic.Anthropic(api_key=self.api_key) # models = client.list_models() - models = [ 'claude-3-haiku-20240307', 'claude-3-sonnet-20240229', 'claude-3-opus-20240229' ] + models = [ name.strip() for name in self.settings.get('model_names').split(',') ] return models @@ -76,6 +76,7 @@ def GetOptions(self) -> dict: 'temperature': (float, "The temperature to use for translations (default 0.0)"), 'rate_limit': (float, "The rate limit to use for translations (default 60.0)"), 'max_tokens': (int, "The maximum number of tokens to use for translations"), + 'model_names': (str, "Comma separated list of supported Claude models (until Anthropic provide a method to retrieve them!)"), }) return options diff --git a/PySubtitle/TranslationProvider.py b/PySubtitle/TranslationProvider.py index cc21f94..c0918d6 100644 --- a/PySubtitle/TranslationProvider.py +++ b/PySubtitle/TranslationProvider.py @@ -30,7 +30,8 @@ def selected_model(self) -> str: """ The currently selected model for the provider """ - return self.settings.get('model') + name : str = self.settings.get('model') + return name.strip() if name else None @property def allow_multithreaded_translation(self) -> bool: