Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose model params #185

Merged
merged 7 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 44 additions & 103 deletions stardis/base.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,13 @@
import numpy as np

from tardis.io.atom_data import AtomData
from tardis.io.configuration.config_validator import validate_yaml
from tardis.io.configuration.config_reader import Configuration

from astropy import units as u
from pathlib import Path
import numba

from stardis.io.base import parse_config_to_model
from stardis.plasma import create_stellar_plasma
from stardis.radiation_field.opacities.opacities_solvers import calc_alphas
from stardis.radiation_field.radiation_field_solvers import raytrace
from stardis.radiation_field import RadiationField
from stardis.io.model.marcs import read_marcs_model
from stardis.io.model.mesa import read_mesa_model
from stardis.radiation_field.source_functions.blackbody import blackbody_flux_at_nu
import logging

from stardis.radiation_field.base import create_stellar_radiation_field
from astropy import units as u

BASE_DIR = Path(__file__).parent
SCHEMA_PATH = BASE_DIR / "config_schema.yml"
import logging


###TODO: Make a function that parses the config and model files and outputs python objects to be passed into run stardis so they can be individually modified in python
def run_stardis(config_fname, tracing_lambdas_or_nus):
"""
Runs a STARDIS simulation.
Expand All @@ -44,99 +29,55 @@ def run_stardis(config_fname, tracing_lambdas_or_nus):

tracing_nus = tracing_lambdas_or_nus.to(u.Hz, u.spectral())

try:
config_dict = validate_yaml(config_fname, schemapath=SCHEMA_PATH)
config = Configuration(config_dict)
except:
raise ValueError("Config failed to validate. Check the config file.")
config, adata, stellar_model = parse_config_to_model(config_fname)
set_num_threads(config.n_threads)

# Set multithreading as specified by the config
if config.n_threads == 1:
logging.info("Running in serial mode")
elif config.n_threads == -99:
logging.info("Running with max threads")
elif config.n_threads > 1:
logging.info(f"Running with {config.n_threads} threads")
numba.set_num_threads(config.n_threads)
else:
raise ValueError(
"n_threads must be a positive integer less than the number of available threads, or -99 to run with max threads."
)

adata = AtomData.from_hdf(config.atom_data)

# model
logging.info("Reading model")
if config.model.type == "marcs":
raw_marcs_model = read_marcs_model(
Path(config.model.fname), gzipped=config.model.gzipped
)
stellar_model = raw_marcs_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
)

elif config.model.type == "mesa":
raw_mesa_model = read_mesa_model(Path(config.model.fname))
if config.model.truncate_to_shell != -99:
raw_mesa_model.truncate_model(config.model.truncate_to_shell)
elif config.model.truncate_to_shell < 0:
raise ValueError(
f"{config.model.truncate_to_shell} shells were requested for mesa model truncation. -99 is default for no truncation."
)

stellar_model = raw_mesa_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
)

else:
raise ValueError("Model type not recognized. Must be either 'marcs' or 'mesa'")

# Handle case of when there are fewer elements requested vs. elements in the atomic mass fraction table.
adata.prepare_atom_data(
np.arange(
1,
np.min(
[
len(
stellar_model.composition.atomic_mass_fraction.columns.tolist()
),
config.model.final_atomic_number,
]
)
+ 1,
),
line_interaction_type="macroatom",
nlte_species=[],
continuum_interaction_species=[],
)
# plasma
logging.info("Creating plasma")
stellar_plasma = create_stellar_plasma(stellar_model, adata, config)

stellar_radiation_field = RadiationField(
tracing_nus, blackbody_flux_at_nu, stellar_model
)
logging.info("Calculating alphas")
calc_alphas(
stellar_plasma=stellar_plasma,
stellar_model=stellar_model,
stellar_radiation_field=stellar_radiation_field,
opacity_config=config.opacity,
n_threads=config.n_threads,
)
logging.info("Raytracing")
raytrace(
stellar_model,
stellar_radiation_field,
no_of_thetas=config.no_of_thetas,
n_threads=config.n_threads,
stellar_radiation_field = create_stellar_radiation_field(
tracing_nus, stellar_model, stellar_plasma, config
)

return STARDISOutput(
config.result_options, stellar_model, stellar_plasma, stellar_radiation_field
)


def set_num_threads(n_threads):
"""
Set the number of threads for multithreading.

This function sets the number of threads to be used for multithreading based on the
input argument `n_threads`. It uses Numba's `set_num_threads` function to set the
number of threads.

Parameters
----------
n_threads : int
The number of threads to use. If `n_threads` is 1, the function will run in
serial mode. If `n_threads` is -99, the function will run with the maximum
number of available threads. If `n_threads` is greater than 1, the function
will run with `n_threads` threads.

Raises
------
ValueError
If `n_threads` is not a positive integer less than the number of available
threads, and it's not -99.

"""
if n_threads == 1:
logging.info("Running in serial mode")
elif n_threads == -99:
logging.info("Running with max threads")
elif n_threads > 1:
logging.info(f"Running with {n_threads} threads")
numba.set_num_threads(n_threads)
else:
raise ValueError(
"n_threads must be a positive integer less than the number of available threads, or -99 to run with max threads."
)


class STARDISOutput:
"""
Class containing all the key outputs of a STARDIS simulation.
Expand Down
3 changes: 2 additions & 1 deletion stardis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from stardis.radiation_field.radiation_field_solvers import raytrace
from stardis.radiation_field import RadiationField
from stardis.radiation_field.source_functions.blackbody import blackbody_flux_at_nu
from stardis import SCHEMA_PATH, STARDISOutput
from stardis import STARDISOutput
from stardis.io.base import SCHEMA_PATH

