Skip to content

Commit

Permalink
Update test_select_n.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rkazants authored Dec 31, 2024
1 parent 07aa342 commit b596253
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions tests/layer_tests/jax_tests/test_select_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
class TestSelectN(JaxLayerTest):
def _prepare_input(self):
cases = []
if(self.case_num == 2):
if (self.case_num == 2):
which = rng.choice([True, False], self.input_shape)
else:
which = rng.uniform(0,self.case_num, self.input_shape).astype(self.input_type)
which = rng.uniform(0, self.case_num, self.input_shape).astype(self.input_type)
which = np.array(which)
for i in range(self.case_num):
cases.append(jnp.array(np.random.uniform(-1000, 1000, self.input_shape).astype(self.input_type)))
Expand All @@ -31,15 +31,13 @@ def create_model(self, input_shape, input_type, case_num):

def jax_select_n(which, cases):
return jax.lax.select_n(which, *cases)

return jax_select_n, None, 'select_n'


@pytest.mark.parametrize("input_shape", [[],[1],[2,3],[4,5,6],[7,8,9,10]])
@pytest.mark.parametrize("input_shape", [[], [1], [2, 3], [4, 5, 6], [7, 8, 9, 10]])
@pytest.mark.parametrize("input_type", [np.int32, np.int64])
@pytest.mark.parametrize("case_num", [2,3,4])
@pytest.mark.parametrize("case_num", [2, 3, 4])
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_jax_fe
def test_select_n(self, ie_device, precision, ir_version, input_shape, input_type, case_num):
self._test(*self.create_model(input_shape, input_type, case_num),
Expand Down

0 comments on commit b596253

Please sign in to comment.