diff --git a/lasy/profiles/transverse/__init__.py b/lasy/profiles/transverse/__init__.py index 1a6109ae..781f10f0 100644 --- a/lasy/profiles/transverse/__init__.py +++ b/lasy/profiles/transverse/__init__.py @@ -4,6 +4,11 @@ from .super_gaussian_profile import SuperGaussianTransverseProfile from .jinc_profile import JincTransverseProfile from .transverse_profile_from_data import TransverseProfileFromData +from .transverse_profile import ( + TransverseProfile, + SummedTransverseProfile, + ScaledTransverseProfile, +) __all__ = [ "GaussianTransverseProfile", @@ -12,4 +17,7 @@ "SuperGaussianTransverseProfile", "JincTransverseProfile", "TransverseProfileFromData", + "TransverseProfile", + "SummedTransverseProfile", + "ScaledTransverseProfile", ] diff --git a/lasy/profiles/transverse/transverse_profile.py b/lasy/profiles/transverse/transverse_profile.py index a30fdc93..4f4c3328 100644 --- a/lasy/profiles/transverse/transverse_profile.py +++ b/lasy/profiles/transverse/transverse_profile.py @@ -34,6 +34,18 @@ def _evaluate(self, x, y): # (This should be replaced by any class that inherits from this one.) return np.zeros(x.shape, dtype="complex128") + def __add__(self, other): + """Return the sum of two transverse profiles.""" + return SummedTransverseProfile(self, other) + + def __mul__(self, factor): + """Return the scaled transverse profile.""" + return ScaledTransverseProfile(self, factor) + + def __rmul__(self, factor): + """Return the scaled transverse profile.""" + return ScaledTransverseProfile(self, factor) + def evaluate(self, x, y): """ Return the transverse envelope modified by any spatial offsets. @@ -72,3 +84,62 @@ def set_offset(self, x_offset, y_offset): self.y_offset = y_offset return self + + +class SummedTransverseProfile(TransverseProfile): + """ + Base class for transverse profiles that are the sum of several other transverse profiles. + + Transverse Profile class that represents the sum of multiple transverse profiles. + + Parameters + ---------- + transverse_profiles: list of TransverseProfile objects + List of transverse profiles to be summed. + """ + + def __init__(self, *transverse_profiles): + """Initialize the summed profile.""" + TransverseProfile.__init__(self) + # Check that all transverse_profiles are TransverseProfile objects + assert all( + [isinstance(tp, TransverseProfile) for tp in transverse_profiles] + ), "All summands must be Profile objects." + self.transverse_profiles = transverse_profiles + + def evaluate(self, x, y): + """Return the envelope field of the summed profile.""" + # Sum the fields of each profile + return sum([tp.evaluate(x, y) for tp in self.transverse_profiles]) + + +class ScaledTransverseProfile(TransverseProfile): + """ + Base class for transverse profiles that are scaled by a factor. + + Transverse Profile class that represents scaled transverse profiles. + + Parameters + ---------- + transverse_profile: TrasnverseProfile object + Trasnverse profile to be scaled. + factor: int or float + Factor by which to scale the profile. + """ + + def __init__(self, transverse_profile, factor): + """Initialize the summed profile.""" + TransverseProfile.__init__(self) + # Check that the factor is a number + assert isinstance(factor, (int, float, complex)), "The factor must be a number." + # Check that the profile is a Profile object + assert isinstance( + transverse_profile, TransverseProfile + ), "The profile must be a TransverseProfile object." + self.transverse_profile = transverse_profile + self.factor = factor + + def evaluate(self, x, y): + """Return the envelope field of the scaled transverse profile.""" + # Sum the fields of each profile + return self.transverse_profile.evaluate(x, y) * self.factor diff --git a/tests/test_laser_profiles.py b/tests/test_laser_profiles.py index 2073743e..6d16bb1e 100644 --- a/tests/test_laser_profiles.py +++ b/tests/test_laser_profiles.py @@ -15,6 +15,9 @@ HermiteGaussianTransverseProfile, JincTransverseProfile, TransverseProfileFromData, + TransverseProfile, + SummedTransverseProfile, + ScaledTransverseProfile, ) from lasy.utils.exp_data_utils import find_center_of_mass @@ -32,6 +35,19 @@ def evaluate(self, x, y, t): return np.ones_like(x, dtype="complex128") * self.value +class MockTransverseProfile(TransverseProfile): + """ + A mock TransverseProfile class that always returns a constant value. + """ + + def __init__(self, value): + super().__init__() + self.value = value + + def evaluate(self, x, y): + return np.ones_like(x, dtype="complex128") * self.value + + @pytest.fixture(scope="function") def gaussian(): # Cases with Gaussian laser @@ -255,3 +271,46 @@ def test_scale_error_if_not_scalar(): profile_1 * profile_1 with pytest.raises(AssertionError): profile_1 * [1.0, 2.0] + + +def test_add_transverse_profiles(): + # Add the two profiles together + trans_profile_1 = MockTransverseProfile(1.0) + trans_profile_2 = MockTransverseProfile(2.0) + summed_trans_profile = trans_profile_1 + trans_profile_2 + # Check that the result is a SummedTransverseProfile object + assert isinstance(summed_trans_profile, SummedTransverseProfile) + # Check that the profiles are stored correctly + assert summed_trans_profile.transverse_profiles[0] == trans_profile_1 + assert summed_trans_profile.transverse_profiles[1] == trans_profile_2 + # Check that the evaluate method works + assert np.allclose(summed_trans_profile.evaluate(0, 0), 3.0) + + +def test_add_transverse_error_if_not_all_transverse_profiles(): + trans_profile_1 = MockTransverseProfile(1.0) + with pytest.raises(AssertionError): + trans_profile_1 + 1.0 + + +def test_scale_transverse_profiles(): + # Add the two profiles together + trans_profile_1 = MockTransverseProfile(1.0) + scaled_trans_profile = 2.0 * trans_profile_1 + scaled_trans_profile_right = trans_profile_1 * 2.0 + # Check that the result is a ScaledProfile object + assert isinstance(scaled_trans_profile, ScaledTransverseProfile) + assert isinstance(scaled_trans_profile_right, ScaledTransverseProfile) + # Check that the profiles are stored correctly + assert scaled_trans_profile.transverse_profile == trans_profile_1 + # Check that the evaluate method works + assert np.allclose(scaled_trans_profile.evaluate(0, 0), 2.0) + assert np.allclose(scaled_trans_profile.evaluate(0, 0), 2.0) + + +def test_scale_trans_error_if_not_scalar(): + trans_profile_1 = MockTransverseProfile(1.0) + with pytest.raises(AssertionError): + trans_profile_1 * trans_profile_1 + with pytest.raises(AssertionError): + trans_profile_1 * [1.0, 2.0]