diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 583eda76..a27afb12 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -409,26 +409,9 @@ def test_full_like(x, fill_value, kw): finite_kw = {"allow_nan": False, "allow_infinity": False} -def int_stops( - start: int, num, dtype: DataType, endpoint: bool -) -> st.SearchStrategy[int]: - min_gap = num - if endpoint: - min_gap += 1 - m, M = dh.dtype_ranges[dtype] - max_pos_gap = M - start - max_neg_gap = start - m - max_pos_mul = max_pos_gap // min_gap - max_neg_mul = max_neg_gap // min_gap - return st.one_of( - st.integers(0, max_pos_mul).map(lambda n: start + min_gap * n), - st.integers(0, max_neg_mul).map(lambda n: start - min_gap * n), - ) - - @given( num=hh.sizes, - dtype=st.none() | xps.numeric_dtypes(), + dtype=st.none() | xps.floating_dtypes(), endpoint=st.booleans(), data=st.data(), ) @@ -436,16 +419,10 @@ def test_linspace(num, dtype, endpoint, data): _dtype = dh.default_float if dtype is None else dtype start = data.draw(xps.from_dtype(_dtype, **finite_kw), label="start") - if dh.is_float_dtype(_dtype): - stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop") - # avoid overflow errors - assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype))) - assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype))) - else: - if num == 0: - stop = start - else: - stop = data.draw(int_stops(start, num, _dtype, endpoint), label="stop") + stop = data.draw(xps.from_dtype(_dtype, **finite_kw), label="stop") + # avoid overflow errors + assume(not ah.isnan(ah.asarray(stop - start, dtype=_dtype))) + assume(not ah.isnan(ah.asarray(start - stop, dtype=_dtype))) kw = data.draw( hh.specified_kwargs( diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index be4a22ca..321263d3 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -315,7 +315,7 @@ def test_matmul(x1, x2): @given( x=finite_matrices(), kw=kwargs(keepdims=booleans(), - ord=sampled_from([-float('inf'), -2, -2, 1, 2, float('inf'), 'fro', 'nuc'])) + ord=sampled_from([-float('inf'), -2, -1, 1, 2, float('inf'), 'fro', 'nuc'])) ) def test_matrix_norm(x, kw): res = linalg.matrix_norm(x, **kw)