diff --git a/google/cloud/bigquery/async_client.py b/google/cloud/bigquery/async_client.py index c471cfaaa8..81bb9a197b 100644 --- a/google/cloud/bigquery/async_client.py +++ b/google/cloud/bigquery/async_client.py @@ -8,14 +8,12 @@ ) from google.api_core import retry_async as retries import asyncio -import google.auth.transport._aiohttp_requests -class AsyncClient(): +class AsyncClient: def __init__(self, *args, **kwargs): self._client = Client(*args, **kwargs) - async def query_and_wait( self, query, @@ -29,14 +27,14 @@ async def query_and_wait( job_retry: retries.AsyncRetry = DEFAULT_ASYNC_JOB_RETRY, page_size: Optional[int] = None, max_results: Optional[int] = None, - ) -> RowIterator: - + ) -> RowIterator: if project is None: project = self._client.project if location is None: location = self._client.location + # for some reason these cannot find the function call # if job_config is not None: # self._client._verify_job_config_type(job_config, QueryJobConfig) @@ -62,7 +60,7 @@ async def query_and_wait( ) -async def async_query_and_wait( +async def async_query_and_wait( client: "Client", query: str, *, @@ -76,23 +74,24 @@ async def async_query_and_wait( page_size: Optional[int] = None, max_results: Optional[int] = None, ) -> table.RowIterator: - # Some API parameters aren't supported by the jobs.query API. In these # cases, fallback to a jobs.insert call. if not _job_helpers._supported_by_jobs_query(job_config): return await async_wait_or_cancel( - asyncio.to_thread(_job_helpers.query_jobs_insert( # throw in a background thread - client=client, - query=query, - job_id=None, - job_id_prefix=None, - job_config=job_config, - location=location, - project=project, - retry=retry, - timeout=api_timeout, - job_retry=job_retry, - )), + asyncio.to_thread( + _job_helpers.query_jobs_insert( + client=client, + query=query, + job_id=None, + job_id_prefix=None, + job_config=job_config, + location=location, + project=project, + retry=retry, + timeout=api_timeout, + job_retry=job_retry, + ) + ), api_timeout=api_timeout, wait_timeout=wait_timeout, retry=retry, @@ -113,14 +112,12 @@ async def async_query_and_wait( if os.getenv("QUERY_PREVIEW_ENABLED", "").casefold() == "true": request_body["jobCreationMode"] = "JOB_CREATION_OPTIONAL" - request_body["requestId"] = _job_helpers.make_job_id() span_attributes = {"path": path} - # For easier testing, handle the retries ourselves. if retry is not None: - response = retry(client._call_api)( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth) - retry=None, # We're calling the retry decorator ourselves, async_retries + response = client._call_api( # ASYNCHRONOUS HTTP CALLS aiohttp (optional of google-auth), add back retry() + retry=None, # We're calling the retry decorator ourselves, async_retries, need to implement after making HTTP calls async span_name="BigQuery.query", span_attributes=span_attributes, method="POST", @@ -128,6 +125,7 @@ async def async_query_and_wait( data=request_body, timeout=api_timeout, ) + else: response = client._call_api( retry=None, @@ -141,9 +139,7 @@ async def async_query_and_wait( # Even if we run with JOB_CREATION_OPTIONAL, if there are more pages # to fetch, there will be a job ID for jobs.getQueryResults. - query_results = google.cloud.bigquery.query._QueryResults.from_api_repr( - await response - ) + query_results = google.cloud.bigquery.query._QueryResults.from_api_repr(response) page_token = query_results.page_token more_pages = page_token is not None @@ -161,7 +157,7 @@ async def async_query_and_wait( max_results=max_results, ) - result = table.RowIterator( # async of RowIterator? async version without all the pandas stuff + result = table.RowIterator( # async of RowIterator? async version without all the pandas stuff client=client, api_request=functools.partial(client._call_api, retry, timeout=api_timeout), path=None, @@ -177,12 +173,12 @@ async def async_query_and_wait( num_dml_affected_rows=query_results.num_dml_affected_rows, ) - if job_retry is not None: - return job_retry(result) # AsyncRetries, new default objects, default_job_retry_async, default_retry_async + return job_retry(result) else: return result + async def async_wait_or_cancel( job: job.QueryJob, api_timeout: Optional[float], @@ -192,12 +188,14 @@ async def async_wait_or_cancel( max_results: Optional[int], ) -> table.RowIterator: try: - return asyncio.to_thread(job.result( # run in a background thread - page_size=page_size, - max_results=max_results, - retry=retry, - timeout=wait_timeout, - )) + return asyncio.to_thread( + job.result( # run in a background thread + page_size=page_size, + max_results=max_results, + retry=retry, + timeout=wait_timeout, + ) + ) except Exception: # Attempt to cancel the job since we can't return the results. try: @@ -205,4 +203,4 @@ async def async_wait_or_cancel( except Exception: # Don't eat the original exception if cancel fails. pass - raise \ No newline at end of file + raise diff --git a/google/cloud/bigquery/retry.py b/google/cloud/bigquery/retry.py index 9acbf13820..c5fbb7fda7 100644 --- a/google/cloud/bigquery/retry.py +++ b/google/cloud/bigquery/retry.py @@ -91,8 +91,11 @@ def _job_should_retry(exc): The default job retry object. """ -DEFAULT_ASYNC_RETRY = retry_async.AsyncRetry(predicate=_should_retry, deadline=_DEFAULT_RETRY_DEADLINE) # deadline is deprecated +DEFAULT_ASYNC_RETRY = retry_async.AsyncRetry( + predicate=_should_retry, deadline=_DEFAULT_RETRY_DEADLINE +) # deadline is deprecated DEFAULT_ASYNC_JOB_RETRY = retry_async.AsyncRetry( - predicate=_job_should_retry, deadline=_DEFAULT_JOB_DEADLINE # deadline is deprecated -) \ No newline at end of file + predicate=_job_should_retry, + deadline=_DEFAULT_JOB_DEADLINE, # deadline is deprecated +) diff --git a/setup.py b/setup.py index b33f556ab4..006a4ad8bb 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ # NOTE: Maintainers, please do not require google-cloud-core>=2.x.x # Until this issue is closed # https://github.com/googleapis/google-cloud-python/issues/10566 + "google-auth >= 2.14.1, <3.0.0dev", "google-cloud-core >= 1.6.0, <3.0.0dev", "google-resumable-media >= 0.6.0, < 3.0dev", "packaging >= 20.0.0", @@ -83,9 +84,9 @@ "proto-plus >= 1.15.0, <2.0.0dev", "protobuf>=3.19.5,<5.0.0dev,!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5", # For the legacy proto-based types. ], - "google-auth": [ - "aiohttp", - ] + "aiohttp": [ + "google-auth[aiohttp]", + ], } all_extras = [] diff --git a/testing/constraints-3.7.txt b/testing/constraints-3.7.txt index 9f71bf11ab..cf46d35395 100644 --- a/testing/constraints-3.7.txt +++ b/testing/constraints-3.7.txt @@ -8,6 +8,7 @@ db-dtypes==0.3.0 geopandas==0.9.0 google-api-core==1.31.5 +google-auth==2.14.1 google-cloud-bigquery-storage==2.6.0 google-cloud-core==1.6.0 google-resumable-media==0.6.0 diff --git a/testing/constraints-3.9.txt b/testing/constraints-3.9.txt index d4c3028675..f4adf95c3f 100644 --- a/testing/constraints-3.9.txt +++ b/testing/constraints-3.9.txt @@ -4,5 +4,6 @@ # # NOTE: Not comprehensive yet, will eventually be maintained semi-automatically by # the renovate bot. +aiohttp==3.6.2 grpcio==1.47.0 pyarrow>=4.0.0 diff --git a/tests/unit/test_async_client.py b/tests/unit/test_async_client.py index a190b5973b..4725047117 100644 --- a/tests/unit/test_async_client.py +++ b/tests/unit/test_async_client.py @@ -77,6 +77,17 @@ else: PANDAS_INSTALLED_VERSION = "0.0.0" +from google.cloud.bigquery.retry import ( + DEFAULT_ASYNC_JOB_RETRY, + DEFAULT_ASYNC_RETRY, + DEFAULT_TIMEOUT, +) +from google.api_core import retry_async as retries +from google.cloud.bigquery import async_client +from google.cloud.bigquery.async_client import AsyncClient +from google.cloud.bigquery.job import query as job_query + + def asyncio_run(async_func): def wrapper(*args, **kwargs): return asyncio.run(async_func(*args, **kwargs)) @@ -94,7 +105,6 @@ def _make_credentials(): return mock.Mock(spec=google.auth.credentials.Credentials) - class TestClient(unittest.TestCase): PROJECT = "PROJECT" DS_ID = "DATASET_ID" @@ -123,12 +133,17 @@ def _make_table_resource(self): }, } + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) def test_ctor_defaults(self): from google.cloud.bigquery._http import Connection creds = _make_credentials() http = object() - client = self._make_one(project=self.PROJECT, credentials=creds, _http=http)._client + client = self._make_one( + project=self.PROJECT, credentials=creds, _http=http + )._client self.assertIsInstance(client._connection, Connection) self.assertIs(client._connection.credentials, creds) self.assertIs(client._connection.http, http) @@ -137,6 +152,9 @@ def test_ctor_defaults(self): client._connection.API_BASE_URL, Connection.DEFAULT_API_ENDPOINT ) + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) def test_ctor_w_empty_client_options(self): from google.api_core.client_options import ClientOptions @@ -154,7 +172,133 @@ def test_ctor_w_empty_client_options(self): ) @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_client_options_dict(self): + creds = _make_credentials() + http = object() + client_options = {"api_endpoint": "https://www.foo-googleapis.com"} + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual( + client._connection.API_BASE_URL, "https://www.foo-googleapis.com" + ) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_client_options_object(self): + from google.api_core.client_options import ClientOptions + + creds = _make_credentials() + http = object() + client_options = ClientOptions(api_endpoint="https://www.foo-googleapis.com") + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual( + client._connection.API_BASE_URL, "https://www.foo-googleapis.com" + ) + + @pytest.mark.skipif( + packaging.version.parse(getattr(google.api_core, "__version__", "0.0.0")) + < packaging.version.Version("2.15.0"), + reason="universe_domain not supported with google-api-core < 2.15.0", + ) + def test_ctor_w_client_options_universe(self): + creds = _make_credentials() + http = object() + client_options = {"universe_domain": "foo.com"} + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + client_options=client_options, + )._client + self.assertEqual(client._connection.API_BASE_URL, "https://bigquery.foo.com") + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_location(self): + from google.cloud.bigquery._http import Connection + + creds = _make_credentials() + http = object() + location = "us-central" + client = self._make_one( + project=self.PROJECT, credentials=creds, _http=http, location=location + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_query_job_config(self): + from google.cloud.bigquery._http import Connection + from google.cloud.bigquery import QueryJobConfig + + creds = _make_credentials() + http = object() + location = "us-central" + job_config = QueryJobConfig() + job_config.dry_run = True + + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + location=location, + default_query_job_config=job_config, + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + self.assertIsInstance(client._default_query_job_config, QueryJobConfig) + self.assertTrue(client._default_query_job_config.dry_run) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" + ) + def test_ctor_w_load_job_config(self): + from google.cloud.bigquery._http import Connection + from google.cloud.bigquery import LoadJobConfig + + creds = _make_credentials() + http = object() + location = "us-central" + job_config = LoadJobConfig() + job_config.create_session = True + + client = self._make_one( + project=self.PROJECT, + credentials=creds, + _http=http, + location=location, + default_load_job_config=job_config, + )._client + self.assertIsInstance(client._connection, Connection) + self.assertIs(client._connection.credentials, creds) + self.assertIs(client._connection.http, http) + self.assertEqual(client.location, location) + + self.assertIsInstance(client._default_load_job_config, LoadJobConfig) + self.assertTrue(client._default_load_job_config.create_session) + + @pytest.mark.skipif( + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_defaults(self): @@ -200,7 +344,7 @@ async def test_query_and_wait_defaults(self): self.assertFalse(sent["useLegacySql"]) @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_w_default_query_job_config(self): @@ -237,7 +381,7 @@ async def test_query_and_wait_w_default_query_job_config(self): self.assertEqual(sent["labels"], {"default-label": "default-value"}) @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_w_job_config(self): @@ -275,7 +419,7 @@ async def test_query_and_wait_w_job_config(self): self.assertEqual(sent["labels"], {"job_config-label": "job_config-value"}) @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_w_location(self): @@ -300,7 +444,7 @@ async def test_query_and_wait_w_location(self): self.assertEqual(sent["location"], "not-the-client-location") @pytest.mark.skipif( - sys.version_info < (3, 9), reason="requires python3.9 or higher" + sys.version_info < (3, 9), reason="requires python3.9 or higher" ) @asyncio_run async def test_query_and_wait_w_project(self): @@ -320,4 +464,7 @@ async def test_query_and_wait_w_project(self): # conn.api_request.assert_called_once() _, req = conn.api_request.call_args self.assertEqual(req["method"], "POST") - self.assertEqual(req["path"], "/projects/not-the-client-project/queries") \ No newline at end of file + self.assertEqual(req["path"], "/projects/not-the-client-project/queries") + + +# Add tests for async_query_and_wait and async_wait_or_cancel