-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Async for files.py functions #586
base: main
Are you sure you want to change the base?
Changes from all commits
f0a4ef9
16cf750
668c520
3c3c661
97f0b46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,12 +22,22 @@ | |
from google.generativeai import protos | ||
from itertools import islice | ||
from io import IOBase | ||
import asyncio | ||
|
||
from google.generativeai.types import file_types | ||
|
||
from google.generativeai.client import get_default_file_client | ||
|
||
__all__ = ["upload_file", "get_file", "list_files", "delete_file"] | ||
__all__ = [ | ||
"upload_file", | ||
"get_file", | ||
"list_files", | ||
"delete_file", | ||
"upload_file_async", | ||
"get_file_async", | ||
"list_files_async", | ||
"delete_file_async", | ||
] | ||
|
||
mimetypes.add_type("image/webp", ".webp") | ||
|
||
|
@@ -88,6 +98,10 @@ def upload_file( | |
return file_types.File(response) | ||
|
||
|
||
async def upload_file_async(*args, **kwargs): | ||
return await asyncio.to_thread(upload_file, *args, **kwargs) | ||
|
||
|
||
def list_files(page_size=100) -> Iterable[file_types.File]: | ||
"""Calls the API to list files using a supported file service.""" | ||
client = get_default_file_client() | ||
|
@@ -97,6 +111,10 @@ def list_files(page_size=100) -> Iterable[file_types.File]: | |
yield file_types.File(proto) | ||
|
||
|
||
async def list_files_async(*args, **kwargs): | ||
return await asyncio.to_thread(list_files, *args, **kwargs) | ||
|
||
|
||
def get_file(name: str) -> file_types.File: | ||
"""Calls the API to retrieve a specified file using a supported file service.""" | ||
if "/" not in name: | ||
|
@@ -105,6 +123,10 @@ def get_file(name: str) -> file_types.File: | |
return file_types.File(client.get_file(name=name)) | ||
|
||
|
||
async def get_file_async(*args, **kwargs): | ||
return await asyncio.to_thread(get_file, *args, **kwargs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We have an async client, we should use it. Let's copy-paste the code, and add the missing |
||
|
||
|
||
def delete_file(name: str | file_types.File | protos.File): | ||
"""Calls the API to permanently delete a specified file using a supported file service.""" | ||
if isinstance(name, (file_types.File, protos.File)): | ||
|
@@ -114,3 +136,7 @@ def delete_file(name: str | file_types.File | protos.File): | |
request = protos.DeleteFileRequest(name=name) | ||
client = get_default_file_client() | ||
client.delete_file(request=request) | ||
|
||
|
||
async def delete_file_async(*args, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. |
||
return await asyncio.to_thread(get_file, *args, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,11 +12,14 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
from absl.testing import absltest | ||
|
||
import google | ||
import google.generativeai as genai | ||
import pathlib | ||
import tempfile | ||
import asyncio | ||
|
||
media = pathlib.Path(__file__).parents[1] / "third_party" | ||
|
||
|
@@ -127,5 +130,29 @@ def test_files_delete(self): | |
# [END files_delete] | ||
|
||
|
||
class AsyncTests(absltest.TestCase, unittest.IsolatedAsyncioTestCase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should test all the methods, I think list is broken right now. |
||
async def test_upload_file_async(self): | ||
import google.generativeai.files as files | ||
|
||
tempdir = pathlib.Path(tempfile.mkdtemp()) | ||
results = [] | ||
|
||
async def create_and_upload_file(n: int): | ||
fname = tempdir / str(n) | ||
fname.write_text(str(n)) | ||
file_obj = await files.upload_file_async(fname, mime_type="text/plain") | ||
results.append(file_obj) | ||
|
||
tasks = [] | ||
for n in range(5): | ||
tasks.append(asyncio.create_task(create_and_upload_file(n))) | ||
|
||
for task in tasks: | ||
await task | ||
|
||
self.assertLen(results, 5) | ||
self.assertEqual(sorted(int(f.display_name) for f in results), list(range(5))) | ||
|
||
|
||
if __name__ == "__main__": | ||
absltest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this will fail because we want an async iterator to be returned here.
Can you copy the list_files code and add the missing
async
s?