diff --git a/tests/test_groups.py b/tests/test_groups.py index f2b18df..419059f 100644 --- a/tests/test_groups.py +++ b/tests/test_groups.py @@ -1,50 +1,37 @@ -import pytest - -from pytest_test_groups import get_group, get_group_size - - -def test_group_size_computed_correctly_for_even_group(): - expected = 8 - actual = get_group_size(32, 4) # 32 total tests; 4 groups - - assert expected == actual - +from itertools import chain -def test_group_size_computed_correctly_for_odd_group(): - expected = 8 - actual = get_group_size(31, 4) # 31 total tests; 4 groups +import pytest - assert expected == actual +from pytest_test_groups import get_group def test_group_is_the_proper_size(): items = [str(i) for i in range(32)] group = get_group(items, 8, 1) - assert len(group) == 8 + assert len(group) == 4 def test_all_groups_together_form_original_set_of_tests(): - items = [str(i) for i in range(32)] - - groups = [get_group(items, 8, i) for i in range(1, 5)] - - combined = [] - for group in groups: - combined.extend(group) - - assert combined == items + group_count = 8 + for item_size in range(group_count, 32): + items = [str(i) for i in range(item_size)] + groups = [get_group(items, group_count, i) for i in range(1, group_count + 1)] + combined = set(chain.from_iterable(groups)) + assert combined == set(items) def test_group_that_is_too_high_raises_value_error(): items = [str(i) for i in range(32)] with pytest.raises(ValueError): - get_group(items, 8, 5) + # When group_count=4, group_id=5 is out of bounds + get_group(items, 4, 5) def test_group_that_is_too_low_raises_value_error(): items = [str(i) for i in range(32)] with pytest.raises(ValueError): - get_group(items, 8, 0) + # When group_count=4, group_id=0 is out of bounds + get_group(items, 4, 0)