diff --git a/lighthouse/SSP/basic_ssp.py b/lighthouse/SSP/basic_ssp.py index 7422d3c..482f7b1 100644 --- a/lighthouse/SSP/basic_ssp.py +++ b/lighthouse/SSP/basic_ssp.py @@ -3,6 +3,8 @@ from time import process_time as time import matplotlib.pyplot as plt +import numpy as np +from scipy.integrate import quad from .. import utils @@ -10,54 +12,139 @@ class Basic_SSP(): def __init__( self, - isochrone: "Isochrone", + isochrone_model: "Isochrone", imf: "Initial_Mass_Function", sas: "Stellar_Atmosphere_Spectrum", ): - self.isochrone = isochrone + self.isochrone_grid = isochrone_model self.imf = imf self.sas = sas - def forward(self, metalicity, Tage, alpha) -> torch.Tensor: - isochrone = self.isochrone.get_isochrone(metalicity, Tage) + self.isochrone = None + self.spectrum = None + self._imf_weights = None + self._imf_weights_v2 = None + + + @property + def imf_weights_v2(self) -> torch.Tensor: + + # like alf + + if self._imf_weights_v2 is None: + + weights = self.imf.get_imf(self.isochrone["initial_mass"], mass_weighted=False) + weights = weights/self.imf.t0_normalization + + self._imf_weights_v2 = weights + + return self._imf_weights_v2 + + + + @property + def imf_weights(self) -> torch.Tensor: + + # like fsps + + + if self._imf_weights is None: + + initial_stellar_masses = self.isochrone["initial_mass"] + + lower_limit = self.imf.lower_limit + upper_limit = self.imf.upper_limit + + weights = torch.zeros(initial_stellar_masses.shape) + for i, mass in enumerate(initial_stellar_masses): + if initial_stellar_masses[i] < lower_limit or initial_stellar_masses[i] > upper_limit: + print("Bounds of isochrone exceed limits of full IMF") + if i == 0: + m1 = lower_limit # ala fsps aka + else: + m1 = initial_stellar_masses[i] - 0.5*(initial_stellar_masses[i] - initial_stellar_masses[i-1]) + if i == len(initial_stellar_masses) - 1: + m2 = initial_stellar_masses[i] + else: + m2 = initial_stellar_masses[i] + 0.5*(initial_stellar_masses[i+1] - initial_stellar_masses[i]) + + if m2 < m1: + print("IMF_WEIGHT WARNING: non-monotonic mass!", m1, m2, m2-m1) + continue + + if m2 == m1: + print("m2 == m1") + continue + + weights[i], error = quad(self.imf.get_imf, + m1, m2, + args=(False,) ) # i.e., not mass-weighting + + self._imf_weights = weights/self.imf.t0_normalization + + return self._imf_weights + + + + def spectral_synthesis(self, metalicity, Tage) -> torch.Tensor: + + isochrone = self.isochrone + imf_weights = self.imf_weights_v2 + + + ## https://waps.cfa.harvard.edu/MIST/README_tables.pdf + ## FSPS phase type defined as follows: + ## -1=PMS, 0=MS, 2=RGB, 3=CHeB, 4=EAGB, + ## 5=TPAGB, 6=postAGB, 9=WR # Main Sequence isochrone integration - CHOOSE = isochrone["phase"] <= 2 - spectra = torch.stack(tuple( - self.sas.get_spectrum( - tf, - lg, - metalicity, - ) for lg, tf in zip(isochrone["log_g"][CHOOSE], isochrone["Teff"][CHOOSE]) + MS = torch.logical_and(isochrone["phase"] >= 0, isochrone["phase"] <= 2) + # Horizontal Branch isochrone integration + HB = torch.logical_and(isochrone["phase"] > 2, isochrone["phase"] <= 5) + + + ms_spectra = torch.stack(tuple( + self.sas.get_spectrum(tf, lg, metalicity, + ) for lg, tf in zip(isochrone["log_g"][MS], isochrone["Teff"][MS]) )).T - spectrum = torch.zeros(spectra.shape[0]) - spectrum += torch.vmap(partial(torch.trapz, x = isochrone["initial_mass"][CHOOSE]))( - (spectra * self.imf.get_weight( - isochrone["initial_mass"][CHOOSE], - alpha, - ) * 10**isochrone["log_l"][CHOOSE]), - ) + ms_spectra *= 10**isochrone["log_l"][MS] - # Horizontal Branch isochrone integration - CHOOSE = torch.logical_and(isochrone["phase"] > 2, isochrone["phase"] <= 5) - spectra = torch.stack(tuple( - self.sas.get_spectrum( - tf, - lg, - metalicity, - ) for lg, tf in zip(isochrone["log_g"][CHOOSE], isochrone["Teff"][CHOOSE]) + hb_spectra = torch.stack(tuple( + self.sas.get_spectrum(tf, lg, metalicity, + ) for lg, tf in zip(isochrone["log_g"][HB], isochrone["Teff"][HB]) )).T - spectrum += torch.vmap(partial(torch.trapz, x = isochrone["initial_mass"][CHOOSE]))( - (spectra * self.imf.get_weight( - isochrone["initial_mass"][CHOOSE], - alpha, - ) * 10**isochrone["log_l"][CHOOSE]), - ) + hb_spectra *= 10**isochrone["log_l"][HB] + + spectrum = torch.zeros(ms_spectra.shape[0]) + + spectrum += torch.vmap(partial(torch.trapz, x = isochrone["initial_mass"][MS])) ( + imf_weights[MS]*ms_spectra + ) + spectrum += torch.vmap(partial(torch.trapz, x = isochrone["initial_mass"][HB])) ( + imf_weights[HB]*hb_spectra + ) + + # SSP in L_sun Hz^-1, CvD models in L_sun micron^-1, convert - spectrum *= utils.light_speed/self.sas.wavelength**2 + spectrum *= utils.light_speed_micron/self.sas.wavelength**2 #for detailed comparisons to alf + #spectrum *= utils.light_speed_cgs/self.sas.wavelength**2 # the correct thing to do;consistent normalization with fsps + + return spectrum + + + def forward(self, metalicity, Tage, synthesize_spectrum=True) -> torch.Tensor: + + self.isochrone = self.isochrone_grid.get_isochrone(metalicity, Tage) + + if synthesize_spectrum: + spectrum = self.spectral_synthesis(metalicity, Tage) + else: + print("HAVE NOT SYNTHEISIZED A SPECTRUM") + spectrum = None + return spectrum diff --git a/lighthouse/initial_mass_function/__init__.py b/lighthouse/initial_mass_function/__init__.py index 4bae004..d59b45e 100644 --- a/lighthouse/initial_mass_function/__init__.py +++ b/lighthouse/initial_mass_function/__init__.py @@ -1,3 +1,5 @@ from .kroupa import * +from .salpeter import * +from .two_slope_powerlaw import * from .initial_mass_function import * diff --git a/lighthouse/initial_mass_function/initial_mass_function.py b/lighthouse/initial_mass_function/initial_mass_function.py index 16dc14d..f8512ee 100644 --- a/lighthouse/initial_mass_function/initial_mass_function.py +++ b/lighthouse/initial_mass_function/initial_mass_function.py @@ -1,14 +1,38 @@ from abc import ABC, abstractmethod -from torch import Tensor +# from torch import Tensor + +import torch + +from scipy.integrate import quad +import numpy as np + __all__ = ("Initial_Mass_Function", ) class Initial_Mass_Function(ABC): - @abstractmethod - def get_weight(self, mass) -> Tensor: - ... + def __init__(self): + + self.lower_limit = torch.tensor(0.08, dtype=torch.float64) + self.upper_limit = torch.tensor(100., dtype=torch.float64) + + + self._t0_normalization = None + @abstractmethod - def to(self, dtype=None, device=None): - ... + def get_imf(self, mass, mass_weighted=False) -> torch.Tensor: + pass + + + @property + def t0_normalization(self): ## TODO: change this name + + if self._t0_normalization is None: + self._t0_normalization = quad(self.get_imf, + self.lower_limit, + self.upper_limit, + args=(True,) )[0] + + return self._t0_normalization + diff --git a/lighthouse/initial_mass_function/kroupa.py b/lighthouse/initial_mass_function/kroupa.py index 5ada94c..bc4b980 100644 --- a/lighthouse/initial_mass_function/kroupa.py +++ b/lighthouse/initial_mass_function/kroupa.py @@ -1,14 +1,18 @@ import torch + from .initial_mass_function import Initial_Mass_Function __all__ = ("Kroupa", ) class Kroupa(Initial_Mass_Function): - def get_weight(self, mass, alpha) -> torch.Tensor: + def get_imf(self, mass, mass_weighted=False) -> torch.Tensor: + + alpha = torch.tensor([1.3, 2.3, 2.3]) + mass = torch.tensor(mass) - weight = torch.where( + imf = torch.where( mass < 0.5, mass**(-alpha[0]), # mass < 0.5 torch.where( @@ -18,8 +22,12 @@ def get_weight(self, mass, alpha) -> torch.Tensor: ) ) - return weight + if mass_weighted: + return mass*imf + else: + return imf + def to(self, dtype=None, device=None): pass @@ -30,8 +38,8 @@ def to(self, dtype=None, device=None): K = Kroupa() M = torch.linspace(0.1, 100, 1000) - W = K.get_weight(M, torch.tensor([1.3, 2.3, 2.7])) + IMF = K.get_imf(M) - plt.plot(torch.log10(M), torch.log10(W)) + plt.plot(torch.log10(M), torch.log10(IMF)) plt.show() diff --git a/lighthouse/initial_mass_function/salpeter.py b/lighthouse/initial_mass_function/salpeter.py new file mode 100644 index 0000000..9f84151 --- /dev/null +++ b/lighthouse/initial_mass_function/salpeter.py @@ -0,0 +1,33 @@ +import torch + + + + +from .initial_mass_function import Initial_Mass_Function + +__all__ = ("Salpeter", ) + +class Salpeter(Initial_Mass_Function): + + def get_imf(self, mass, mass_weighted=False) -> torch.Tensor: + + salpeter_index = torch.tensor(2.35, dtype = torch.float64) + imf = mass**(-salpeter_index) + + if mass_weighted: + imf = imf*mass + + return imf + + +if __name__ == "__main__": + + import matplotlib.pyplot as plt + K = Salpeter() + + M = torch.linspace(0.08, 100, 1000) + IMF = K.get_imf(M) + + plt.plot(torch.log10(M), torch.log10(IMF)) + plt.show() + diff --git a/lighthouse/initial_mass_function/two_slope_powerlaw.py b/lighthouse/initial_mass_function/two_slope_powerlaw.py new file mode 100644 index 0000000..5b55342 --- /dev/null +++ b/lighthouse/initial_mass_function/two_slope_powerlaw.py @@ -0,0 +1,38 @@ +import torch + + +from .initial_mass_function import Initial_Mass_Function + +__all__ = ("Two_Slope_Powerlaw", ) + +class Two_Slope_Powerlaw(Initial_Mass_Function): + + """ + Fixed lower mass cutoff + Fixed break + """ + + def get_imf(self, mass, mass_weighted=False, alpha1=1.0, alpah2=1.0) -> torch.Tensor: + + alpha = torch.tensor([alpha1, alpha2, -2.30]) + mass = torch.tensor(mass) + + imf = torch.where( + mass < 0.5, + mass**(-alpha[0]), # mass < 0.5 + torch.where( + mass < 1.0, + 0.5**(-alpha[0] + alpha[1]) * mass**(-alpha[1]), # 0.5 <= mass < 1.0 + 0.5**(-alpha[0] + alpha[1]) * mass**(-alpha[2]), # mass >= 1.0 + ) + ) + + if mass_weighted: + return mass*imf + else: + return imf + + +if __name__ == "__main__": + + pass \ No newline at end of file diff --git a/lighthouse/isochrone/MIST_Isochrone.py b/lighthouse/isochrone/MIST_Isochrone.py index 11bf384..4c3ab6c 100644 --- a/lighthouse/isochrone/MIST_Isochrone.py +++ b/lighthouse/isochrone/MIST_Isochrone.py @@ -14,23 +14,45 @@ class MIST(Isochrone): - def __init__(self, iso_file = 'MIST_v1.2_vvcrit0.0_basic_isos.hdf5'): - directory_path = Path(__file__) - data_path = Path(directory_path.parent, 'data/MIST/') + def __init__(self, iso_file = 'MIST_v1.2_vvcrit0.4_basic_isos.hdf5'): + data_path = Path(os.environ['LightHouse_HOME'], 'lighthouse/data/MIST/') + with h5py.File(os.path.join(data_path, iso_file), 'r') as f: self.isochrone_grid = torch.tensor(f["isochrone_grid"][:], dtype = torch.float64) self.metallicities = torch.tensor(f["metallicities"][:], dtype = torch.float64) - self.ages = torch.tensor(f["ages"][:], dtype = torch.float64) + self.log10ages = torch.tensor(f["ages"][:], dtype = torch.float64) self.param_order = list(p.decode("UTF-8") for p in f["parameters"][:]) def get_isochrone(self, metallicity, age, *args, low_m_limit = 0.08, high_m_limit = 100) -> dict: - metallicity_index = torch.clamp(torch.sum(self.metallicities < metallicity) - 1, 0) - age_index = torch.clamp(torch.sum(self.ages < age) - 1, 0) # TODO: figure out a better way later + age_step = torch.tensor(0.05, dtype=torch.float64) + + metallicity = torch.tensor(metallicity, dtype = torch.float64) + age = torch.tensor(age, dtype = torch.float64) + + + + + + metallicity_index = torch.isclose(self.metallicities, metallicity, 1e-2).nonzero(as_tuple=False).squeeze() + + print(torch.log10(age)) + tmp = torch.ceil(torch.log10(age) / age_step) + age = tmp*age_step + + print(age) + print(self.log10ages) + print(torch.isclose(self.log10ages, age, 1e-3) ) + age_index = torch.isclose(self.log10ages, age, 1e-3).nonzero(as_tuple=False).squeeze() + + + isochrone = self.isochrone_grid[metallicity_index, age_index].clone() #TODO: do we need to be worried about copy vs deep copy kind of situation here? + bad_phase_mask = ( (isochrone[3] != 6) & (isochrone[2] >= low_m_limit) & (isochrone[2] <= high_m_limit) ) + isochrone = isochrone[:, bad_phase_mask] - isochrone = self.isochrone_grid[metallicity_index, age_index] - isochrone = isochrone[:,isochrone[3] > -999] + bad_values = (isochrone[3] > -999) + isochrone = isochrone[:, bad_values] return dict((p, isochrone[i]) for i, p in enumerate(self.param_order)) diff --git a/lighthouse/isochrone/get_isochrones.py b/lighthouse/isochrone/get_isochrones.py index 5a12dcb..18179b7 100644 --- a/lighthouse/isochrone/get_isochrones.py +++ b/lighthouse/isochrone/get_isochrones.py @@ -17,19 +17,16 @@ __all__ = ("get_mist_isochrones", ) -def get_mist_isochrones(saveto = None, iso_version = 'MIST_v1.2_vvcrit0.0_basic_isos.txz', url='https://waps.cfa.harvard.edu/MIST/data/tarballs_v1.2/{}'): +def get_mist_isochrones(saveto = None, iso_version = 'MIST_v1.2_vvcrit0.4_basic_isos.txz', url='https://waps.cfa.harvard.edu/MIST/data/tarballs_v1.2/{}'): # Collect isochrone data from the internet ###################################################################### import requests # Path to where MIST data will live - if saveto is None: - directory_path = Path(__file__) #Path().absolute() - data_path = Path(directory_path.parent, 'data/MIST/') - else: - data_path = saveto - + data_path = Path(os.environ['LightHouse_HOME'], 'lighthouse/data/MIST/') + + # Ensure the directoty exists to place the files try: os.mkdir(data_path.parent) @@ -39,7 +36,7 @@ def get_mist_isochrones(saveto = None, iso_version = 'MIST_v1.2_vvcrit0.0_basic_ # Specific file path for the requested version of MIST file_path = os.path.join(data_path, iso_version) - + # Skip download if files already exit if not os.path.exists(os.path.splitext(file_path)[0]): # Pull the isochrone files from the internet @@ -50,12 +47,12 @@ def get_mist_isochrones(saveto = None, iso_version = 'MIST_v1.2_vvcrit0.0_basic_ print("Writing MIST") with open(file_path, 'wb') as f: f.write(r.content) - + # Extract the tar file into the individual .iso files print("Extracting MIST") with tarfile.open(file_path) as T: T.extractall(path = data_path) - + # Remove the old tar file, no longer needed print("Deleting tar file") os.remove(file_path) @@ -69,69 +66,60 @@ def get_mist_isochrones(saveto = None, iso_version = 'MIST_v1.2_vvcrit0.0_basic_ # Run through all the files and collect information about metalicities and ages metallicities = [] longest_track = 0 - for isochrone_file in isochrone_files: + for isochrone_file in isochrone_files: isochrone = ISO(isochrone_file, verbose=False) - + metallicities.append(isochrone.abun['[Fe/H]']) - ages = [round(x, 2) for x in isochrone.ages] for age in isochrone.ages: i = isochrone.age_index(age) - j = np.where((isochrone.isos[i]['phase'] != 6) & - (isochrone.isos[i]['initial_mass'] >= 0.08) & - (isochrone.isos[i]['initial_mass'] <= 100.) - ) - tracklength = len(isochrone.isos[i]['log_g'][j]) + tracklength = len(isochrone.isos[i]['log_g']) if longest_track < tracklength: longest_track = tracklength - - metallicities = np.array(list(sorted(metallicities))) - ages = np.array(ages) - isochrone_grid = np.zeros((len(isochrone_files), len(ages), 5, longest_track)) - 999 - metallicities_order = np.argsort(metallicities) + + isochrone_grid = np.zeros((len(isochrone_files), isochrone.num_ages, 6, longest_track)) - 999 + # Go through the isochrone files and collect all the data - for n, isochrone_file in enumerate(isochrone_files): + for n, isochrone_file in enumerate(isochrone_files): + isochrone = ISO(isochrone_file, verbose=False) - for x, age in enumerate(isochrone.ages): + i = isochrone.age_index(age) - - j = np.where((isochrone.isos[i]['phase'] != 6) & - (isochrone.isos[i]['initial_mass'] >= 0.08) & - (isochrone.isos[i]['initial_mass'] <= 100.) - ) - - track_length = len(isochrone.isos[i]['log_g'][j]) - isochrone_grid[metallicities_order[n], i, 0][:track_length] = np.array(isochrone.isos[i]['log_g'][j]) - isochrone_grid[metallicities_order[n], i, 1][:track_length] = np.array(10**isochrone.isos[i]['log_Teff'][j]) - isochrone_grid[metallicities_order[n], i, 2][:track_length] = np.array(isochrone.isos[i]['initial_mass'][j]) - isochrone_grid[metallicities_order[n], i, 3][:track_length] = np.array(isochrone.isos[i]['phase'][j]) - isochrone_grid[metallicities_order[n], i, 4][:track_length] = np.array(isochrone.isos[i]['log_L'][j]) - - # Write the isochrones to a database + + track_length = len(isochrone.isos[i]['log_g']) + + isochrone_grid[n, x, 0][:track_length] = np.array(isochrone.isos[i]['log_g']) + isochrone_grid[n, x, 1][:track_length] = np.array(10**isochrone.isos[i]['log_Teff']) + isochrone_grid[n, x, 2][:track_length] = np.array(isochrone.isos[i]['initial_mass']) + isochrone_grid[n, x, 3][:track_length] = np.array(isochrone.isos[i]['phase']) + isochrone_grid[n, x, 4][:track_length] = np.array(isochrone.isos[i]['log_L']) + isochrone_grid[n, x, 5][:track_length] = np.array(isochrone.isos[i]['star_mass']) + + ###################################################################### print("Writing MIST to hdf5 database") with h5py.File(os.path.splitext(file_path)[0] + ".hdf5", 'w') as f: # Isochrne grid data_iso_grid = f.create_dataset("isochrone_grid", data = isochrone_grid) - data_iso_grid.attrs["description"] = "This is a 4D tensor of isochrones orgnaized by: metalicity, age, parameter, track length. Parameter has length 5 and goes by: log_g, Teff, initial_mass, phase, log_L" + data_iso_grid.attrs["description"] = "This is a 4D tensor of isochrones orgnaized by: metalicity, age, parameter, track length. Parameter has length 6 and goes by: log_g, Teff, initial_mass, phase, log_L, star_mass" # Meta data for each axis data_metallicities = f.create_dataset("metallicities", data = metallicities) data_metallicities.attrs["description"] = "For the metallicity axis of the isochrone_grid, this is the associated matalicities" - data_ages = f.create_dataset("ages", data = ages) + data_ages = f.create_dataset("ages", data = isochrone.ages) data_ages.attrs["description"] = "For the ages axis of the isochrone_grid, this is the associated ages" dt = h5py.special_dtype(vlen=str) - params = ["log_g", "Teff", "initial_mass", "phase", "log_l"] + params = ["log_g", "Teff", "initial_mass", "phase", "log_l", "current_mass"] data_params = f.create_dataset("parameters", (len(params), ), dtype = dt) data_params[:] = params data_params.attrs["description"] = "For the parameters axis of the isochrone_grid, this lists the relevant parameters in the correct order" - # Cleanup - shutil.rmtree(os.path.splitext(file_path)[0]) + ## Cleanup + #shutil.rmtree(os.path.splitext(file_path)[0]) # TODO: this should be optional, with the default keeping the .iso files if __name__=='__main__': diff --git a/lighthouse/isochrone/isochrone.py b/lighthouse/isochrone/isochrone.py index b434fb4..c9ff44b 100644 --- a/lighthouse/isochrone/isochrone.py +++ b/lighthouse/isochrone/isochrone.py @@ -11,3 +11,17 @@ def get_isochrone(self, metalicity, Tage, *args, low_m_limit = 0.08, high_m_limi @abstractmethod def to(self, dtype=None, device=None): ... + + + + + + + + def write_isochrone(self): + pass + + + def plot_isochrone(self, ax): + + ax.plot(self.isochrone["Teff"], self.isochrone["log_g"], label='Light-House', color='r', lw=3) diff --git a/lighthouse/stellar_atmosphere_spectrum/get_stellar_templates.py b/lighthouse/stellar_atmosphere_spectrum/get_stellar_templates.py index 3a844f4..31d0274 100644 --- a/lighthouse/stellar_atmosphere_spectrum/get_stellar_templates.py +++ b/lighthouse/stellar_atmosphere_spectrum/get_stellar_templates.py @@ -1,11 +1,11 @@ -import os +import os from pathlib import Path __all__ = ("get_polynomial_coefficients_villaume2017a", ) def get_polynomial_coefficients_villaume2017a(): """ - + """ import requests @@ -15,8 +15,8 @@ def get_polynomial_coefficients_villaume2017a(): stellar_types = ['Cool_Dwarfs', 'Cool_Giants', 'Warm_Dwarfs', 'Warm_Giants', 'Hot_Stars'] - directory_path = Path(__file__) #Path().absolute() - data_path = Path(directory_path.parent, 'data/Villaume2017a/') + data_path = Path(os.environ['LightHouse_HOME'], 'lighthouse/data/Villaume2017a/') + try: os.mkdir(data_path.parent) @@ -33,9 +33,9 @@ def get_polynomial_coefficients_villaume2017a(): continue - r = requests.get(base_name.format(regime)) + r = requests.get(base_name.format(regime)) with open(file_path, 'w') as f: - f.write(r.text) + f.write(r.text) C = {} C['Cool_Dwarfs'] = [[0,0,0], [1,0,0], [0,1,0], [0,0,1], [0,2,0], [2,0,0], [0,0,2], [1,1,0], [1,0,1], [0,1,1], [0,3,0], [3,0,0], [0,0,3], [2,1,0], [1,2,0], [2,0,1], [4,0,0], [0,4,0], [2,2,0], [3,1,0], [5,0,0]] @@ -44,21 +44,27 @@ def get_polynomial_coefficients_villaume2017a(): C['Warm_Giants'] = [[0,0,0], [1,0,0], [0,1,0], [0,0,1], [2,0,0], [0,0,2], [0,2,0], [1,1,0], [1,0,1], [0,1,1], [3,0,0], [0,0,3], [0,3,0], [2,1,0], [1,2,0], [2,0,1], [1,0,2], [4,0,0], [0,4,0], [2,2,0], [2,0,2], [0,2,2], [5,0,0]] C['Hot_Stars'] = [[0,0,0], [1,0,0], [0,1,0], [0,0,1], [2,0,0], [0,2,0], [0,0,2], [1,0,1], [1,1,0], [0,1,1], [3,0,0], [0,0,3], [0,3,0], [1,1,1], [2,1,0], [2,0,1], [1,2,0], [0,2,1], [1,0,2], [0,1,2], [4,0,0]] - with open(os.path.join(data_path, 'polynomial_powers.dat'), "w") as f: + with open(os.path.join(data_path, 'polynomial_powers.dat'), "w") as f: f.write(str(C)) B = { - "Cool_Dwarfs": {"surface_gravity": (4.0, 6.0), "effective_temperature": (2500,4000)}, - "Cool_Giants": {"surface_gravity": (-0.5, 4.0), "effective_temperature": (2500,4000)}, - "Warm_Dwarfs": {"surface_gravity": (4.0, 6.0), "effective_temperature": (4000,6000)}, - "Warm_Giants": {"surface_gravity": (-0.5, 4.0), "effective_temperature": (4000,6000)}, - "Hot_Stars": {"surface_gravity": (-0.5, 6), "effective_temperature": (6000,12000)}, + "Cool_Giants": {"surface_gravity": (-0.5, 4.0), "effective_temperature": (2500,4000)}, + "Warm_Giants": {"surface_gravity": (-0.5, 4.0), "effective_temperature": (4500,5500)}, + "Hottish_Giants": {"surface_gravity": (-0.5, 4.0), "effective_temperature": (5500,6500)}, + "Coolish_Giants": {"surface_gravity": (-0.5, 4.0), "effective_temperature": (3500,4500)}, + + "Cool_Dwarfs": {"surface_gravity": (4.0, 10.0), "effective_temperature": (2500,3000)}, + "Warm_Dwarfs": {"surface_gravity": (4.0, 10.0), "effective_temperature": (5500,6000)}, + "Coolish_Dwarfs": {"surface_gravity": (4.0, 10.0), "effective_temperature": (3000,5500)}, + + "Hot_Giants": {"surface_gravity": (-0.5, 4.0), "effective_temperature": (6500,12000)}, + "Hot_Dwarfs": {"surface_gravity": (4.0, 10.0), "effective_temperature": (6000,12000)}, } - with open(os.path.join(data_path, 'bounds.dat'), "w") as f: + with open(os.path.join(data_path, 'bounds.dat'), "w") as f: f.write(str(B)) - -if __name__=='__main__': + +if __name__=='__main__': get_polynomial_coefficients_villaume2017a() diff --git a/lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py b/lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py index 6f6a472..a88d6ba 100644 --- a/lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py +++ b/lighthouse/stellar_atmosphere_spectrum/polynomial_evaluator.py @@ -17,8 +17,8 @@ class PolynomialEvaluator(Stellar_Atmosphere_Spectrum): def __init__(self): - directory_path = Path(__file__) - data_path = Path(directory_path.parent, 'data/Villaume2017a/') + data_path = Path(os.environ['LightHouse_HOME'], 'lighthouse/data/Villaume2017a/') + self.coefficients = {} self.reference = {} @@ -38,12 +38,30 @@ def __init__(self): self.wavelength = torch.tensor(coeffs.to_numpy()[:,0], dtype = torch.float64) self.reference[name] = torch.tensor(coeffs.to_numpy()[:,1], dtype = torch.float64) self.coefficients[name] = torch.tensor(coeffs.to_numpy()[:,2:], dtype = torch.float64) - + def get_spectrum(self, teff, logg, feh) -> torch.Tensor: + """ - Setting up some boundaries + These weights are used later to ensure + smooth behavior. """ + # Overlap of cool dwarf and warm dwarf training sets + d_teff_overlap = torch.linspace(3000, 5500, steps=100) + d_weights = torch.linspace(1, 0, steps=100) + + # Overlap of warm giant and hot star training sets + gh_teff_overlap = torch.linspace(5500, 6500, steps=100) + gh_weights = torch.linspace(1, 0, steps=100) + # Overlap of warm giant and cool giant training sets + gc_teff_overlap = torch.linspace(3500, 4500, steps=100) + gc_weights = torch.linspace(1, 0, steps=100) + + + + """ + Setting up some boundaries + """ teff2 = teff logg2 = logg if teff2 <= 2800.: @@ -52,26 +70,92 @@ def get_spectrum(self, teff, logg, feh) -> torch.Tensor: logg2 = torch.tensor(-0.5) # Normalizing to solar values - # logt = np.log10(teff2) - 3.7617 - logt = np.log10(teff) - 3.7617 + logt = np.log10(teff2) - 3.7617 logg = logg - 4.44 - print(logg2, teff2) for key, ranges in self.bounds.items(): if ranges["surface_gravity"][0] <= logg2 <= ranges["surface_gravity"][1] and ranges["effective_temperature"][0] <= teff2 <= ranges["effective_temperature"][1]: - stellar_type = key + + if key is 'Hot_Giants' or key is 'Hot_Dwarfs': + stellar_type = 'Hot_Stars' + else: + stellar_type = key break - else: - stellar_type = "Cool_Giants" - - K = torch.stack((torch.as_tensor(logt, dtype = torch.float64), - torch.as_tensor(feh, dtype = torch.float64), + + + K = torch.stack((torch.as_tensor(logt, dtype = torch.float64), + torch.as_tensor(feh, dtype = torch.float64), torch.as_tensor(logg, dtype = torch.float64))) - PP = torch.as_tensor(self.polynomial_powers[stellar_type], dtype = torch.float64) - X = torch.prod(K**PP, dim = -1) - flux = np.exp(self.coefficients[stellar_type] @ X) - flux *= self.reference[stellar_type] + easy_types = ['Cool_Giants', 'Warm_Giants', 'Cool_Dwarfs', 'Warm_Dwarfs', 'Hot_Stars'] + + if stellar_type in easy_types: + + PP = torch.as_tensor(self.polynomial_powers[stellar_type], dtype = torch.float64) + X = torch.prod(K**PP, dim = -1) + + flux = np.exp(self.coefficients[stellar_type] @ X) + flux *= self.reference[stellar_type] + + elif stellar_type is 'Hottish_Giants': + + PP = torch.as_tensor(self.polynomial_powers['Warm_Giants'], dtype = torch.float64) + X = torch.prod(K**PP, dim = -1) + + flux1 = np.exp(self.coefficients['Warm_Giants'] @ X) + flux1 *= self.reference['Warm_Giants'] + + PP = torch.as_tensor(self.polynomial_powers['Hot_Stars'], dtype = torch.float64) + X = torch.prod(K**PP, dim = -1) + + flux2 = np.exp(self.coefficients['Hot_Stars'] @ X) + flux2 *= self.reference['Hot_Stars'] + + t_index = (np.abs(gh_teff_overlap - teff2)).argmin() + weight = gh_weights[t_index] + flux = (flux1*weight + flux2*(1-weight)) + + elif stellar_type is 'Coolish_Giants': + + PP = torch.as_tensor(self.polynomial_powers['Warm_Giants'], dtype = torch.float64) + X = torch.prod(K**PP, dim = -1) + + flux1 = np.exp(self.coefficients['Warm_Giants'] @ X) + flux1 *= self.reference['Warm_Giants'] + + PP = torch.as_tensor(self.polynomial_powers['Cool_Giants'], dtype = torch.float64) + X = torch.prod(K**PP, dim = -1) + + flux2 = np.exp(self.coefficients['Cool_Giants'] @ X) + flux2 *= self.reference['Cool_Giants'] + + t_index = (np.abs(gh_teff_overlap - teff2)).argmin() + weight = gc_weights[t_index] + flux = (flux1*weight + flux2*(1-weight)) + + elif stellar_type is 'Coolish_Dwarfs': + + PP = torch.as_tensor(self.polynomial_powers['Warm_Dwarfs'], dtype = torch.float64) + X = torch.prod(K**PP, dim = -1) + + flux1 = np.exp(self.coefficients['Warm_Dwarfs'] @ X) + flux1 *= self.reference['Warm_Dwarfs'] + + PP = torch.as_tensor(self.polynomial_powers['Cool_Dwarfs'], dtype = torch.float64) + X = torch.prod(K**PP, dim = -1) + + flux2 = np.exp(self.coefficients['Cool_Dwarfs'] @ X) + flux2 *= self.reference['Warm_Dwarfs'] + + t_index = (np.abs(gh_teff_overlap - teff2)).argmin() + weight = d_weights[t_index] + flux = (flux1*weight + flux2*(1-weight)) + + else: + error = ('Parameter out of bounds:' + 'teff = {0}, logg {1}') + raise ValueError(error.format(teff2, logg)) + return flux diff --git a/lighthouse/utils.py b/lighthouse/utils.py index c637533..7459ff1 100644 --- a/lighthouse/utils.py +++ b/lighthouse/utils.py @@ -1 +1,2 @@ -light_speed = 2.998e14 # micron s^-1 +light_speed_cgs = 29979245800.0 +light_speed_micron = 2.998e14 diff --git a/tests/fix_sas_tests.py b/tests/fix_sas_tests.py deleted file mode 100644 index 4e97028..0000000 --- a/tests/fix_sas_tests.py +++ /dev/null @@ -1,32 +0,0 @@ - -from time import process_time as time - -import numpy as np - -import spigen - - - -from stellar_atmosphere_spectrum import PolynomialEvaluator - - -if __name__ == "__main__": - - import matplotlib.pyplot as plt - - teff = 2500.1831261806437 - logg = -0.6 - feh = 0.0 - - P = PolynomialEvaluator() - sas = P.get_spectrum(teff, logg, feh) - - spec = spigen.Spectrum() - spec = spec.from_coefficients(teff, logg, feh) - - i = (P.wavelength >= 0.36) - plt.plot(spec['wave'], spec['flux'], color='k', lw=3, label='SPI_Utils') - plt.plot(P.wavelength[i], sas[i], label='polynomial_evaluator') - - plt.legend() - plt.show() \ No newline at end of file diff --git a/tests/test_imf.py b/tests/test_imf.py index fd3a1e4..7a43e58 100644 --- a/tests/test_imf.py +++ b/tests/test_imf.py @@ -6,11 +6,53 @@ sys.path.insert(0, os.path.split(os.path.split(__file__)[0])[0]) import lighthouse as lh + +class TestIMF(unittest.TestCase): + + def test_t0_normalization(self): + """ + Check that the IMFs are valid PDFs + """ + + N = 100000 + shift = (100 - 0.08)/N + masses = torch.linspace(0.08 + shift/2, 100 - shift/2, N) + dm = masses[1] - masses[0] + + for imf in lh.initial_mass_function.Initial_Mass_Function.__subclasses__(): + + weights = imf().get_weight(masses, mass_weighted=True) + + self.assertAlmostEqual(torch.sum(weights*dm).item(), 1, 3, "Check your IMF normalization!!!") + + + + def test_continuity(self): + + N = 1000000 + shift = (100 - 0.08)/N + masses = torch.linspace(0.08 + shift/2, 100 - shift/2, N) + + for imf in lh.initial_mass_function.Initial_Mass_Function.__subclasses__(): + + print(imf) + + weights = imf().get_weight(masses, mass_weighted=True) + + delta_weight = weights[1:] - weights[:-1] + + print(delta_weight.abs()/weights[1:]) + + self.assertTrue(torch.all(delta_weight.abs()/weights[1:] < 1e-3), "Check your piecewise functions!!") + class TestKroupa(unittest.TestCase): + + def test_kroupa(self): + K = lh.initial_mass_function.Kroupa() - weight = K.get_weight(torch.tensor(30., dtype = torch.float64), torch.tensor([1.3, 2.3, 2.7])) + weight = K.get_imf(torch.tensor(30., dtype = torch.float64)) self.assertAlmostEqual(weight.detach().cpu().numpy(), 5.1374e-5, 3, "Kroupa IMF test value not equal") diff --git a/tests/test_install.py b/tests/test_install.py new file mode 100644 index 0000000..e3373df --- /dev/null +++ b/tests/test_install.py @@ -0,0 +1,10 @@ +import unittest +import sys +import os + +sys.path.insert(0, os.path.split(os.path.split(__file__)[0])[0]) +import lighthouse as lh + +class TestInstall(unittest.TestCase): + def test_home_variable(self): + self.assertTrue("LightHouse_Home" in os.environ, "Need to set up LightHouse_HOME environment variable")