Skip to content

Commit

Permalink
use _make_client in more places
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-sanche committed Jan 19, 2024
1 parent 52ec52a commit b2bf56f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 28 deletions.
7 changes: 3 additions & 4 deletions tests/system/data/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ async def test_check_and_mutate(
expected_value = true_mutation_value if expected_result else false_mutation_value
assert (await _retrieve_cell_value(table, row_key)) == expected_value


@pytest.mark.skipif(
bool(os.environ.get(BIGTABLE_EMULATOR)),
reason="emulator doesn't raise InvalidArgument",
Expand All @@ -610,13 +611,11 @@ async def test_check_and_mutate_empty_request(client, table):

with pytest.raises(exceptions.InvalidArgument) as e:
await table.check_and_mutate_row(
b'row_key',
None,
true_case_mutations=None,
false_case_mutations=None
b"row_key", None, true_case_mutations=None, false_case_mutations=None
)
assert "No mutations provided" in str(e.value)


@pytest.mark.usefixtures("table")
@retry.AsyncRetry(predicate=retry.if_exception_type(ClientError), initial=1, maximum=5)
@pytest.mark.asyncio
Expand Down
46 changes: 22 additions & 24 deletions tests/unit/data/_async/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
def _make_client(*args, use_emulator=True, **kwargs):
import os
from google.cloud.bigtable.data._async.client import BigtableDataClientAsync

env_mask = {}
# by default, use emulator mode to avoid auth issues in CI
# emulator mode must be disabled by tests that check channel pooling/refresh background tasks
Expand Down Expand Up @@ -271,7 +272,9 @@ async def test__start_background_channel_refresh_tasks_exist(self):
@pytest.mark.parametrize("pool_size", [1, 3, 7])
async def test__start_background_channel_refresh(self, pool_size):
# should create background tasks for each channel
client = self._make_one(project="project-id", pool_size=pool_size, use_emulator=False)
client = self._make_one(
project="project-id", pool_size=pool_size, use_emulator=False
)
ping_and_warm = AsyncMock()
client._ping_and_warm_instances = ping_and_warm
client._start_background_channel_refresh()
Expand All @@ -291,7 +294,9 @@ async def test__start_background_channel_refresh(self, pool_size):
async def test__start_background_channel_refresh_tasks_names(self):
# if tasks exist, should do nothing
pool_size = 3
client = self._make_one(project="project-id", pool_size=pool_size, use_emulator=False)
client = self._make_one(
project="project-id", pool_size=pool_size, use_emulator=False
)
for i in range(pool_size):
name = client._channel_refresh_tasks[i].get_name()
assert str(i) in name
Expand Down Expand Up @@ -938,9 +943,13 @@ async def test_multiple_pool_sizes(self):
# should be able to create multiple clients with different pool sizes without issue
pool_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]
for pool_size in pool_sizes:
client = self._make_one(project="project-id", pool_size=pool_size, use_emulator=False)
client = self._make_one(
project="project-id", pool_size=pool_size, use_emulator=False
)
assert len(client._channel_refresh_tasks) == pool_size
client_duplicate = self._make_one(project="project-id", pool_size=pool_size, use_emulator=False)
client_duplicate = self._make_one(
project="project-id", pool_size=pool_size, use_emulator=False
)
assert len(client_duplicate._channel_refresh_tasks) == pool_size
assert str(pool_size) in str(client.transport)
await client.close()
Expand All @@ -953,7 +962,9 @@ async def test_close(self):
)

pool_size = 7
client = self._make_one(project="project-id", pool_size=pool_size, use_emulator=False)
client = self._make_one(
project="project-id", pool_size=pool_size, use_emulator=False
)
assert len(client._channel_refresh_tasks) == pool_size
tasks_list = list(client._channel_refresh_tasks)
for task in client._channel_refresh_tasks:
Expand Down Expand Up @@ -1003,10 +1014,9 @@ async def test_context_manager(self):

def test_client_ctor_sync(self):
# initializing client in a sync context should raise RuntimeError
from google.cloud.bigtable.data._async.client import BigtableDataClientAsync

