From a17dd1f0c6100b89d203043b6902f23fb9130213 Mon Sep 17 00:00:00 2001 From: Google Colaboratory Team Date: Tue, 5 Mar 2024 09:54:39 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 612877063 --- google/colab/_import_hooks/_generativeai.py | 25 +++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/google/colab/_import_hooks/_generativeai.py b/google/colab/_import_hooks/_generativeai.py index 209651d4..a92cf4f3 100644 --- a/google/colab/_import_hooks/_generativeai.py +++ b/google/colab/_import_hooks/_generativeai.py @@ -13,8 +13,8 @@ # limitations under the License. """Import hook for google.generativeai in Colab. -This will enable the IP geolocation restrictions for the PaLM and Gemini -APIs to be based on the location of the user instead of the runtime VM. +This will enable the IP geolocation restrictions for the Gemini API to be based +on the location of the user instead of the runtime VM. """ import imp # pylint: disable=deprecated-module @@ -24,7 +24,7 @@ class _GenerativeAIImportHook: - """Enables the PaLM and Gemini API clients libraries to be customized upon import.""" + """Enables the Gemini API client library to be customized upon import.""" def find_module(self, fullname, path=None): if fullname != 'google.generativeai': @@ -53,6 +53,7 @@ def load_module(self, fullname): try: import functools # pylint:disable=g-import-not-at-top import json # pylint:disable=g-import-not-at-top + import google.api_core.exceptions # pylint:disable=g-import-not-at-top from google.colab import output # pylint:disable=g-import-not-at-top from google.colab.html import _background_server # pylint:disable=g-import-not-at-top import portpicker # pylint:disable=g-import-not-at-top @@ -121,12 +122,28 @@ def start(): return p start() + + api_endpoint = f'http://localhost:{port}' orig_configure = generativeai_module.configure generativeai_module.configure = functools.partial( orig_configure, transport='rest', - client_options={'api_endpoint': f'http://localhost:{port}'}, + client_options={'api_endpoint': api_endpoint}, ) + + # Change error messages to use the generative language API endpoint + # instead of the proxy endpoint. + orig_from_http_response = google.api_core.exceptions.from_http_response + + @functools.wraps(orig_from_http_response) + def new_from_http_response(*args, **kwargs): + error = orig_from_http_response(*args, **kwargs) + error.message = error.message.replace( + api_endpoint, 'https://generativelanguage.googleapis.com' + ) + return error + + google.api_core.exceptions.from_http_response = new_from_http_response except: # pylint: disable=bare-except logging.exception('Error customizing google.generativeai.') os.environ['COLAB_GENERATIVEAI_IMPORT_HOOK_EXCEPTION'] = '1'