Skip to content

Commit

Permalink
Merge pull request #12 from MaKTaiL/feat-dual-api
Browse files Browse the repository at this point in the history
Dual API Support
  • Loading branch information
MaKTaiL authored Oct 29, 2024
2 parents ef7fbc8 + 2b8c882 commit 162ec65
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 42 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ This will translate the subtitles in the `subtitle.srt` file to French.

You can further customize the translation settings by providing optional parameters:

- `gemini_api_key2`: Second Gemini API key for additional quota.
- `output_file`: Path to save the translated subtitle file.
- `description`: Description of the translation job.
- `model_name`: Model name to use for translation.
Expand All @@ -69,6 +70,7 @@ Example:
import gemini_srt_translator as gst

gst.gemini_api_key = "your_gemini_api_key_here"
gst.gemini_api_key2 = "your_gemini_api_key2_here"
gst.target_language = "French"
gst.input_file = "subtitle.srt"
gst.output_file = "subtitle_translated.srt"
Expand Down
5 changes: 5 additions & 0 deletions gemini_srt_translator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .main import GeminiSRTTranslator

gemini_api_key: str = None
gemini_api_key2: str = None
target_language: str = None
input_file: str = None
output_file: str = None
Expand Down Expand Up @@ -82,6 +83,9 @@ def translate():
# Path to the subtitle file to translate
gst.input_file = "subtitle.srt"
# (Optional) Gemini API key 2 for additional quota
gst.gemini_api_key2 = "your_gemini_api_key2_here"
# (Optional) Path to save the translated subtitle file
gst.output_file = "translated_subtitle.srt"
Expand All @@ -106,6 +110,7 @@ def translate():
"""
params = {
'gemini_api_key': gemini_api_key,
'gemini_api_key2': gemini_api_key2,
'target_language': target_language,
'input_file': input_file,
'output_file': output_file,
Expand Down
160 changes: 120 additions & 40 deletions gemini_srt_translator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,38 @@
from google.generativeai import GenerativeModel

class SubtitleObject(typing.TypedDict):
"""
TypedDict for subtitle objects used in translation
"""
index: str
content: str
_blank: str

class GeminiSRTTranslator:
def __init__(self, gemini_api_key: str = None, target_language: str = None, input_file: str = None, output_file: str = None, description: str = None, model_name: str = "gemini-1.5-flash", batch_size: int = 30, free_quota: bool = True):
"""
A translator class that uses Gemini API to translate subtitles.
"""
def __init__(self, gemini_api_key: str = None, gemini_api_key2: str = None, target_language: str = None,
input_file: str = None, output_file: str = None, description: str = None,
model_name: str = "gemini-1.5-flash", batch_size: int = 30, free_quota: bool = True):
"""
Initialize the translator with necessary parameters.
Args:
gemini_api_key (str): Primary Gemini API key
gemini_api_key2 (str): Secondary Gemini API key for additional quota
target_language (str): Target language for translation
input_file (str): Path to input subtitle file
output_file (str): Path to output translated subtitle file
description (str): Additional instructions for translation
model_name (str): Gemini model to use
batch_size (int): Number of subtitles to process in each batch
free_quota (bool): Whether to use free quota (affects rate limiting)
"""
self.gemini_api_key = gemini_api_key
self.gemini_api_key2 = gemini_api_key2
self.current_api_key = gemini_api_key
self.current_api_number = 1
self.backup_api_number = 2
self.target_language = target_language
self.input_file = input_file
self.output_file = output_file
Expand All @@ -30,23 +55,22 @@ def __init__(self, gemini_api_key: str = None, target_language: str = None, inpu
self.free_quota = free_quota

def listmodels(self):
"""
Lists available models from the Gemini API.
"""
if not self.gemini_api_key:
"""List available Gemini models that support content generation."""
if not self.current_api_key:
raise Exception("Please provide a valid Gemini API key.")

genai.configure(api_key=self.gemini_api_key)
genai.configure(api_key=self.current_api_key)
models = genai.list_models()
for model in models:
if "generateContent" in model.supported_generation_methods:
print(model.name.replace("models/", ""))

def translate(self):
"""
Translates a subtitle file using the Gemini API.
Main translation method. Reads the input subtitle file, translates it in batches,
and writes the translated subtitles to the output file.
"""
if not self.gemini_api_key:
if not self.current_api_key:
raise Exception("Please provide a valid Gemini API key.")

if not self.target_language:
Expand All @@ -58,8 +82,6 @@ def translate(self):
if not self.output_file:
self.output_file = ".".join(self.input_file.split(".")[:-1]) + "_translated.srt"

genai.configure(api_key=self.gemini_api_key)

instruction = f"""You are an assistant that translates subtitles to {self.target_language}.
You will receive the following JSON type:
Expand All @@ -78,21 +100,10 @@ class SubtitleObject(typing.TypedDict):
if self.description:
instruction += "\nAdditional user instruction: '" + self.description + "'"