EXAMPLE_CONF_PATH = Path(__file__).parent / "tests" / "stardis_test_config.yml"
EXAMPLE_CONF_PATH_BROADENING = (
Expand Down
89 changes: 89 additions & 0 deletions stardis/io/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from pathlib import Path
import logging
import numpy as np

from tardis.io.atom_data import AtomData
from tardis.io.configuration.config_validator import validate_yaml
from tardis.io.configuration.config_reader import Configuration

from stardis.io.model.marcs import read_marcs_model
from stardis.io.model.mesa import read_mesa_model


BASE_DIR = Path(__file__).parent.parent
SCHEMA_PATH = BASE_DIR / "config_schema.yml"


def parse_config_to_model(config_fname):
"""
Parses the config and model files and outputs python objects to be passed into run stardis so they can be individually modified in python.

Parameters
----------
config_fname : str
Filepath to the STARDIS configuration. Must be a YAML file.

Returns
-------
config : stardis.io.configuration.config_reader.Configuration
Configuration object.
adata : tardis.io.atom_data.AtomData
AtomData object.
stellar_model : stardis.io.model.marcs.MarcsModel or stardis.io.model.mesa.MesaModel
Stellar model object.
"""

try:
config_dict = validate_yaml(config_fname, schemapath=SCHEMA_PATH)
config = Configuration(config_dict)
except:
raise ValueError("Config failed to validate. Check the config file.")

adata = AtomData.from_hdf(config.atom_data)

# model
logging.info("Reading model")
if config.model.type == "marcs":
raw_marcs_model = read_marcs_model(
Path(config.model.fname), gzipped=config.model.gzipped
)
stellar_model = raw_marcs_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
)

elif config.model.type == "mesa":
raw_mesa_model = read_mesa_model(Path(config.model.fname))
if config.model.truncate_to_shell != -99:
raw_mesa_model.truncate_model(config.model.truncate_to_shell)
elif config.model.truncate_to_shell < 0:
raise ValueError(
f"{config.model.truncate_to_shell} shells were requested for mesa model truncation. -99 is default for no truncation."
)

stellar_model = raw_mesa_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
)

else:
raise ValueError("Model type not recognized. Must be either 'marcs' or 'mesa'")

# Handle case of when there are fewer elements requested vs. elements in the atomic mass fraction table.
adata.prepare_atom_data(
np.arange(
1,
np.min(
[
len(
stellar_model.composition.atomic_mass_fraction.columns.tolist()
),
config.model.final_atomic_number,
]
)
+ 1,
),
line_interaction_type="macroatom",
nlte_species=[],
continuum_interaction_species=[],
)

return config, adata, stellar_model
3 changes: 3 additions & 0 deletions stardis/plasma/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import logging

from astropy import constants as const, units as u
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -502,6 +503,8 @@ def create_stellar_plasma(
tardis.plasma.base.BasePlasma
"""

logging.info("Creating plasma")

# basic_properties.remove(tardis.plasma.properties.general.NumberDensity)
plasma_modules = []
plasma_modules += basic_inputs
Expand Down
54 changes: 53 additions & 1 deletion stardis/radiation_field/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from stardis.radiation_field.opacities import Opacities
import numpy as np
import logging
from stardis.radiation_field.opacities import Opacities
from stardis.radiation_field.opacities.opacities_solvers import calc_alphas
from stardis.radiation_field.radiation_field_solvers import raytrace
from stardis.radiation_field.source_functions.blackbody import blackbody_flux_at_nu


class RadiationField:
Expand Down Expand Up @@ -31,3 +35,51 @@ def __init__(self, frequencies, source_function, stellar_model):
self.source_function = source_function
self.opacities = Opacities(frequencies, stellar_model)
self.F_nu = np.zeros((stellar_model.no_of_depth_points, len(frequencies)))


def create_stellar_radiation_field(tracing_nus, stellar_model, stellar_plasma, config):
"""
Create a stellar radiation field.

This function creates a stellar radiation field by initializing a `RadiationField`
object and calculating the alpha values for the stellar plasma. It then performs
raytracing on the stellar model.

Parameters
----------
tracing_nus : array
The frequencies at which to trace the radiation field.
stellar_model : StellarModel
The stellar model to use for the radiation field.
stellar_plasma : StellarPlasma
The stellar plasma to use for the radiation field.
config : Config
The configuration object containing the opacity and threading settings.

Returns
-------
stellar_radiation_field : RadiationField
The created stellar radiation field.

"""

stellar_radiation_field = RadiationField(
tracing_nus, blackbody_flux_at_nu, stellar_model
)
logging.info("Calculating alphas")
calc_alphas(
stellar_plasma=stellar_plasma,
stellar_model=stellar_model,
stellar_radiation_field=stellar_radiation_field,
opacity_config=config.opacity,
n_threads=config.n_threads,
)
logging.info("Raytracing")
raytrace(
stellar_model,
stellar_radiation_field,
no_of_thetas=config.no_of_thetas,
n_threads=config.n_threads,
)

return stellar_radiation_field
Loading