From b2e7748ebf102de0204313bfde4eb15234aebe0d Mon Sep 17 00:00:00 2001 From: Simon Humpohl Date: Mon, 11 May 2020 11:38:19 +0200 Subject: [PATCH] Add central function for duration comparability test (cherry picked from commit 14f114a0218ae3efcc405ba2f63e9dddcfdb1fc7) # Conflicts: # qupulse/_program/waveforms.py --- qupulse/_program/waveforms.py | 8 +++- .../pulses/multi_channel_pulse_template.py | 23 +++++++--- qupulse/utils/numeric.py | 43 ++++++++++++++++++- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/qupulse/_program/waveforms.py b/qupulse/_program/waveforms.py index 701db1408..c628da328 100644 --- a/qupulse/_program/waveforms.py +++ b/qupulse/_program/waveforms.py @@ -17,6 +17,9 @@ from qupulse import ChannelID from qupulse._program.transformation import Transformation +from qupulse.utils import checked_int_cast, isclose +from qupulse.utils.types import TimeType, time_from_float +from qupulse.utils.numeric import are_durations_compatible from qupulse.comparable import Comparable from qupulse.expressions import ExpressionScalar from qupulse.pulses.interpolation import InterpolationStrategy @@ -464,8 +467,9 @@ def get_sub_waveform_sort_key(waveform): waveform.defined_channels & self.__defined_channels) self.__defined_channels |= waveform.defined_channels - if not all(isclose(waveform.duration, self._sub_waveforms[0].duration) for waveform in self._sub_waveforms[1:]): - # meaningful error message: + durations = [subwaveform.duration for subwaveform in self._sub_waveforms] + if not are_durations_compatible(*durations): + # generate a useful error message durations = {} for waveform in self._sub_waveforms: diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index 75784eb7a..f9e42c526 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -17,6 +17,7 @@ from qupulse.utils import isclose from qupulse.utils.sympy import almost_equal, Sympifyable from qupulse.utils.types import ChannelID, TimeType +from qupulse.utils.numeric import are_durations_compatible from qupulse._program.waveforms import MultiChannelWaveform, Waveform, TransformingWaveform from qupulse._program.transformation import ParallelConstantChannelTransformation, Transformation, chain_transformations from qupulse.pulses.pulse_template import PulseTemplate, AtomicPulseTemplate @@ -87,13 +88,21 @@ def __init__(self, category=DeprecationWarning) if not duration: - duration = self._subtemplates[0].duration - for subtemplate in self._subtemplates[1:]: - if almost_equal(duration.sympified_expression, subtemplate.duration.sympified_expression): - continue - else: - raise ValueError('Could not assert duration equality of {} and {}'.format(duration, - subtemplate.duration)) + durations = [subtemplate.duration for subtemplate in subtemplates] + are_compatible = are_durations_compatible(*durations) + + if are_compatible is False: + # durations definitely not compatible + raise ValueError('Could not assert duration equality of {} and {}'.format(repr(duration), + repr(subtemplate.duration))) + elif are_compatible is None: + # cannot assert compatibility + raise ValueError('Could not assert duration equality of {} and {}'.format(repr(duration), + repr(subtemplate.duration))) + + else: + assert are_compatible is True + self._duration = None elif duration is True: self._duration = None diff --git a/qupulse/utils/numeric.py b/qupulse/utils/numeric.py index 53c640bbf..caf8431d8 100644 --- a/qupulse/utils/numeric.py +++ b/qupulse/utils/numeric.py @@ -1,5 +1,5 @@ -from typing import Tuple, Type -from numbers import Rational +from typing import Tuple, Type, Optional +from numbers import Rational, Real from math import gcd @@ -98,3 +98,42 @@ def approximate_rational(x: Rational, abs_err: Rational, fraction_type: Type[Rat def approximate_double(x: float, abs_err: float, fraction_type: Type[Rational]) -> Rational: """Return the fraction with the smallest denominator in (x - abs_err, x + abs_err).""" return approximate_rational(fraction_type(x), fraction_type(abs_err), fraction_type=fraction_type) + + +def are_durations_compatible(first_duration: Real, *other_durations: Real, + max_abs_spread=1e-10, max_rel_spread=1e-10) -> Optional[bool]: + """Durations and maximum allowed spreads must be positive. + + For the durations to be considered compatible, the difference between them must be smaller than at least one of + the allowed spreads. + + Args: + first_duration: Singled out duration for performance reasons. Not handled differently by the algorithm. + *other_durations: Other durations to compare for compatibility + max_abs_spread: Maximum difference for being considered "compatible", regardless of the magnitude of the input + max_rel_spread: maximum difference for being considered "compatible", relative to the magnitude of the + maximum input duration + + Returns: + True or False if decidable else None + """ + min_duration = max_duration = first_duration + for duration in other_durations: + min_duration = min(min_duration, duration) + max_duration = max(max_duration, duration) + assert 0 < max_duration, "At least one duration must be positive" + # spread = max_duration - min_duration + # allowed_spread = max(max_rel_spread * max_duration, max_abs_spread) + are_compatible = max_duration - min_duration < max(max_rel_spread * max_duration, max_abs_spread) + if are_compatible in (False, True): + return are_compatible + + # durations are sympy expressions with clear ordering + elif are_compatible.is_Boolean: + return bool(are_compatible) + + else: + # Not decidable + return None + +