diff --git a/pyproject.toml b/pyproject.toml index 3de08a88..9ce04cf4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ dev = [ "pydata-sphinx-theme>=0.12", "pytest", "pytest-cov", - "pytest-lazy-fixture", "sphinx-autobuild", "sphinx-copybutton", "sphinx-design", diff --git a/src/pytac/data/DIAD/simple_devices.csv b/src/pytac/data/DIAD/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/DIAD/simple_devices.csv +++ b/src/pytac/data/DIAD/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/DIADSP/simple_devices.csv b/src/pytac/data/DIADSP/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/DIADSP/simple_devices.csv +++ b/src/pytac/data/DIADSP/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/DIADTHz/simple_devices.csv b/src/pytac/data/DIADTHz/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/DIADTHz/simple_devices.csv +++ b/src/pytac/data/DIADTHz/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/I04/simple_devices.csv b/src/pytac/data/I04/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/I04/simple_devices.csv +++ b/src/pytac/data/I04/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/I04SP/simple_devices.csv b/src/pytac/data/I04SP/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/I04SP/simple_devices.csv +++ b/src/pytac/data/I04SP/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/I04THz/simple_devices.csv b/src/pytac/data/I04THz/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/I04THz/simple_devices.csv +++ b/src/pytac/data/I04THz/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/SRI0913_MOGA/simple_devices.csv b/src/pytac/data/SRI0913_MOGA/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/SRI0913_MOGA/simple_devices.csv +++ b/src/pytac/data/SRI0913_MOGA/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/VMX/simple_devices.csv b/src/pytac/data/VMX/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/VMX/simple_devices.csv +++ b/src/pytac/data/VMX/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/VMXSP/simple_devices.csv b/src/pytac/data/VMXSP/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/VMXSP/simple_devices.csv +++ b/src/pytac/data/VMXSP/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/src/pytac/data/VMXTHz/simple_devices.csv b/src/pytac/data/VMXTHz/simple_devices.csv index c84f75d9..82423fcf 100644 --- a/src/pytac/data/VMXTHz/simple_devices.csv +++ b/src/pytac/data/VMXTHz/simple_devices.csv @@ -1,2 +1,2 @@ el_id,field,value,readonly -0,energy,3e9,true +0,energy,3000000000,True diff --git a/utils/load_mml.m b/src/pytac/data/utils/load_mml.m similarity index 98% rename from utils/load_mml.m rename to src/pytac/data/utils/load_mml.m index 2394be94..0d5514a0 100644 --- a/utils/load_mml.m +++ b/src/pytac/data/utils/load_mml.m @@ -18,7 +18,7 @@ function load_mml(ringmode) dir = fileparts(mfilename('fullpath')); cd(dir); - datadir = fullfile(dir, '..', 'pytac', 'data', ringmode); + datadir = fullfile(dir, '..', ringmode); if ~exist(datadir, 'dir') fprintf('Data directory %s does not exist. Please create it.\n', datadir); fprintf('Script will exit.\n'); @@ -43,7 +43,7 @@ function load_mml(ringmode) ao = getao(); % Hard-coded beam energy value. - fprintf(f_simple_devices, '0,energy,3e9,true\n'); + fprintf(f_simple_devices, '0,energy,3e9,True\n'); % The individual BPM PVs are not stored in middlelayer. BPMS = get_bpm_pvs(ao); diff --git a/utils/load_unitconv.m b/src/pytac/data/utils/load_unitconv.m similarity index 95% rename from utils/load_unitconv.m rename to src/pytac/data/utils/load_unitconv.m index d40addcb..d201e0b0 100644 --- a/utils/load_unitconv.m +++ b/src/pytac/data/utils/load_unitconv.m @@ -1,9 +1,17 @@ function load_unitconv(ringmode, renamedIndexes) dir = fileparts(mfilename('fullpath')); cd(dir); -units_file = fullfile(dir, '..', 'pytac', 'data', ringmode, 'unitconv.csv'); -poly_file = fullfile(dir, '..', 'pytac', 'data', ringmode, 'uc_poly_data.csv'); -pchip_file = fullfile(dir, '..', 'pytac', 'data', ringmode, 'uc_pchip_data.csv'); +datadir = fullfile(dir, '..', ringmode); +if ~exist(datadir, 'dir') + fprintf('Data directory %s does not exist. Please create it.\n', datadir); + fprintf('Script will exit.\n'); + return; +end + +% Open the CSV files that store the Pytac data. +units_file = fullfile(datadir, 'unitconv.csv'); +poly_file = fullfile(datadir, 'uc_poly_data.csv'); +pchip_file = fullfile(datadir, 'uc_pchip_data.csv'); fprintf('Loading unit conversions...\n'); diff --git a/src/pytac/data_source.py b/src/pytac/data_source.py index 53949d73..ec75fa74 100644 --- a/src/pytac/data_source.py +++ b/src/pytac/data_source.py @@ -1,4 +1,5 @@ """Module containing pytac data source classes.""" + import pytac from pytac.exceptions import DataSourceException, FieldException diff --git a/src/pytac/device.py b/src/pytac/device.py index f385f39b..c139f0f8 100644 --- a/src/pytac/device.py +++ b/src/pytac/device.py @@ -5,6 +5,7 @@ DLS is a sextupole magnet that contains also horizontal and vertical corrector magnets and a skew quadrupole. """ + from typing import List, Union import pytac diff --git a/src/pytac/element.py b/src/pytac/element.py index a3632298..160ce8f1 100644 --- a/src/pytac/element.py +++ b/src/pytac/element.py @@ -1,4 +1,5 @@ """Module containing the element class.""" + import pytac from pytac.data_source import DataSource, DataSourceManager from pytac.exceptions import DataSourceException, FieldException @@ -89,13 +90,13 @@ def __str__(self): """ repn = "" return repn __repr__ = __str__ diff --git a/src/pytac/lattice.py b/src/pytac/lattice.py index 86e0c949..033ca142 100644 --- a/src/pytac/lattice.py +++ b/src/pytac/lattice.py @@ -1,6 +1,7 @@ """Representation of a lattice object which contains all the elements of the machine. """ + import logging from typing import List, Optional diff --git a/src/pytac/load_csv.py b/src/pytac/load_csv.py index 6c7eba09..2bb8b4c9 100644 --- a/src/pytac/load_csv.py +++ b/src/pytac/load_csv.py @@ -9,23 +9,23 @@ * uc_poly_data.csv * uc_pchip_data.csv """ + +import ast import collections import contextlib import copy import csv +import logging from pathlib import Path from typing import Dict, Iterator import pytac from pytac import data_source, element, utils from pytac.device import EpicsDevice, SimpleDevice -from pytac.exceptions import ControlSystemException +from pytac.exceptions import ControlSystemException, UnitsException from pytac.lattice import EpicsLattice, Lattice from pytac.units import NullUnitConv, PchipUnitConv, PolyUnitConv, UnitConv -# Create a default unit conversion object that returns the input unchanged. -DEFAULT_UC = NullUnitConv() - ELEMENTS_FILENAME = "elements.csv" EPICS_DEVICES_FILENAME = "epics_devices.csv" SIMPLE_DEVICES_FILENAME = "simple_devices.csv" @@ -86,6 +86,52 @@ def load_pchip_unitconv(filepath: Path) -> Dict[int, PchipUnitConv]: return unitconvs +def resolve_unitconv( + uc_params: Dict, unitconvs: Dict, polyconv_file: Path, pchipconv_file: Path +) -> UnitConv: + """Create a unit conversion object based on the dictionary of parameters passed. + + Args: + uc_params (Dict): A dictionary of parameters specifying the unit conversion + object's properties. + unitconvs (Dict): A dictionary of all loaded unit conversion objects. + polyconv_file (Path): The path to the .csv file from which all PolyUnitConv + objects are loaded. + pchipconv_file (Path): The path to the .csv file from which all PchipUnitConv + objects are loaded. + Returns: + UnitConv: The unit conversion object as specified by uc_params. + + Raises: + UnitsException: if the "uc_id" given in uc_params isn't in the unitconvs Dict. + """ + error_msg = ( + f"Unable to resolve {uc_params['uc_type']} unit conversion with ID " + f"{uc_params['uc_id']}, " + ) + if uc_params["uc_type"] == "null": + uc = NullUnitConv(uc_params["eng_units"], uc_params["phys_units"]) + else: + # Each element needs its own UnitConv object as it may have different limits. + try: + uc = copy.copy(unitconvs[int(uc_params["uc_id"])]) + except KeyError: + if uc_params["uc_type"] == "poly" and not polyconv_file.exists(): + raise UnitsException(error_msg + f"{polyconv_file} not found.") + elif uc_params["uc_type"] == "pchip" and not pchipconv_file.exists(): + raise UnitsException(error_msg + f"{pchipconv_file} not found.") + else: + raise UnitsException(error_msg + "unrecognised UnitConv type.") + uc.phys_units = uc_params["phys_units"] + uc.eng_units = uc_params["eng_units"] + lower, upper = [ + float(lim) if lim != "" else None + for lim in [uc_params["lower_lim"], uc_params["upper_lim"]] + ] + uc.set_conversion_limits(lower, upper) + return uc + + def load_unitconv(mode_dir: Path, lattice: Lattice) -> None: """Load the unit conversion objects from a file. @@ -95,56 +141,38 @@ def load_unitconv(mode_dir: Path, lattice: Lattice) -> None: """ unitconvs: Dict[int, UnitConv] = {} # Assemble datasets from the polynomial file - unitconvs.update(load_poly_unitconv(mode_dir / POLY_FILENAME)) + polyconv_file = mode_dir / POLY_FILENAME + if polyconv_file.exists(): + unitconvs.update(load_poly_unitconv(polyconv_file)) + else: + logging.warning(f"{polyconv_file} not found, unable to load PolyUnitConvs.") # Assemble datasets from the pchip file - unitconvs.update(load_pchip_unitconv(mode_dir / PCHIP_FILENAME)) + pchipconv_file = mode_dir / PCHIP_FILENAME + if pchipconv_file.exists(): + unitconvs.update(load_pchip_unitconv(pchipconv_file)) + else: + logging.warning(f"{pchipconv_file} not found, unable to load PchipUnitConvs.") # Add the unitconv objects to the elements with csv_loader(mode_dir / UNITCONV_FILENAME) as csv_reader: for item in csv_reader: + uc = resolve_unitconv(item, unitconvs, polyconv_file, pchipconv_file) # Special case for element 0: the lattice itself. if int(item["el_id"]) == 0: - if item["uc_type"] != "null": - # Each element needs its own unitconv object as - # it may for example have different limit. - uc = copy.copy(unitconvs[int(item["uc_id"])]) - uc.phys_units = item["phys_units"] - uc.eng_units = item["eng_units"] - upper, lower = ( - float(lim) if lim != "" else None - for lim in [item["upper_lim"], item["lower_lim"]] - ) - uc.set_conversion_limits(lower, upper) - else: - uc = NullUnitConv(item["eng_units"], item["phys_units"]) lattice.set_unitconv(item["field"], uc) else: element = lattice[int(item["el_id"]) - 1] # For certain magnet types, we need an additional rigidity # conversion factor as well as the raw conversion. - if item["uc_type"] == "null": - uc = NullUnitConv(item["eng_units"], item["phys_units"]) - else: - # Each element needs its own unitconv object as - # it may for example have different limit. - uc = copy.copy(unitconvs[int(item["uc_id"])]) - if any( - element.is_in_family(f) - for f in ("HSTR", "VSTR", "Quadrupole", "Sextupole", "Bend") - ): - energy = lattice.get_value("energy", units=pytac.PHYS) - uc.set_post_eng_to_phys(utils.get_div_rigidity(energy)) - uc.set_pre_phys_to_eng(utils.get_mult_rigidity(energy)) - uc.phys_units = item["phys_units"] - uc.eng_units = item["eng_units"] - upper, lower = ( - float(lim) if lim != "" else None - for lim in [item["upper_lim"], item["lower_lim"]] - ) - uc.set_conversion_limits(lower, upper) + # TODO: This should probably be moved into the .csv files somewhere. + rigidity_families = {"hstr", "vstr", "quadrupole", "sextupole", "bend"} + if item["uc_type"] != "null" and element._families & rigidity_families: + energy = lattice.get_value("energy", units=pytac.PHYS) + uc.set_post_eng_to_phys(utils.get_div_rigidity(energy)) + uc.set_pre_phys_to_eng(utils.get_mult_rigidity(energy)) element.set_unitconv(item["field"], uc) -def load(mode, control_system=None, directory=None, symmetry=None): +def load(mode, control_system=None, directory=None, symmetry=None) -> EpicsLattice: """Load the elements of a lattice from a directory. Args: @@ -173,9 +201,8 @@ def load(mode, control_system=None, directory=None, symmetry=None): control_system = cothread_cs.CothreadControlSystem() except ImportError: raise ControlSystemException( - "Please install cothread to load a " - "lattice using the default control " - "system (found in cothread_cs.py)." + "Please install cothread to load a lattice using the default control system" + " (found in cothread_cs.py)." ) if directory is None: directory = Path(__file__).resolve().parent / "data" @@ -191,31 +218,44 @@ def load(mode, control_system=None, directory=None, symmetry=None): lat.add_element(e) with csv_loader(mode_dir / EPICS_DEVICES_FILENAME) as csv_reader: for item in csv_reader: - name = item["name"] index = int(item["el_id"]) get_pv = item["get_pv"] if item["get_pv"] else None set_pv = item["set_pv"] if item["set_pv"] else None - pve = True - d = EpicsDevice(name, control_system, pve, get_pv, set_pv) # Devices on index 0 are attached to the lattice not elements. target = lat if index == 0 else lat[index - 1] - target.add_device(item["field"], d, DEFAULT_UC) + # Create with a default UnitConv that returns the input unchanged. + target.add_device( # type: ignore[attr-defined] + item["field"], + EpicsDevice(item["name"], control_system, rb_pv=get_pv, sp_pv=set_pv), + NullUnitConv(), + ) # Add basic devices to the lattice. positions = [] - for elem in lat: + for elem in lat: # type: ignore[attr-defined] positions.append(elem.s) - lat.add_device("s_position", SimpleDevice(positions, readonly=True), True) + lat.add_device( + "s_position", SimpleDevice(positions, readonly=True), NullUnitConv() + ) simple_devices_file = mode_dir / SIMPLE_DEVICES_FILENAME if simple_devices_file.exists(): with csv_loader(simple_devices_file) as csv_reader: for item in csv_reader: index = int(item["el_id"]) - field = item["field"] - value = float(item["value"]) - readonly = item["readonly"].lower() == "true" + try: + readonly = ast.literal_eval(item["readonly"]) + assert isinstance(readonly, bool) + except (ValueError, AssertionError): + raise ValueError( + f"Unable to evaluate {item['readonly']} as a boolean." + ) # Devices on index 0 are attached to the lattice not elements. target = lat if index == 0 else lat[index - 1] - target.add_device(field, SimpleDevice(value, readonly=readonly), True) + # Create with a default UnitConv that returns the input unchanged. + target.add_device( # type: ignore[attr-defined] + item["field"], + SimpleDevice(float(item["value"]), readonly=readonly), + NullUnitConv(), + ) with csv_loader(mode_dir / FAMILIES_FILENAME) as csv_reader: for item in csv_reader: lat[int(item["el_id"]) - 1].add_to_family(item["family"]) diff --git a/src/pytac/units.py b/src/pytac/units.py index fc8ac1fe..93f9e6f3 100644 --- a/src/pytac/units.py +++ b/src/pytac/units.py @@ -1,4 +1,5 @@ """Classes for use in unit conversion.""" + import numpy from scipy.interpolate import PchipInterpolator diff --git a/src/pytac/utils.py b/src/pytac/utils.py index bccab0ba..251f83ba 100644 --- a/src/pytac/utils.py +++ b/src/pytac/utils.py @@ -1,4 +1,5 @@ """Utility functions.""" + import math import scipy.constants diff --git a/tests/conftest.py b/tests/conftest.py index 1438b8eb..d5097151 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -178,3 +178,18 @@ def simple_epics_lattice(simple_epics_element, mock_cs, unit_uc): lat.add_device("x", x_device, unit_uc) lat.add_device("y", y_device, unit_uc) return lat + + +@pytest.fixture +def mode_dir(): + return CURRENT_DIR_PATH / "data/dummy" + + +@pytest.fixture +def polyconv_file(mode_dir): + return mode_dir / load_csv.POLY_FILENAME + + +@pytest.fixture +def pchipconv_file(mode_dir): + return mode_dir / load_csv.PCHIP_FILENAME diff --git a/tests/data/dummy/unitconv.csv b/tests/data/dummy/unitconv.csv new file mode 100644 index 00000000..7b91c79b --- /dev/null +++ b/tests/data/dummy/unitconv.csv @@ -0,0 +1,3 @@ +el_id,field,uc_type,uc_id,phys_units,eng_units,lower_lim,upper_lim +2,b1,null,1,m^-2,A,0,200 +4,b2,null,2,m^-3,A,-100,100 diff --git a/tests/test_cothread_cs.py b/tests/test_cothread_cs.py index 999a8a69..ca8b6ac7 100644 --- a/tests/test_cothread_cs.py +++ b/tests/test_cothread_cs.py @@ -4,6 +4,7 @@ See pytest_sessionstart() in conftest.py for more. """ + import pytest from constants import RB_PV, SP_PV from cothread.catools import ca_nothing, caget, caput diff --git a/tests/test_data_source.py b/tests/test_data_source.py index e73fb4df..45532e1c 100644 --- a/tests/test_data_source.py +++ b/tests/test_data_source.py @@ -1,82 +1,57 @@ import pytest from constants import DUMMY_VALUE_2 -from pytest_lazyfixture import lazy_fixture import pytac @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_device(simple_object, y_device): +def test_get_device(simple_object, y_device, request): + simple_object = request.getfixturevalue(simple_object) assert simple_object.get_device("y") == y_device @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_unitconv(simple_object, unit_uc): +def test_get_unitconv(simple_object, unit_uc, request): + simple_object = request.getfixturevalue(simple_object) assert simple_object.get_unitconv("x") == unit_uc @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_fields(simple_object): +def test_get_fields(simple_object, request): + simple_object = request.getfixturevalue(simple_object) fields = simple_object.get_fields()[pytac.LIVE] assert set(fields) == {"x", "y"} @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_set_value(simple_object): +def test_set_value(simple_object, request): + simple_object = request.getfixturevalue(simple_object) simple_object.set_value("x", DUMMY_VALUE_2, pytac.ENG, pytac.LIVE) simple_object.get_device("x").set_value.assert_called_with(DUMMY_VALUE_2, True) @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_get_value_sim(simple_object): +def test_get_value_sim(simple_object, request): + simple_object = request.getfixturevalue(simple_object) assert ( simple_object.get_value("x", pytac.RB, pytac.PHYS, pytac.SIM) == DUMMY_VALUE_2 ) @pytest.mark.parametrize( - "simple_object", - [ - lazy_fixture("simple_element"), - lazy_fixture("simple_lattice"), - lazy_fixture("simple_data_source_manager"), - ], + "simple_object", ["simple_element", "simple_lattice", "simple_data_source_manager"] ) -def test_unit_conversion(simple_object, double_uc): +def test_unit_conversion(simple_object, double_uc, request): + simple_object = request.getfixturevalue(simple_object) simple_object.set_value("y", DUMMY_VALUE_2, pytac.PHYS, pytac.LIVE) simple_object.get_device("y").set_value.assert_called_with(DUMMY_VALUE_2 / 2, True) diff --git a/tests/test_load.py b/tests/test_load.py index da7208e3..b7d91381 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -1,9 +1,10 @@ from unittest.mock import patch import pytest +from testfixtures import LogCapture import pytac -from pytac.load_csv import load +from pytac.load_csv import load, load_unitconv, resolve_unitconv @pytest.fixture @@ -70,3 +71,62 @@ def test_families_loaded(lattice): ["drift", "sext", "quad", "ds", "qf", "qs", "sd"] ) assert lattice.get_elements("quad")[0].families == set(["quad", "qf", "qs"]) + + +def test_load_unitconv_warns_if_pchip_or_poly_data_file_not_found( + lattice, mode_dir, polyconv_file, pchipconv_file +): + with LogCapture() as log: + load_unitconv(mode_dir, lattice) + log.check( + ( + "root", + "WARNING", + f"{polyconv_file} not found, unable to load PolyUnitConvs.", + ), + ( + "root", + "WARNING", + f"{pchipconv_file} not found, unable to load PchipUnitConvs.", + ), + ) + + +def test_resolve_unitconv_raises_UnitsException_if_pchip_or_poly_data_file_not_found( + polyconv_file, pchipconv_file +): + uc_params = { + "uc_type": "poly", + "uc_id": 1, + "phys_units": "m^-2", + "eng_units": "A", + "lower_lim": 0, + "upper_lim": 200, + } + with pytest.raises(pytac.exceptions.UnitsException): + resolve_unitconv(uc_params, {}, polyconv_file, pchipconv_file) + uc_params = { + "uc_type": "pchip", + "uc_id": 2, + "phys_units": "m^-3", + "eng_units": "A", + "lower_lim": -100, + "upper_lim": 100, + } + with pytest.raises(pytac.exceptions.UnitsException): + resolve_unitconv(uc_params, {}, polyconv_file, pchipconv_file) + + +def test_resolve_unitconv_raises_UnitsException_if_unrecognised_UnitConv_type( + polyconv_file, pchipconv_file +): + uc_params = { + "uc_type": "unrecognised", + "uc_id": 0, + "phys_units": "", + "eng_units": "", + "lower_lim": 0, + "upper_lim": 0, + } + with pytest.raises(pytac.exceptions.UnitsException): + resolve_unitconv(uc_params, {}, polyconv_file, pchipconv_file) diff --git a/tests/test_machine.py b/tests/test_machine.py index 933306e4..14b7c001 100644 --- a/tests/test_machine.py +++ b/tests/test_machine.py @@ -2,12 +2,12 @@ files in the data directory. These are more like integration tests, and allows us to check that the pytac setup is working correctly. """ + import re from unittest import mock import numpy import pytest -from pytest_lazyfixture import lazy_fixture import pytac @@ -27,22 +27,18 @@ def test_load_lattice_using_default_dir(): @pytest.mark.parametrize( "lattice, name, n_elements, length", - [ - (lazy_fixture("vmx_ring"), "VMX", 2142, 561.571), - (lazy_fixture("diad_ring"), "DIAD", 2144, 561.571), - ], + [("vmx_ring", "VMX", 2142, 561.571), ("diad_ring", "DIAD", 2144, 561.571)], ) -def test_load_lattice(lattice, name, n_elements, length): +def test_load_lattice(lattice, name, n_elements, length, request): + lattice = request.getfixturevalue(lattice) assert len(lattice) == n_elements assert lattice.name == name assert (lattice.get_length() - length) < EPS -@pytest.mark.parametrize( - "lattice, n_bpms", - [(lazy_fixture("vmx_ring"), 173), (lazy_fixture("diad_ring"), 173)], -) -def test_get_pv_names(lattice, n_bpms): +@pytest.mark.parametrize("lattice, n_bpms", [("vmx_ring", 173), ("diad_ring", 173)]) +def test_get_pv_names(lattice, n_bpms, request): + lattice = request.getfixturevalue(lattice) bpm_x_pvs = lattice.get_element_pv_names("BPM", "x", handle="readback") assert len(bpm_x_pvs) == n_bpms for pv in bpm_x_pvs: @@ -55,11 +51,9 @@ def test_get_pv_names(lattice, n_bpms): assert re.match("SR.*HBPM.*SLOW:DISABLED", pv) -@pytest.mark.parametrize( - "lattice, n_bpms", - [(lazy_fixture("vmx_ring"), 173), (lazy_fixture("diad_ring"), 173)], -) -def test_load_bpms(lattice, n_bpms): +@pytest.mark.parametrize("lattice, n_bpms", [("vmx_ring", 173), ("diad_ring", 173)]) +def test_load_bpms(lattice, n_bpms, request): + lattice = request.getfixturevalue(lattice) bpms = lattice.get_elements("BPM") bpm_fields = { "x", @@ -80,20 +74,16 @@ def test_load_bpms(lattice, n_bpms): assert bpms[-1].cell == 24 -@pytest.mark.parametrize( - "lattice, n_drifts", - [(lazy_fixture("vmx_ring"), 1308), (lazy_fixture("diad_ring"), 1311)], -) -def test_load_drift_elements(lattice, n_drifts): +@pytest.mark.parametrize("lattice, n_drifts", [("vmx_ring", 1308), ("diad_ring", 1311)]) +def test_load_drift_elements(lattice, n_drifts, request): + lattice = request.getfixturevalue(lattice) drifts = lattice.get_elements("DRIFT") assert len(drifts) == n_drifts -@pytest.mark.parametrize( - "lattice, n_quads", - [(lazy_fixture("vmx_ring"), 248), (lazy_fixture("diad_ring"), 248)], -) -def test_load_quadrupoles(lattice, n_quads): +@pytest.mark.parametrize("lattice, n_quads", [("vmx_ring", 248), ("diad_ring", 248)]) +def test_load_quadrupoles(lattice, n_quads, request): + lattice = request.getfixturevalue(lattice) quads = lattice.get_elements("Quadrupole") assert len(quads) == n_quads for quad in quads: @@ -104,10 +94,10 @@ def test_load_quadrupoles(lattice, n_quads): @pytest.mark.parametrize( - "lattice, n_q1b, n_q1d", - [(lazy_fixture("vmx_ring"), 34, 12), (lazy_fixture("diad_ring"), 34, 12)], + "lattice, n_q1b, n_q1d", [("vmx_ring", 34, 12), ("diad_ring", 34, 12)] ) -def test_load_quad_family(lattice, n_q1b, n_q1d): +def test_load_quad_family(lattice, n_q1b, n_q1d, request): + lattice = request.getfixturevalue(lattice) q1b = lattice.get_elements("Q1B") assert len(q1b) == n_q1b q1d = lattice.get_elements("Q1D") @@ -115,10 +105,10 @@ def test_load_quad_family(lattice, n_q1b, n_q1d): @pytest.mark.parametrize( - "lattice, n_correctors", - [(lazy_fixture("vmx_ring"), 173), (lazy_fixture("diad_ring"), 172)], + "lattice, n_correctors", [("vmx_ring", 173), ("diad_ring", 172)] ) -def test_load_correctors(lattice, n_correctors): +def test_load_correctors(lattice, n_correctors, request): + lattice = request.getfixturevalue(lattice) hcm = lattice.get_elements("HSTR") vcm = lattice.get_elements("VSTR") assert len(hcm) == n_correctors @@ -135,11 +125,9 @@ def test_load_correctors(lattice, n_correctors): ) -@pytest.mark.parametrize( - "lattice, n_squads", - [(lazy_fixture("vmx_ring"), 98), (lazy_fixture("diad_ring"), 98)], -) -def test_load_squads(lattice, n_squads): +@pytest.mark.parametrize("lattice, n_squads", [("vmx_ring", 98), ("diad_ring", 98)]) +def test_load_squads(lattice, n_squads, request): + lattice = request.getfixturevalue(lattice) squads = lattice.get_elements("SQUAD") assert len(squads) == n_squads for squad in squads: @@ -149,21 +137,19 @@ def test_load_squads(lattice, n_squads): assert re.match("SR.*SQ.*:SETI", device.sp_pv) -@pytest.mark.parametrize( - "lattice", (lazy_fixture("diad_ring"), lazy_fixture("vmx_ring")) -) -def test_cell(lattice): +@pytest.mark.parametrize("lattice", ["diad_ring", "vmx_ring"]) +def test_cell(lattice, request): + lattice = request.getfixturevalue(lattice) # there are squads in every cell sq = lattice.get_elements("SQUAD") assert sq[0].cell == 1 assert sq[-1].cell == 24 -@pytest.mark.parametrize( - "lattice", (lazy_fixture("diad_ring"), lazy_fixture("vmx_ring")) -) +@pytest.mark.parametrize("lattice", ["diad_ring", "vmx_ring"]) @pytest.mark.parametrize("field", ("x", "y")) -def test_bpm_unitconv(lattice, field): +def test_bpm_unitconv(lattice, field, request): + lattice = request.getfixturevalue(lattice) bpm = lattice.get_elements("BPM")[0] uc = bpm._data_source_manager._uc[field] @@ -171,6 +157,15 @@ def test_bpm_unitconv(lattice, field): assert uc.phys_to_eng(2) == 2000 +def test_hstr_unitconv(vmx_ring): + # From MML: hw2physics('HTRIM', 'Monitor', 2.5, [1]) + htrim = vmx_ring.get_elements("HTRIM")[0] + # This test depends on the lattice having an energy of 3000Mev. + uc = htrim._data_source_manager._uc["x_kick"] + numpy.testing.assert_allclose(uc.eng_to_phys(2.5), 0.0001925) + numpy.testing.assert_allclose(uc.phys_to_eng(0.0001925), 2.5) + + def test_quad_unitconv(vmx_ring): # From MML: hw2physics('Q1D', 'Monitor', 70, [1]) q1d = vmx_ring.get_elements("Q1D")