Skip to content

Commit c75772b

Browse files
committed
Refactor gdrive_downloader.py - Google auth session for private documents
- Used a double-check locking to ensure that multiple threads won't create redundant auth sessions simultaneously.
1 parent 76b80be commit c75772b

File tree

3 files changed

+38
-14
lines changed

3 files changed

+38
-14
lines changed

daras_ai_v2/asr.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import multiprocessing
33
import os.path
44
import tempfile
5+
import threading
56
import typing
67
from enum import Enum
78

@@ -873,9 +874,6 @@ def _translate_text(
873874
return result.strip()
874875

875876

876-
_session = None
877-
878-
879877
def _MinT_translate_one_text(
880878
text: str, source_language: str, target_language: str
881879
) -> str:
@@ -894,18 +892,27 @@ def _MinT_translate_one_text(
894892
return tanslation.get("translation", text)
895893

896894

897-
def get_google_auth_session():
895+
_session = None
896+
_session_lock = threading.Lock()
897+
898+
899+
def get_google_auth_session(scopes: typing.Optional[list[str]] = None):
898900
global _session
899901

900902
if _session is None:
901-
import google.auth
902-
from google.auth.transport.requests import AuthorizedSession
903+
with _session_lock:
904+
if _session is None:
905+
import google.auth
906+
from google.auth.transport.requests import AuthorizedSession
903907

904-
creds, project = google.auth.default(
905-
scopes=["https://www.googleapis.com/auth/cloud-platform"]
906-
)
907-
# takes care of refreshing the token and adding it to request headers
908-
_session = AuthorizedSession(credentials=creds), project
908+
if not scopes:
909+
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
910+
911+
creds, project = google.auth.default(
912+
scopes=scopes,
913+
)
914+
# takes care of refreshing the token and adding it to request headers
915+
_session = AuthorizedSession(credentials=creds), project
909916

910917
return _session
911918

daras_ai_v2/gdrive_downloader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def gdrive_download(
7474
) -> tuple[bytes, str]:
7575
from googleapiclient import discovery
7676
from googleapiclient.http import MediaIoBaseDownload
77+
from daras_ai_v2.asr import get_google_auth_session
7778

7879
if export_links is None:
7980
export_links = {}
@@ -87,7 +88,9 @@ def gdrive_download(
8788
# export google docs to appropriate type
8889
export_mime_type = DOCS_EXPORT_MIMETYPES.get(mime_type, mime_type)
8990
if f_url_export := export_links.get(export_mime_type, None):
90-
r = requests.get(f_url_export)
91+
drive_scopes = ["https://www.googleapis.com/auth/drive.readonly"]
92+
session, _ = get_google_auth_session(drive_scopes)
93+
r = session.get(f_url_export)
9194
file_bytes = r.content
9295
raise_for_status(r, is_user_url=True)
9396
return file_bytes, export_mime_type

daras_ai_v2/vector_search.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,23 @@ def doc_url_to_file_metadata(f_url: str) -> FileMetadata:
373373
meta = gdrive_metadata(url_to_gdrive_file_id(f))
374374
except HttpError as e:
375375
if e.status_code == 404:
376+
from google.oauth2.service_account import Credentials
377+
378+
service_account_client_email = Credentials.from_service_account_file(
379+
settings.service_account_key_path
380+
).service_account_email
381+
376382
raise UserError(
377-
f"Could not download the google doc at {f_url} "
378-
f"Please make sure to make the document public for viewing."
383+
# language=HTML
384+
f"""<p>This knowledge base Google Doc is not accessible: <a href="{f_url}" target="_blank">{f_url}</a></p>
385+
<p>To address this:</p>
386+
<ul>
387+
<li>Please make the Google Doc publicly viewable, or</li>
388+
<li>Share the Doc or its parent folder with <br>
389+
<a href="mailto:{service_account_client_email}">{service_account_client_email}</a>
390+
as an authorized viewer.
391+
</li>
392+
</ul>"""
379393
) from e
380394
else:
381395
raise

0 commit comments

Comments
 (0)