model = genai.GenerativeModel(
model_name=self.model_name,
safety_settings={
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
},
system_instruction=instruction,
generation_config=genai.GenerationConfig(response_mime_type="application/json", temperature=0)
)
model = self._get_model(instruction)

with open(self.input_file, "r", encoding="utf-8") as original_file, open(self.output_file, "w", encoding="utf-8") as translated_file:
original_text = original_file.read()

original_subtitle = list(srt.parse(original_text))
translated_subtitle = original_subtitle.copy()

Expand All @@ -102,16 +113,24 @@ class SubtitleObject(typing.TypedDict):
previous_message = None
reverted = 0
delay = False
delay_time = 30

if 'pro' in self.model_name and self.free_quota:
delay = True
print("Pro model and free user quota detected, enabling 30s delay between requests...")
if not self.gemini_api_key2:
print("Pro model and free user quota detected, enabling 30s delay between requests...")
else:
delay_time = 15
print("Pro model and free user quota detected, using secondary API key for additional quota...")

batch.append(SubtitleObject(index=str(i), content=original_subtitle[i].content))
i += 1

print(f"Starting translation of {total} lines...")

if self.gemini_api_key2:
print(f"Starting with API {self.current_api_number}:")

while len(batch) > 0:
if i < total and len(batch) < self.batch_size:
batch.append(SubtitleObject(index=str(i), content=original_subtitle[i].content))
Expand All @@ -122,7 +141,7 @@ class SubtitleObject(typing.TypedDict):
previous_message = self._process_batch(model, batch, previous_message, translated_subtitle)
end_time = time.time()
print(f"Translated {i}/{total}")
if delay and (end_time - start_time < 30):
if delay and (end_time - start_time < delay_time):
time.sleep(30 - (end_time - start_time))
if reverted > 0:
self.batch_size += reverted
Expand All @@ -132,14 +151,15 @@ class SubtitleObject(typing.TypedDict):
batch.append(SubtitleObject(index=str(i), content=original_subtitle[i].content))
i += 1
except Exception as e:
e = str(e)
if "block" in e:
print(e)
batch.clear()
break
elif "quota" in e:
print("Quota exceeded, waiting 1 minute...")
time.sleep(60)
e_str = str(e)

if "quota" in e_str:
if self._switch_api():
print(f"\n🔄 API {self.backup_api_number} quota exceeded! Switching to API {self.current_api_number}...")
model = self._get_model(instruction)
else:
print("\nAll API quotas exceeded, waiting 1 minute...")
time.sleep(60)
else:
if self.batch_size == 1:
raise Exception("Translation failed, aborting...")
Expand All @@ -150,39 +170,99 @@ class SubtitleObject(typing.TypedDict):
i -= 1
batch.pop()
self.batch_size -= decrement
if "finish_reason" in e:
print("Gemini has blocked the translation for unknown reasons")
if "Gemini" in e_str:
print(e_str)
else:
print(e)
print("An unexpected error has occurred")
print("Decreasing batch size to {} and trying again...".format(self.batch_size))

translated_file.write(srt.compose(translated_subtitle))

