diff --git a/CHANGELOG.md b/CHANGELOG.md index fe31d76d..116de53e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ sequences in toolkit (`#141`). * Add support for session wide transactions that bundle set command from all devices connected to the data server. (`#134`) +* Add `from_existing_connection()` to `zhinst.toolkit.Session` to help reusing the existing DataServer connection. * Bugfix: Nodes with nameless options don't raise an exception when their enum attribute is called (`#165`). * Bugfix: Values of enumerated nodes can now be pickled (`#129`). * Bugfix: `SHFScope` `run()` and `stop()` shows specified timeout value when `TimeoutError` is raised. diff --git a/src/zhinst/toolkit/driver/modules/shfqa_sweeper.py b/src/zhinst/toolkit/driver/modules/shfqa_sweeper.py index d1d8e17e..95044c8e 100644 --- a/src/zhinst/toolkit/driver/modules/shfqa_sweeper.py +++ b/src/zhinst/toolkit/driver/modules/shfqa_sweeper.py @@ -56,7 +56,7 @@ class SHFQASweeper(Node): session: Session to the Data Server. """ - def __init__(self, daq_server: ziDAQServer, session: "Session"): + def __init__(self, session: "Session"): self._config_classes = { SweepConfig: ("sweep", "sweep_config"), RfConfig: ("rf", "rf_config"), @@ -70,8 +70,12 @@ def __init__(self, daq_server: ziDAQServer, session: "Session"): "force_sw_trigger": "sw_trigger_mode", } super().__init__(self._create_nodetree(), tuple()) - self._daq_server = daq_server - self._raw_module = CoreSweeper(daq_server, "") + self._daq_server = ziDAQServer( + session.daq_server.host, + session.daq_server.port, + 6, + ) + self._raw_module = CoreSweeper(self._daq_server, "") self._session = session self.root.update_nodes( { diff --git a/src/zhinst/toolkit/session.py b/src/zhinst/toolkit/session.py index 692d4e4c..dbf554fb 100644 --- a/src/zhinst/toolkit/session.py +++ b/src/zhinst/toolkit/session.py @@ -221,13 +221,11 @@ class ModuleHandler: server_port: Port of the session """ - def __init__(self, session: "Session", server_host: str, server_port: int): + def __init__(self, session: "Session"): self._session = session - self._server_host = server_host - self._server_port = server_port def __repr__(self): - return f"LabOneModules({self._server_host}:{self._server_port})" + return f"LabOneModules({self._session.daq_server.host}:{self._session.daq_server.port})" def create_awg_module(self) -> tk_modules.BaseModule: """Create an instance of the AwgModule. @@ -457,14 +455,7 @@ def create_shfqa_sweeper(self) -> tk_modules.SHFQASweeper: Returns: Created object """ - return tk_modules.SHFQASweeper( - core.ziDAQServer( - self._server_host, - self._server_port, - 6, - ), - self._session, - ) + return tk_modules.SHFQASweeper(self._session) @lazy_property def awg(self) -> tk_modules.BaseModule: @@ -663,34 +654,33 @@ class Session(Node): def __init__( self, server_host: str, - server_port: int = None, + server_port: t.Optional[int] = None, *, - hf2: bool = None, - connection: core.ziDAQServer = None, + hf2: t.Optional[bool] = None, + connection: t.Optional[core.ziDAQServer] = None, ): self._is_hf2_server = bool(hf2) - self._server_host = server_host - self._server_port = server_port if server_port else 8004 if connection is not None: self._is_hf2_server = "HF2" in connection.getString("/zi/about/dataserver") if hf2 and not self._is_hf2_server: raise RuntimeError( - "hf2_server Flag was set but the passed " - "DAQServer instance is no HF2 data server." + "hf2 flag was set but the passed " + "DAQServer instance is not a HF2 data server." ) if hf2 is False and self._is_hf2_server: raise RuntimeError( - "hf2_server Flag was reset but the passed " + "hf2 flag was set but the passed " "DAQServer instance is a HF2 data server." ) self._daq_server = connection else: - if self._is_hf2_server and self._server_port == 8004: - self._server_port = 8005 + server_port = server_port if server_port else 8004 + if self._is_hf2_server and server_port == 8004: + server_port = 8005 try: self._daq_server = core.ziDAQServer( - self._server_host, - self._server_port, + server_host, + server_port, 1 if self._is_hf2_server else 6, ) except RuntimeError as error: @@ -699,14 +689,14 @@ def __init__( if hf2 is None: self._is_hf2_server = True self._daq_server = core.ziDAQServer( - self._server_host, - self._server_port, + server_host, + server_port, 1, ) elif not hf2: raise RuntimeError( - "hf2_server Flag was reset but the specified " - f"server at {self._server_host}:{self._server_port} is a " + "hf2 Flag was reset but the specified " + f"server at {server_host}:{server_port} is a " "HF2 data server." ) from error @@ -714,16 +704,16 @@ def __init__( "/zi/about/dataserver" ): raise RuntimeError( - "hf2_server Flag was set but the specified " - f"server at {self._server_host}:{self._server_port} is not a " + "hf2 Flag was set but the specified " + f"server at {server_host}:{server_port} is not a " "HF2 data server." ) self._devices = HF2Devices(self) if self._is_hf2_server else Devices(self) - self._modules = ModuleHandler(self, self._server_host, self._server_port) + self._modules = ModuleHandler(self) hf2_node_doc = Path(__file__).parent / "resources/nodedoc_hf2_data_server.json" nodetree = NodeTree( - self.daq_server, + self._daq_server, prefix_hide="zi", list_nodes=["/zi/*"], preloaded_json=json.loads(hf2_node_doc.open("r").read()) @@ -736,7 +726,24 @@ def __init__( def __repr__(self): return str( f"{'HF2' if self._is_hf2_server else ''}DataServerSession(" - f"{self._server_host}:{self._server_port})" + f"{self._daq_server.host}:{self._daq_server.port})" + ) + + @classmethod + def from_existing_connection(cls, connection: core.ziDAQServer) -> "Session": + """Initialize Session from an existing connection. + + Args: + connection: Existing connection. + + .. versionadded:: 0.3.6 + """ + is_hf2_server = "HF2" in connection.getString("/zi/about/dataserver") + return cls( + server_host=connection.host, + server_port=connection.port, + hf2=is_hf2_server, + connection=connection, ) def connect_device( @@ -969,9 +976,9 @@ def daq_server(self) -> core.ziDAQServer: @property def server_host(self) -> str: """Server host.""" - return self._server_host + return self._daq_server.host @property def server_port(self) -> int: """Server port.""" - return self._server_port + return self._daq_server.port diff --git a/tests/conftest.py b/tests/conftest.py index 208e6d09..5479dac2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ from pathlib import Path -from unittest.mock import patch +from unittest.mock import patch, PropertyMock import pytest @@ -17,6 +17,9 @@ def data_dir(request): @pytest.fixture() def mock_connection(): with patch("zhinst.toolkit.session.core.ziDAQServer", autospec=True) as connection: + type(connection.return_value).port = PropertyMock(return_value=8004) + type(connection.return_value).host = PropertyMock(return_value="localhost") + type(connection.return_value).api_level = PropertyMock(return_value=6) yield connection @@ -36,6 +39,7 @@ def session(nodedoc_zi_json, mock_connection): @pytest.fixture() def hf2_session(data_dir, mock_connection): mock_connection.return_value.getString.return_value = "HF2DataServer" + type(mock_connection.return_value).port = PropertyMock(return_value=8005) yield Session("localhost", hf2=True) @@ -92,3 +96,11 @@ def nodedoc_dev1234_json(data_dir): json_path = data_dir / "nodedoc_dev1234.json" with json_path.open("r", encoding="UTF-8") as file: return file.read() + + +@pytest.fixture() +def mock_sweeper_daq(): + with patch( + "zhinst.toolkit.driver.modules.shfqa_sweeper.ziDAQServer", autospec=True + ) as connection: + yield connection diff --git a/tests/test_awg.py b/tests/test_awg.py index 79845dbe..d8057fa8 100644 --- a/tests/test_awg.py +++ b/tests/test_awg.py @@ -76,10 +76,7 @@ def test_load_sequencer_program(mock_connection, shfsg): mock_connection.return_value.set.call_args[0][0] == "/dev1234/sgchannels/0/awg/elf/data" ) - assert all( - mock_connection.return_value.set.call_args[0][1] - == np.frombuffer(elf, dtype="uint32") - ) + assert mock_connection.return_value.set.call_args[0][1] == elf assert info == info_original # Compiler error @@ -109,10 +106,7 @@ def test_load_sequencer_program_qc(mock_connection, shfqc): mock_connection.return_value.set.call_args[0][0] == "/dev1234/sgchannels/0/awg/elf/data" ) - assert all( - mock_connection.return_value.set.call_args[0][1] - == np.frombuffer(elf, dtype="uint32") - ) + assert mock_connection.return_value.set.call_args[0][1] == elf assert info == info_original diff --git a/tests/test_data_server_session.py b/tests/test_data_server_session.py index 79471b2d..caa25e04 100644 --- a/tests/test_data_server_session.py +++ b/tests/test_data_server_session.py @@ -396,7 +396,7 @@ def test_sweeper_module(data_dir, mock_connection, session): assert isinstance(sweeper_module.device, Node) -def test_shfqa_sweeper(session): +def test_shfqa_sweeper(session, mock_sweeper_daq): sweeper = session.modules.shfqa_sweeper assert sweeper == session.modules.shfqa_sweeper assert isinstance(sweeper, SHFQASweeper) diff --git a/tests/test_generator.py b/tests/test_generator.py index 0447aed2..2f67a549 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -78,10 +78,7 @@ def test_load_sequencer_program(shfqa, generator, mock_connection): mock_connection.return_value.set.call_args[0][0] == "/dev1234/qachannels/0/generator/elf/data" ) - assert all( - mock_connection.return_value.set.call_args[0][1] - == np.frombuffer(elf, dtype="uint32") - ) + assert mock_connection.return_value.set.call_args[0][1] == elf assert info == info_original def test_write_to_waveform_memory(shfqa, generator, mock_connection): diff --git a/tests/test_shfqa_sweeper.py b/tests/test_shfqa_sweeper.py index 2b4d243f..48a6605e 100644 --- a/tests/test_shfqa_sweeper.py +++ b/tests/test_shfqa_sweeper.py @@ -33,28 +33,30 @@ def mock_TriggerConfig(): @pytest.fixture() -def sweeper_module(mock_connection, mock_shf_sweeper, session): - yield SHFQASweeper(mock_connection.return_value, session) +def sweeper_module(session, mock_sweeper_daq, mock_shf_sweeper): + yield SHFQASweeper(session) def test_repr(sweeper_module): assert "SHFQASweeper(DataServerSession(localhost:8004))" in repr(sweeper_module) -def test_missing_node(mock_connection, mock_shf_sweeper, session): +def test_missing_node(mock_connection, mock_shf_sweeper, mock_sweeper_daq, session): with patch( "zhinst.toolkit.driver.modules.shfqa_sweeper.SweepConfig", make_dataclass("Y", fields=[("s", str, 0)], bases=(SweepConfig,)), ) as sweeper_config: - sweeper_module = SHFQASweeper(mock_connection.return_value, session) + sweeper_module = SHFQASweeper(session) assert sweeper_module.sweep.s() == 0 -def test_device(mock_connection, sweeper_module, session, mock_shf_sweeper): +def test_device( + mock_connection, sweeper_module, session, mock_shf_sweeper, mock_sweeper_daq +): assert sweeper_module.device() == "" sweeper_module.device("dev1234") - mock_shf_sweeper.assert_called_with(mock_connection.return_value, "dev1234") + mock_shf_sweeper.assert_called_with(mock_sweeper_daq(), "dev1234") assert sweeper_module.device() == "dev1234" connected_devices = "dev1234" @@ -71,15 +73,17 @@ def get_string_side_effect(arg): assert sweeper_module.device() == session.devices["dev1234"] -def test_update_settings(mock_connection, sweeper_module, mock_shf_sweeper): +def test_update_settings( + mock_connection, sweeper_module, mock_shf_sweeper, mock_sweeper_daq +): assert not sweeper_module.envelope.enable() - mock_shf_sweeper.assert_called_with(mock_connection.return_value, "") + mock_shf_sweeper.assert_called_with(sweeper_module._daq_server, "") # device needs to be set first with pytest.raises(RuntimeError) as e_info: sweeper_module._update_settings() sweeper_module.device("dev1234") - mock_shf_sweeper.assert_called_with(mock_connection.return_value, "dev1234") + mock_shf_sweeper.assert_called_with(mock_sweeper_daq(), "dev1234") sweeper_module._update_settings() mock_shf_sweeper.return_value.configure.assert_called_once() assert "sweep_config" in mock_shf_sweeper.return_value.configure.call_args[1]