with pytest.warns(RuntimeWarning) as warnings:
client = BigtableDataClientAsync(project="project-id")
client = _make_client(project="project-id")
expected_warning = [w for w in warnings if "client.py" in w.filename]
assert len(expected_warning) == 1
assert (
Expand All @@ -1020,7 +1030,6 @@ def test_client_ctor_sync(self):
class TestTableAsync:
@pytest.mark.asyncio
async def test_table_ctor(self):
from google.cloud.bigtable.data._async.client import BigtableDataClientAsync
from google.cloud.bigtable.data._async.client import TableAsync
from google.cloud.bigtable.data._async.client import _WarmedInstanceKey

Expand All @@ -1033,7 +1042,7 @@ async def test_table_ctor(self):
expected_read_rows_attempt_timeout = 0.5
expected_mutate_rows_operation_timeout = 2.5
expected_mutate_rows_attempt_timeout = 0.75
client = BigtableDataClientAsync()
client = _make_client()
assert not client._active_instances

table = TableAsync(
Expand Down Expand Up @@ -1088,12 +1097,11 @@ async def test_table_ctor_defaults(self):
"""
should provide default timeout values and app_profile_id
"""
from google.cloud.bigtable.data._async.client import BigtableDataClientAsync
from google.cloud.bigtable.data._async.client import TableAsync

expected_table_id = "table-id"
expected_instance_id = "instance-id"
client = BigtableDataClientAsync()
client = _make_client()
assert not client._active_instances

table = TableAsync(
Expand All @@ -1119,10 +1127,9 @@ async def test_table_ctor_invalid_timeout_values(self):
"""
bad timeout values should raise ValueError
"""
from google.cloud.bigtable.data._async.client import BigtableDataClientAsync
from google.cloud.bigtable.data._async.client import TableAsync

client = BigtableDataClientAsync()
client = _make_client()

timeout_pairs = [
("default_operation_timeout", "default_attempt_timeout"),
Expand Down Expand Up @@ -1240,10 +1247,8 @@ async def test_customizable_retryable_errors(
Test that retryable functions support user-configurable arguments, and that the configured retryables are passed
down to the gapic layer.
"""
from google.cloud.bigtable.data import BigtableDataClientAsync

with mock.patch(retry_fn_path) as retry_fn_mock:
async with BigtableDataClientAsync() as client:
async with _make_client() as client:
table = client.get_table("instance-id", "table-id")
expected_predicate = lambda a: a in expected_retryables # noqa
retry_fn_mock.side_effect = RuntimeError("stop early")
Expand Down Expand Up @@ -1291,14 +1296,13 @@ async def test_customizable_retryable_errors(
async def test_call_metadata(self, include_app_profile, fn_name, fn_args, gapic_fn):
"""check that all requests attach proper metadata headers"""
from google.cloud.bigtable.data import TableAsync
from google.cloud.bigtable.data import BigtableDataClientAsync

profile = "profile" if include_app_profile else None
with mock.patch(
f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}", mock.AsyncMock()
) as gapic_mock:
gapic_mock.side_effect = RuntimeError("stop early")
async with BigtableDataClientAsync() as client:
async with _make_client() as client:
table = TableAsync(client, "instance-id", "table-id", profile)
try:
test_fn = table.__getattribute__(fn_name)
Expand Down Expand Up @@ -1825,7 +1829,6 @@ async def test_row_exists(self, return_value, expected_result):


class TestReadRowsSharded:

@pytest.mark.asyncio
async def test_read_rows_sharded_empty_query(self):
async with _make_client() as client:
Expand Down Expand Up @@ -1984,7 +1987,6 @@ async def test_read_rows_sharded_batching(self):


class TestSampleRowKeys:

async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]):
from google.cloud.bigtable_v2.types import SampleRowKeysResponse

Expand Down Expand Up @@ -2133,7 +2135,6 @@ async def test_sample_row_keys_non_retryable_errors(self, non_retryable_exceptio


class TestMutateRow:

@pytest.mark.asyncio
@pytest.mark.parametrize(
"mutation_arg",
Expand Down Expand Up @@ -2306,7 +2307,6 @@ async def test_mutate_row_no_mutations(self, mutations):


class TestBulkMutateRows:

async def _mock_response(self, response_list):
from google.cloud.bigtable_v2.types import MutateRowsResponse
from google.rpc import status_pb2
Expand Down Expand Up @@ -2683,7 +2683,6 @@ async def test_bulk_mutate_error_recovery(self):


class TestCheckAndMutateRow:

@pytest.mark.parametrize("gapic_result", [True, False])
@pytest.mark.asyncio
async def test_check_and_mutate(self, gapic_result):
Expand Down Expand Up @@ -2832,7 +2831,6 @@ async def test_check_and_mutate_mutations_parsing(self):


class TestReadModifyWriteRow:

@pytest.mark.parametrize(
"call_rules,expected_rules",
[
Expand Down

0 comments on commit b2bf56f

Please sign in to comment.