def _switch_api(self) -> bool:
"""
Switch to the secondary API key if available.
Returns:
bool: True if switched successfully, False if no alternative API available
"""
if self.current_api_number == 1 and self.gemini_api_key2:
self.current_api_key = self.gemini_api_key2
self.current_api_number = 2
self.backup_api_number = 1
return True
if self.current_api_number == 2 and self.gemini_api_key:
self.current_api_key = self.gemini_api_key
self.current_api_number = 1
self.backup_api_number = 2
return True
return False

def _get_model(self, instruction: str) -> GenerativeModel:
"""
Configure and return a Gemini model instance with current API key.
Args:
instruction (str): System instruction for the model
Returns:
GenerativeModel: Configured Gemini model instance
"""
genai.configure(api_key=self.current_api_key)
return genai.GenerativeModel(
model_name=self.model_name,
safety_settings={
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
},
system_instruction=instruction,
generation_config=genai.GenerationConfig(response_mime_type="application/json")
)

def _process_batch(self, model: GenerativeModel, batch: list[SubtitleObject], previous_message: ContentDict, translated_subtitle: list[Subtitle]) -> ContentDict:
"""
Processes a batch of subtitles.
Process a batch of subtitles for translation.
Args:
model (GenerativeModel): The Gemini model instance
batch (list[SubtitleObject]): Batch of subtitles to translate
previous_message (ContentDict): Previous message for context
translated_subtitle (list[Subtitle]): List to store translated subtitles
Returns:
ContentDict: The model's response for context in next batch
"""
if previous_message:
messages = [previous_message] + [{"role": "user", "parts": json.dumps(batch)}]
else:
messages = [{"role": "user", "parts": json.dumps(batch)}]
response = model.generate_content(messages)
translated_lines: list[SubtitleObject] = json.loads(response.text)

if len(translated_lines) != len(batch):
raise Exception("Gemini has returned the wrong number of lines.")

for line in translated_lines:
if line["index"] not in [x["index"] for x in batch]:
raise Exception("Gemini has returned different indices.")
if self.dominant_strong_direction(line["content"]) == "rtl":
if self._dominant_strong_direction(line["content"]) == "rtl":
translated_subtitle[int(line["index"])].content = f"\u202B{line['content']}\u202C"
else:
translated_subtitle[int(line["index"])].content = line["content"]

batch.clear()
return response.candidates[0].content

def dominant_strong_direction(self, s: str) -> str:
def _dominant_strong_direction(self, s: str) -> str:
"""
Determines the dominant strong direction of a string.
Determine the dominant text direction (RTL or LTR) of a string.
Args:
s (str): Input string to analyze
Returns:
str: 'rtl' if right-to-left is dominant, 'ltr' otherwise
"""
count = Counter([ud.bidirectional(c) for c in list(s)])
rtl_count = count['R'] + count['AL'] + count['RLE'] + count["RLI"]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="gemini-srt-translator",
version="1.2.6",
version="1.3.0",
packages=find_packages(),
install_requires=[
"google-generativeai==0.8.3",
Expand Down
17 changes: 17 additions & 0 deletions tests/test_batch_translate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
import gemini_srt_translator as gst

gst.gemini_api_key = "your_gemini_api_key_here"
gst.target_language = "French"

input_dir = r"input folder"
output_dir = r"output folder"

for filename in os.listdir(input_dir):
if filename.endswith(".srt"):
input_file = os.path.join(input_dir, filename)
output_file = os.path.join(output_dir, filename)
gst.input_file = input_file
gst.output_file = output_file
gst.description = f"Translation of {filename}"
gst.translate()
2 changes: 1 addition & 1 deletion tests/test_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
gst.target_language = "French"
gst.input_file = "subtitle.srt"

gst.translate()
gst.translate()
13 changes: 13 additions & 0 deletions tests/test_translate_alloptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import gemini_srt_translator as gst

gst.gemini_api_key = ""
gst.gemini_api_key2 = ""
gst.target_language = "French"
gst.input_file = "subtitle.srt"
gst.output_file = "translated_subtitle.srt"
gst.description = "This is a medical TV Show"
gst.model_name = "gemini-1.5-flash"
gst.batch_size = 30
gst.free_quota = True

gst.translate()

0 comments on commit 162ec65

Please sign in to comment.