diff --git a/tests/test_test.py b/tests/test_test.py index 92163428..03798392 100644 --- a/tests/test_test.py +++ b/tests/test_test.py @@ -37,6 +37,11 @@ def socket() -> FakeSocket: return FakeSocket(1) +@pytest.fixture +def sub_socket() -> FakeSocket: + return FakeSocket(2) + + def test_socket_unbind(socket: FakeSocket): socket.bind("abc") socket.unbind() @@ -50,10 +55,9 @@ def test_socket_disconnect(socket: FakeSocket): @pytest.mark.parametrize("topic", ("string", b"bytes")) -def test_socket_subscribe(socket: FakeSocket, topic): - socket.socket_type = 2 - socket.subscribe(topic) - assert isinstance(socket._subscriptions[-1], bytes) +def test_socket_subscribe(sub_socket: FakeSocket, topic): + sub_socket.subscribe(topic) + assert isinstance(sub_socket._subscriptions[-1], bytes) def test_subscribe_fails_for_not_SUB(socket: FakeSocket): @@ -61,11 +65,11 @@ def test_subscribe_fails_for_not_SUB(socket: FakeSocket): socket.subscribe("abc") -@pytest.mark.parametrize("topic", ("string", b"bytes")) -def test_socket_unsubscribe(socket: FakeSocket, topic): - socket.socket_type = 2 - socket.unsubscribe(topic) - assert isinstance(socket._subscriptions[-1], bytes) +@pytest.mark.parametrize("topic", ("topic", b"topic")) +def test_socket_unsubscribe(sub_socket: FakeSocket, topic): + sub_socket._subscriptions.append(b"topic") + sub_socket.unsubscribe(topic) + assert b"topic" not in sub_socket._subscriptions def test_unsubscribe_fails_for_not_SUB(socket: FakeSocket): diff --git a/tests/utils/test_extended_message_handler.py b/tests/utils/test_extended_message_handler.py index dbb00d72..c1db857d 100644 --- a/tests/utils/test_extended_message_handler.py +++ b/tests/utils/test_extended_message_handler.py @@ -115,14 +115,14 @@ def test_handle_pickled_message(self, handler_hfl: ExtendedMessageHandler): handler_hfl.handle_full_legacy_subscription_message( DataMessage("topic", data=pickle.dumps(data), message_type=234) ) - handler_hfl.handle_subscription_data.assert_called_once_with({"topic": data}) + handler_hfl.handle_subscription_data.assert_called_once_with({"topic": data}) # type: ignore def test_handle_json_message(self, handler_hfl: ExtendedMessageHandler): data = ["some", "data", 5] handler_hfl.handle_full_legacy_subscription_message( DataMessage("topic", data=json.dumps(data), message_type=235) ) - handler_hfl.handle_subscription_data.assert_called_once_with({"topic": data}) + handler_hfl.handle_subscription_data.assert_called_once_with({"topic": data}) # type: ignore def test_handle_unknown_message_type(self, handler_hfl: ExtendedMessageHandler): with pytest.raises(ValueError):