diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index dade56592..e8d95f0cc 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -1,6 +1,7 @@ import multiprocessing import os.path import tempfile +import threading import typing from enum import Enum @@ -932,9 +933,6 @@ def _translate_text( return result.strip() -_session = None - - def _MinT_translate_one_text( text: str, source_language: str, target_language: str ) -> str: @@ -953,18 +951,27 @@ def _MinT_translate_one_text( return tanslation.get("translation", text) -def get_google_auth_session(): +_session = None +_session_lock = threading.Lock() + + +def get_google_auth_session(scopes: typing.Optional[list[str]] = None): global _session if _session is None: - import google.auth - from google.auth.transport.requests import AuthorizedSession + with _session_lock: + if _session is None: + import google.auth + from google.auth.transport.requests import AuthorizedSession - creds, project = google.auth.default( - scopes=["https://www.googleapis.com/auth/cloud-platform"] - ) - # takes care of refreshing the token and adding it to request headers - _session = AuthorizedSession(credentials=creds), project + if not scopes: + scopes = ["https://www.googleapis.com/auth/cloud-platform"] + + creds, project = google.auth.default( + scopes=scopes, + ) + # takes care of refreshing the token and adding it to request headers + _session = AuthorizedSession(credentials=creds), project return _session diff --git a/daras_ai_v2/gdrive_downloader.py b/daras_ai_v2/gdrive_downloader.py index ce1e23620..5e77c7a2f 100644 --- a/daras_ai_v2/gdrive_downloader.py +++ b/daras_ai_v2/gdrive_downloader.py @@ -1,7 +1,6 @@ import io import typing from furl import furl -import requests from daras_ai_v2.exceptions import UserError from daras_ai_v2.functional import flatmap_parallel @@ -74,6 +73,7 @@ def gdrive_download( ) -> tuple[bytes, str]: from googleapiclient import discovery from googleapiclient.http import MediaIoBaseDownload + from daras_ai_v2.asr import get_google_auth_session if export_links is None: export_links = {} @@ -87,7 +87,9 @@ def gdrive_download( # export google docs to appropriate type export_mime_type = DOCS_EXPORT_MIMETYPES.get(mime_type, mime_type) if f_url_export := export_links.get(export_mime_type, None): - r = requests.get(f_url_export) + drive_scopes = ["https://www.googleapis.com/auth/drive.readonly"] + session, _ = get_google_auth_session(drive_scopes) + r = session.get(f_url_export) file_bytes = r.content raise_for_status(r, is_user_url=True) return file_bytes, export_mime_type diff --git a/daras_ai_v2/vector_search.py b/daras_ai_v2/vector_search.py index 8f493cdd6..af6597ef2 100644 --- a/daras_ai_v2/vector_search.py +++ b/daras_ai_v2/vector_search.py @@ -377,9 +377,23 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata: meta = gdrive_metadata(url_to_gdrive_file_id(f)) except HttpError as e: if e.status_code == 404: + from google.oauth2.service_account import Credentials + + service_account_client_email = Credentials.from_service_account_file( + settings.service_account_key_path + ).service_account_email + raise UserError( - f"Could not download the google doc at {f_url} " - f"Please make sure to make the document public for viewing." + # language=HTML + f"""

This knowledge base Google Doc is not accessible: {f_url}

+

To address this:

+ """ ) from e else: raise