From 756ab9f4932b68dcce7841fbc678f87c9b96a360 Mon Sep 17 00:00:00 2001 From: bookfere Date: Wed, 6 Nov 2024 16:19:41 +0800 Subject: [PATCH] fix: Fixed the bug preventing Gemini translation. #365 --- engines/base.py | 6 ++++-- engines/openai.py | 14 +++++++++----- lib/utils.py | 10 +++------- tests/test_engine.py | 38 +++++++++++++++++++++++--------------- tests/test_utils.py | 25 ++----------------------- 5 files changed, 41 insertions(+), 52 deletions(-) diff --git a/engines/base.py b/engines/base.py index edca9cf..54bbb95 100644 --- a/engines/base.py +++ b/engines/base.py @@ -166,8 +166,10 @@ def _is_auto_lang(self): def translate(self, text): try: response = request( - self.get_endpoint(), self.get_body(text), self.get_headers(), - self.method, self.request_timeout, self.proxy_uri, self.stream) + url=self.get_endpoint(), data=self.get_body(text), + headers=self.get_headers(), method=self.method, + timeout=self.request_timeout, proxy_uri=self.proxy_uri, + raw_object=self.stream) return self.get_result(response) except Exception as e: # Combine the error messages for investigation. diff --git a/engines/openai.py b/engines/openai.py index ebd700a..d6cd482 100644 --- a/engines/openai.py +++ b/engines/openai.py @@ -3,6 +3,8 @@ import json import uuid +from mechanize._response import response_seek_wrapper as Response + from .. import EbookTranslator from ..lib.utils import request from ..lib.exception import UnsupportedModel @@ -205,14 +207,16 @@ def retrieve(self, output_file_id): del headers['Content-Type'] response = request( '%s/%s/content' % (self.file_endpoint, output_file_id), - headers=headers, as_bytes=True) + headers=headers, raw_object=True) + assert isinstance(response, Response) translations = {} - for line in io.BytesIO(response): + for line in io.BytesIO(response.read()): result = json.loads(line) - response = result['response'] - if response.get('status_code') == 200: - content = response['body']['choices'][0]['message']['content'] + response_item = result['response'] + if response_item.get('status_code') == 200: + content = response_item[ + 'body']['choices'][0]['message']['content'] translations[result.get('custom_id')] = content return translations diff --git a/lib/utils.py b/lib/utils.py index 412297c..88ab594 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -147,14 +147,14 @@ def traceback_error(): def request( url, data=None, headers={}, method='GET', timeout=30, proxy_uri=None, - as_bytes=False, stream=False): + raw_object=False) -> Response | str: br = Browser() br.set_handle_robots(False) # Do not verify SSL certificates br.set_ca_data( context=ssl._create_unverified_context(cert_reqs=ssl.CERT_NONE)) # Set up proxy - proxies = {} + proxies: dict = {} if proxy_uri is not None: proxies.update(http=proxy_uri, https=proxy_uri) else: @@ -171,8 +171,4 @@ def request( _request = Request(url, data, headers=headers, timeout=timeout) br.open(_request) response: Response = br.response() - if stream: - return response - if as_bytes: - return response.read() - return response.read().decode('utf-8').strip() + return response if raw_object else response.read().decode('utf-8').strip() diff --git a/tests/test_engine.py b/tests/test_engine.py index 08f25ee..5931ceb 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -192,9 +192,10 @@ def test_translate(self, mock_request): '{"text": "你好世界"}', self.translator.translate('Hello World')) mock_request.assert_called_once_with( - 'https://example.com/api', '{"text": "Hello World"}', - {'Authorization': 'Bearer a', 'Content-Type': 'application/json'}, - 'POST', 10.0, None, False) + url='https://example.com/api', data='{"text": "Hello World"}', + headers={ + 'Authorization': 'Bearer a', 'Content-Type': 'application/json' + }, method='POST', timeout=10.0, proxy_uri=None, raw_object=False) @patch(module_name + '.base.request') def test_translate_with_stream(self, mock_request): @@ -205,9 +206,10 @@ def test_translate_with_stream(self, mock_request): self.assertIs(mock_response, self.translator.translate('Hello World')) mock_request.assert_called_once_with( - 'https://example.com/api', '{"text": "Hello World"}', - {'Authorization': 'Bearer a', 'Content-Type': 'application/json'}, - 'POST', 10.0, None, True) + url='https://example.com/api', data='{"text": "Hello World"}', + headers={ + 'Authorization': 'Bearer a', 'Content-Type': 'application/json' + }, method='POST', timeout=10.0, proxy_uri=None, raw_object=True) @patch(module_name + '.base.request') def test_translate_with_http_error(self, mock_request): @@ -457,7 +459,8 @@ def test_translate_stream(self, mock_request, mock_et): result = self.translator.translate('Hello World!') mock_request.assert_called_with( - url, data, headers, 'POST', 30.0, None, True) + url=url, data=data, headers=headers, method='POST', timeout=30.0, + proxy_uri=None, raw_object=True) self.assertIsInstance(result, GeneratorType) self.assertEqual('你好世界!', ''.join(result)) @@ -626,7 +629,7 @@ def test_retrieve(self, mock_request): line_2 = ( b'{"custom_id":"def","response":{"status_code":200,"body":{' b'"choices": [{"message": {"content": "B"}}]}}}') - mock_request.return_value = line_1 + b'\n' + line_2 + mock_request.return_value.read.return_value = line_1 + b'\n' + line_2 self.mock_translator.get_headers.return_value = { 'Content-Type': 'application/json', 'Authorization': 'Bearer abc', @@ -641,7 +644,8 @@ def test_retrieve(self, mock_request): 'User-Agent': 'Ebook-Translator/v1.0.0'} mock_request.assert_called_once_with( 'https://api.openai.com/v1/files/test-batch-id/content', - headers=headers, as_bytes=True) + headers=headers, raw_object=True) + mock_request().read.assert_called_once() @patch(module_name + '.openai.request') def test_create(self, mock_request): @@ -804,7 +808,8 @@ def test_translate_stream(self, mock_request): self.translator.endpoint = url result = self.translator.translate('Hello World!') mock_request.assert_called_with( - url, data, headers, 'POST', 30.0, None, True) + url=url, data=data, headers=headers, method='POST', timeout=30.0, + proxy_uri=None, raw_object=True) self.assertIsInstance(result, GeneratorType) self.assertEqual('你好世界!', ''.join(result)) @@ -873,7 +878,8 @@ def test_translate(self, mock_request, mock_et): result = self.translator.translate('Hello World!') mock_request.assert_called_with( - url, data, headers, 'POST', 30.0, None, False) + url=url, data=data, headers=headers, method='POST', timeout=30.0, + proxy_uri=None, raw_object=False) self.assertEqual('你好世界!', result) @patch(module_name + '.anthropic.EbookTranslator') @@ -943,7 +949,8 @@ def test_translate_stream(self, mock_request, mock_et): self.translator.model = 'claude-2.1' result = self.translator.translate('Hello World!') mock_request.assert_called_with( - url, data, headers, 'POST', 30.0, None, True) + url=url, data=data, headers=headers, method='POST', timeout=30.0, + proxy_uri=None, raw_object=True) self.assertIsInstance(result, GeneratorType) self.assertEqual('你好世界!', ''.join(result)) @@ -1076,9 +1083,10 @@ def test_translate(self, mock_request): mock_request.return_value = '{"text": "你好世界"}' self.assertEqual('你好世界', translator.translate('Hello "World"')) mock_request.assert_called_with( - 'https://example.api', - b'{"source": "en", "target": "zh", "text": "Hello \\"World\\""}', - {'Content-Type': 'application/json'}, 'POST', 10.0, None, False) + url='https://example.api', data=b'{"source": "en", "target": "zh",' + b' "text": "Hello \\"World\\""}', + headers={'Content-Type': 'application/json'}, method='POST', + timeout=10.0, proxy_uri=None, raw_object=False) # XML response translator.response = 'response.text' mock_request.return_value = '你好世界' diff --git a/tests/test_utils.py b/tests/test_utils.py index 24d6fb1..94f8126 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -101,33 +101,12 @@ def test_request_output_as_string( @patch(module_name + '.ssl') @patch(module_name + '.Request') @patch(module_name + '.Browser') - def test_request_output_as_bytes( + def test_request_output_as_raw_object( self, mock_browser, mock_request, mock_ssl): browser = mock_browser() self.assertIs( - request('https://example.com/api', 'test data', as_bytes=True), - browser.response().read()) - - browser.set_handle_robots.assert_called_once_with(False) - mock_ssl._create_unverified_context.assert_called_once_with( - cert_reqs=mock_ssl.CERT_NONE) - browser.set_ca_data.assert_called_once_with( - context=mock_ssl._create_unverified_context()) - - mock_request.assert_called_once_with( - 'https://example.com/api', 'test data', headers={}, timeout=30, - method='GET') - browser.open.assert_called_once_with(mock_request()) - - @patch(module_name + '.ssl') - @patch(module_name + '.Request') - @patch(module_name + '.Browser') - def test_request_with_stream(self, mock_browser, mock_request, mock_ssl): - browser = mock_browser() - - self.assertIs( - request('https://example.com/api', 'test data', stream=True), + request('https://example.com/api', 'test data', raw_object=True), browser.response()) browser.set_handle_robots.assert_called_once_with(False)