Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 612877063
  • Loading branch information
colaboratory-team committed Mar 5, 2024
1 parent 7f20e6e commit e1e01e3
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions google/colab/_import_hooks/_generativeai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.
"""Import hook for google.generativeai in Colab.
This will enable the IP geolocation restrictions for the PaLM and Gemini
APIs to be based on the location of the user instead of the runtime VM.
This will enable the IP geolocation restrictions for the Gemini API to be based
on the location of the user instead of the runtime VM.
"""

import imp # pylint: disable=deprecated-module
Expand All @@ -24,7 +24,7 @@


class _GenerativeAIImportHook:
"""Enables the PaLM and Gemini API clients libraries to be customized upon import."""
"""Enables the Gemini API client library to be customized upon import."""

def find_module(self, fullname, path=None):
if fullname != 'google.generativeai':
Expand Down Expand Up @@ -53,6 +53,7 @@ def load_module(self, fullname):
try:
import functools # pylint:disable=g-import-not-at-top
import json # pylint:disable=g-import-not-at-top
import google.api_core.exceptions # pylint:disable=g-import-not-at-top
from google.colab import output # pylint:disable=g-import-not-at-top
from google.colab.html import _background_server # pylint:disable=g-import-not-at-top
import portpicker # pylint:disable=g-import-not-at-top
Expand Down Expand Up @@ -121,12 +122,28 @@ def start():
return p

start()

api_endpoint = f'http://localhost:{port}'
orig_configure = generativeai_module.configure
generativeai_module.configure = functools.partial(
orig_configure,
transport='rest',
client_options={'api_endpoint': f'http://localhost:{port}'},
client_options={'api_endpoint': api_endpoint},
)

# Change error messages to use the generative language API endpoint
# instead of the proxy endpoint.
orig_from_http_response = google.api_core.exceptions.from_http_response

@functools.wraps(orig_from_http_response)
def new_from_http_response(*args, **kwargs):
error = orig_from_http_response(*args, **kwargs)
error.message = error.message.replace(
api_endpoint, 'https://generativelanguage.googleapis.com'
)
return error

google.api_core.exceptions.from_http_response = new_from_http_response
except: # pylint: disable=bare-except
logging.exception('Error customizing google.generativeai.')
os.environ['COLAB_GENERATIVEAI_IMPORT_HOOK_EXCEPTION'] = '1'
Expand Down

0 comments on commit e1e01e3

Please sign in to comment.