Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closes #215 , add overloading for addition and multiplication operators in TransverseProfile #216

Merged
Merged
8 changes: 8 additions & 0 deletions lasy/profiles/transverse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from .super_gaussian_profile import SuperGaussianTransverseProfile
from .jinc_profile import JincTransverseProfile
from .transverse_profile_from_data import TransverseProfileFromData
from .transverse_profile import (
TransverseProfile,
SummedTransverseProfile,
ScaledTransverseProfile,
)

__all__ = [
"GaussianTransverseProfile",
Expand All @@ -12,4 +17,7 @@
"SuperGaussianTransverseProfile",
"JincTransverseProfile",
"TransverseProfileFromData",
"TransverseProfile",
"SummedTransverseProfile",
"ScaledTransverseProfile",
]
71 changes: 71 additions & 0 deletions lasy/profiles/transverse/transverse_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def _evaluate(self, x, y):
# (This should be replaced by any class that inherits from this one.)
return np.zeros(x.shape, dtype="complex128")

def __add__(self, other):
"""Return the sum of two transverse profiles."""
return SummedTransverseProfile(self, other)

def __mul__(self, factor):
"""Return the scaled transverse profile."""
return ScaledTransverseProfile(self, factor)

def __rmul__(self, factor):
"""Return the scaled transverse profile."""
return ScaledTransverseProfile(self, factor)

def evaluate(self, x, y):
"""
Return the transverse envelope modified by any spatial offsets.
Expand Down Expand Up @@ -72,3 +84,62 @@ def set_offset(self, x_offset, y_offset):
self.y_offset = y_offset

return self


class SummedTransverseProfile(TransverseProfile):
"""
Base class for transverse profiles that are the sum of several other transverse profiles.

Transverse Profile class that represents the sum of multiple transverse profiles.

Parameters
----------
transverse_profiles: list of TransverseProfile objects
List of transverse profiles to be summed.
"""

def __init__(self, *transverse_profiles):
"""Initialize the summed profile."""
TransverseProfile.__init__(self)
# Check that all transverse_profiles are TransverseProfile objects
assert all(
[isinstance(tp, TransverseProfile) for tp in transverse_profiles]
), "All summands must be Profile objects."
self.transverse_profiles = transverse_profiles

def evaluate(self, x, y):
"""Return the envelope field of the summed profile."""
# Sum the fields of each profile
return sum([tp.evaluate(x, y) for tp in self.transverse_profiles])


class ScaledTransverseProfile(TransverseProfile):
"""
Base class for transverse profiles that are scaled by a factor.

Transverse Profile class that represents scaled transverse profiles.

Parameters
----------
transverse_profile: TrasnverseProfile object
Trasnverse profile to be scaled.
factor: int or float
Factor by which to scale the profile.
"""

def __init__(self, transverse_profile, factor):
"""Initialize the summed profile."""
TransverseProfile.__init__(self)
# Check that the factor is a number
assert isinstance(factor, (int, float, complex)), "The factor must be a number."
# Check that the profile is a Profile object
assert isinstance(
transverse_profile, TransverseProfile
), "The profile must be a TransverseProfile object."
self.transverse_profile = transverse_profile
self.factor = factor

def evaluate(self, x, y):
"""Return the envelope field of the scaled transverse profile."""
# Sum the fields of each profile
return self.transverse_profile.evaluate(x, y) * self.factor
58 changes: 58 additions & 0 deletions tests/test_laser_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
HermiteGaussianTransverseProfile,
JincTransverseProfile,
TransverseProfileFromData,
TransverseProfile,
SummedTransverseProfile,
ScaledTransverseProfile,
)
from lasy.utils.exp_data_utils import find_center_of_mass

Expand All @@ -32,6 +35,18 @@
return np.ones_like(x, dtype="complex128") * self.value


class MockTransverseProfile(TransverseProfile):
"""
A mock TransverseProfile class that always returns a constant value.
"""

def __init__(self, value):
self.value = value

def evaluate(self, x, y):
return np.ones_like(x, dtype="complex128") * self.value


@pytest.fixture(scope="function")
def gaussian():
# Cases with Gaussian laser
Expand Down Expand Up @@ -255,3 +270,46 @@
profile_1 * profile_1
with pytest.raises(AssertionError):
profile_1 * [1.0, 2.0]


def test_add_transverse_profiles():
# Add the two profiles together
trans_profile_1 = MockTransverseProfile(1.0)
trans_profile_2 = MockTransverseProfile(2.0)
summed_trans_profile = trans_profile_1 + trans_profile_2
# Check that the result is a SummedTransverseProfile object
assert isinstance(summed_trans_profile, SummedTransverseProfile)
# Check that the profiles are stored correctly
assert summed_trans_profile.transverse_profiles[0] == trans_profile_1
assert summed_trans_profile.transverse_profiles[1] == trans_profile_2
# Check that the evaluate method works
assert np.allclose(summed_trans_profile.evaluate(0, 0), 3.0)


def test_add_transverse_error_if_not_all_transverse_profiles():
trans_profile_1 = MockTransverseProfile(1.0)
with pytest.raises(AssertionError):
trans_profile_1 + 1.0


def test_scale_transverse_profiles():
# Add the two profiles together
trans_profile_1 = MockTransverseProfile(1.0)
scaled_trans_profile = 2.0 * trans_profile_1
scaled_trans_profile_right = trans_profile_1 * 2.0
# Check that the result is a ScaledProfile object
assert isinstance(scaled_trans_profile, ScaledTransverseProfile)
assert isinstance(scaled_trans_profile_right, ScaledTransverseProfile)
# Check that the profiles are stored correctly
assert scaled_trans_profile.transverse_profile == trans_profile_1
# Check that the evaluate method works
assert np.allclose(scaled_trans_profile.evaluate(0, 0), 2.0)
assert np.allclose(scaled_trans_profile.evaluate(0, 0), 2.0)


def test_scale_trans_error_if_not_scalar():
trans_profile_1 = MockTransverseProfile(1.0)
with pytest.raises(AssertionError):
trans_profile_1 * trans_profile_1
with pytest.raises(AssertionError):
trans_profile_1 * [1.0, 2.0]
Loading