Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for allow_version_mismatch (L1-2516) #287

Merged
merged 2 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# zhinst-toolkit Changelog

## Version 0.8.0
* The constructor of `Session` fails when attempting to connect to a data-server on a different LabOne version. This behavior can be overridden by setting the newly added allow_version_mismatch keyword argument to True. When allow_version_mismatch=True is passed to the `Session` constructor the connection to the data-server succeeds even if the version doesn't match.

## Version 0.7.0
* Add QHub driver

Expand Down
7 changes: 1 addition & 6 deletions src/zhinst/toolkit/driver/modules/shfqa_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from pathlib import Path

import numpy as np
from zhinst.core import ziDAQServer
from zhinst.utils.shf_sweeper import AvgConfig, EnvelopeConfig, RfConfig
from zhinst.utils.shf_sweeper import ShfSweeper as CoreSweeper
from zhinst.utils.shf_sweeper import SweepConfig, TriggerConfig
Expand Down Expand Up @@ -71,11 +70,7 @@ def __init__(self, session: "Session"):
"force_sw_trigger": "sw_trigger_mode",
}
super().__init__(self._create_nodetree(), tuple())
self._daq_server = ziDAQServer(
session.daq_server.host,
session.daq_server.port,
6,
)
self._daq_server = session.clone_underlying_session()
self._raw_module = CoreSweeper(self._daq_server, "")
self._session = session
self.root.update_nodes(
Expand Down
49 changes: 45 additions & 4 deletions src/zhinst/toolkit/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module for managing a session to a Data Server through zhinst.core."""

import json
import typing as t
from collections.abc import MutableMapping
Expand Down Expand Up @@ -654,6 +655,13 @@ class Session(Node):
connection: Existing DAQ server object. If specified the session will
not create a new session to the data server but reuse the passed
one. (default = None)
allow_version_mismatch: if set to True, the connection to the data-server
will succeed even if the data-server is on a different version of LabOne.
If False, an exception will be raised if the data-server is on a
different version. (default = False)

.. versionchanged:: 0.8.0
Added `allow_version_mismatch` argument.
"""

def __init__(
Expand All @@ -663,6 +671,7 @@ def __init__(
*,
hf2: t.Optional[bool] = None,
connection: t.Optional[core.ziDAQServer] = None,
allow_version_mismatch: bool = False,
):
self._is_hf2_server = bool(hf2)
if connection is not None:
Expand All @@ -683,10 +692,8 @@ def __init__(
if self._is_hf2_server and server_port == 8004:
server_port = 8005
try:
self._daq_server = core.ziDAQServer(
server_host,
server_port,
1 if self._is_hf2_server else 6,
self._daq_server = self._create_daq(
server_host, server_port, allow_version_mismatch
)
except RuntimeError as error:
if "Unsupported API level" not in error.args[0]:
Expand Down Expand Up @@ -987,3 +994,37 @@ def server_host(self) -> str:
def server_port(self) -> int:
"""Server port."""
return self._daq_server.port

def clone_underlying_session(self) -> core.ziDAQServer:
"""Create a new session to the data server.

Create a new core.ziDAQServer connected to the same data-server this
session is connected to.
"""
# Don't execute version checking. When clone_underlying_session is called,
# a connection has already been made, so checking again would be redundant.
return self._create_daq(self.server_host, self.server_port, True)

def _create_daq(
self,
server_host: str,
server_port: int,
allow_version_mismatch: bool,
):
"""Create a new session to the data server.

Attempt to pass the allow_version_mismatch flag. Fallback in case
zhinst.core does not support it yet.
"""
api_level = 1 if self._is_hf2_server else 6
try:
return core.ziDAQServer(
server_host,
server_port,
api_level,
allow_version_mismatch=allow_version_mismatch,
)
except TypeError as error:
if "allow_version_mismatch" not in error.args[0]:
raise
return core.ziDAQServer(server_host, server_port, api_level)
11 changes: 0 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def hf2_session(mock_connection):

@pytest.fixture()
def shfqa(data_dir, mock_connection, session):

json_path = data_dir / "nodedoc_dev1234_shfqa.json"
with json_path.open("r", encoding="UTF-8") as file:
nodes_json = file.read()
Expand All @@ -62,7 +61,6 @@ def shfqa(data_dir, mock_connection, session):

@pytest.fixture()
def shfsg(data_dir, mock_connection, session):

json_path = data_dir / "nodedoc_dev1234_shfsg.json"
with json_path.open("r", encoding="UTF-8") as file:
nodes_json = file.read()
Expand All @@ -74,7 +72,6 @@ def shfsg(data_dir, mock_connection, session):

@pytest.fixture()
def shfqc(data_dir, mock_connection, session):

json_path = data_dir / "nodedoc_dev1234_shfqc.json"
with json_path.open("r", encoding="UTF-8") as file:
nodes_json = file.read()
Expand All @@ -96,11 +93,3 @@ def nodedoc_dev1234_json(data_dir):
json_path = data_dir / "nodedoc_dev1234.json"
with json_path.open("r", encoding="UTF-8") as file:
return file.read()


@pytest.fixture()
def mock_sweeper_daq():
with patch(
"zhinst.toolkit.driver.modules.shfqa_sweeper.ziDAQServer", autospec=True
) as connection:
yield connection
61 changes: 54 additions & 7 deletions tests/test_data_server_session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

Expand All @@ -9,7 +9,9 @@


def test_setup(mock_connection, session):
mock_connection.assert_called_once_with("localhost", 8004, 6)
mock_connection.assert_called_once_with(
"localhost", 8004, 6, allow_version_mismatch=False
)
mock_connection.return_value.listNodesJSON.assert_called_once_with("/zi/*")
assert repr(session) == "DataServerSession(localhost:8004)"
assert not session.is_hf2_server
Expand All @@ -18,12 +20,60 @@ def test_setup(mock_connection, session):


def test_setup_hf2(mock_connection, hf2_session):
mock_connection.assert_called_once_with("localhost", 8005, 1)
mock_connection.assert_called_once_with(
"localhost", 8005, 1, allow_version_mismatch=False
)
mock_connection.return_value.listNodesJSON.assert_not_called()
assert repr(hf2_session) == "HF2DataServerSession(localhost:8005)"
assert hf2_session.is_hf2_server


def test_allow_mismatch_not_supported(mock_connection, nodedoc_zi_json):
mock_daq = MagicMock()
mock_daq.listNodesJSON.return_value = nodedoc_zi_json

def create_daq(*args, **kwargs):
if "allow_version_mismatch" in kwargs:
raise TypeError("allow_version_mismatch not recognized")
return mock_daq

mock_connection.side_effect = create_daq
Session("localhost", 8004)
mock_connection.assert_any_call("localhost", 8004, 6, allow_version_mismatch=False)
mock_connection.assert_called_with("localhost", 8004, 6)


# Passing "allow_version_mismatch" does not cause an error, even if the underlying
# zhinst.core does not recognize this flag.
def test_allow_mismatch_passed_but_not_supported(mock_connection, nodedoc_zi_json):
mock_daq = MagicMock()
mock_daq.listNodesJSON.return_value = nodedoc_zi_json

def create_daq(*args, **kwargs):
if "allow_version_mismatch" in kwargs:
raise TypeError("allow_version_mismatch not recognized")
return mock_daq

mock_connection.side_effect = create_daq
Session("localhost", 8004, allow_version_mismatch=True)
mock_connection.assert_any_call("localhost", 8004, 6, allow_version_mismatch=True)
mock_connection.assert_called_with("localhost", 8004, 6)


def test_allow_mismatch_default(mock_connection, nodedoc_zi_json):
mock_daq = MagicMock()
mock_daq.listNodesJSON.return_value = nodedoc_zi_json

def create_daq(*args, **kwargs):
return mock_daq

mock_connection.side_effect = create_daq
Session("localhost", 8004)
mock_connection.assert_called_once_with(
"localhost", 8004, 6, allow_version_mismatch=False
)


def test_existing_connection(nodedoc_zi_json, mock_connection):
mock_connection.listNodesJSON.return_value = nodedoc_zi_json
mock_connection.getString.return_value = "DataServer"
Expand Down Expand Up @@ -77,7 +127,6 @@ def test_unkown_init_error(mock_connection):
def test_connect_device(
zi_devices_json, mock_connection, session, nodedoc_dev1234_json
):

connected_devices = ""

def get_string_side_effect(arg):
Expand Down Expand Up @@ -137,7 +186,6 @@ def connect_device_side_effect(serial, _):
def test_connect_device_autodetection(
zi_devices_json, mock_connection, session, nodedoc_dev1234_json
):

connected_devices = ""
selected_interface = ""

Expand Down Expand Up @@ -478,7 +526,7 @@ def test_sweeper_module(data_dir, mock_connection, session):
assert isinstance(sweeper_module.device, Node)


def test_shfqa_sweeper(session, mock_sweeper_daq):
def test_shfqa_sweeper(session):
sweeper = session.modules.shfqa_sweeper
assert sweeper == session.modules.shfqa_sweeper
assert isinstance(sweeper, tk_modules.SHFQASweeper)
Expand All @@ -487,7 +535,6 @@ def test_shfqa_sweeper(session, mock_sweeper_daq):
def test_session_wide_transaction(
mock_connection, nodedoc_dev1234_json, session, shfqa, shfsg
):

# Hack devices into the created once
session._devices._devices = {"dev1": shfqa, "dev2": shfsg}

Expand Down
21 changes: 6 additions & 15 deletions tests/test_shfqa_sweeper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,15 @@ def mock_TriggerConfig():


@pytest.fixture()
def sweeper_module(session, mock_sweeper_daq, mock_shf_sweeper):
def sweeper_module(session, mock_shf_sweeper):
yield SHFQASweeper(session)


def test_repr(sweeper_module):
assert "SHFQASweeper(DataServerSession(localhost:8004))" in repr(sweeper_module)


def test_missing_node(mock_connection, mock_shf_sweeper, mock_sweeper_daq, session):
def test_missing_node(mock_connection, mock_shf_sweeper, session):
with patch(
"zhinst.toolkit.driver.modules.shfqa_sweeper.SweepConfig",
make_dataclass("Y", fields=[("s", str, 0)], bases=(SweepConfig,)),
Expand All @@ -50,13 +50,11 @@ def test_missing_node(mock_connection, mock_shf_sweeper, mock_sweeper_daq, sessi
assert sweeper_module.sweep.s() == 0


def test_device(
mock_connection, sweeper_module, session, mock_shf_sweeper, mock_sweeper_daq
):
def test_device(mock_connection, sweeper_module, session, mock_shf_sweeper):
assert sweeper_module.device() == ""

sweeper_module.device("dev1234")
mock_shf_sweeper.assert_called_with(mock_sweeper_daq(), "dev1234")
mock_shf_sweeper.assert_called_with(sweeper_module._daq_server, "dev1234")
assert sweeper_module.device() == "dev1234"

connected_devices = "dev1234"
Expand All @@ -73,17 +71,15 @@ def get_string_side_effect(arg):
assert sweeper_module.device() == session.devices["dev1234"]


def test_update_settings(
mock_connection, sweeper_module, mock_shf_sweeper, mock_sweeper_daq
):
def test_update_settings(sweeper_module, mock_shf_sweeper):
assert not sweeper_module.envelope.enable()
mock_shf_sweeper.assert_called_with(sweeper_module._daq_server, "")
# device needs to be set first
with pytest.raises(RuntimeError) as e_info:
sweeper_module._update_settings()

sweeper_module.device("dev1234")
mock_shf_sweeper.assert_called_with(mock_sweeper_daq(), "dev1234")
mock_shf_sweeper.assert_called_with(sweeper_module._daq_server, "dev1234")
sweeper_module._update_settings()
mock_shf_sweeper.return_value.configure.assert_called_once()
assert "sweep_config" in mock_shf_sweeper.return_value.configure.call_args[1]
Expand Down Expand Up @@ -135,7 +131,6 @@ def test_update_settings_broken(
mock_shf_sweeper,
caplog,
):

sweeper_module.device("dev1234")
sweeper_module._update_settings()
assert (
Expand All @@ -157,31 +152,27 @@ def test_update_settings_broken(


def test_run(sweeper_module, mock_shf_sweeper):

sweeper_module.device("dev1234")
sweeper_module.run()
mock_shf_sweeper.return_value.configure.assert_called_once()
mock_shf_sweeper.return_value.run.assert_called_once()


def test_get_result(sweeper_module, mock_shf_sweeper):

sweeper_module.device("dev1234")
sweeper_module.get_result()
mock_shf_sweeper.return_value.configure.assert_called_once()
mock_shf_sweeper.return_value.get_result.assert_called_once()


def test_plot(sweeper_module, mock_shf_sweeper):

sweeper_module.device("dev1234")
sweeper_module.plot()
mock_shf_sweeper.return_value.configure.assert_called_once()
mock_shf_sweeper.return_value.plot.assert_called_once()


def test_get_offset_freq_vector(sweeper_module, mock_shf_sweeper):

sweeper_module.device("dev1234")
sweeper_module.get_offset_freq_vector()
mock_shf_sweeper.return_value.configure.assert_called_once()
Expand Down
Loading