Skip to content

Commit

Permalink
Expose model params (#185)
Browse files Browse the repository at this point in the history
* restructure high level code to allow modification of model after creation

* remove comment

* fix conftext import

* apply black

* add docstrings

* apply black (again)

* move create_radiation_field() function
  • Loading branch information
jvshields authored Apr 15, 2024
1 parent feeb2db commit 99bc5dc
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 105 deletions.
147 changes: 44 additions & 103 deletions stardis/base.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,13 @@
import numpy as np

from tardis.io.atom_data import AtomData
from tardis.io.configuration.config_validator import validate_yaml
from tardis.io.configuration.config_reader import Configuration

from astropy import units as u
from pathlib import Path
import numba

from stardis.io.base import parse_config_to_model
from stardis.plasma import create_stellar_plasma
from stardis.radiation_field.opacities.opacities_solvers import calc_alphas
from stardis.radiation_field.radiation_field_solvers import raytrace
from stardis.radiation_field import RadiationField
from stardis.io.model.marcs import read_marcs_model
from stardis.io.model.mesa import read_mesa_model
from stardis.radiation_field.source_functions.blackbody import blackbody_flux_at_nu
import logging

from stardis.radiation_field.base import create_stellar_radiation_field
from astropy import units as u

BASE_DIR = Path(__file__).parent
SCHEMA_PATH = BASE_DIR / "config_schema.yml"
import logging


###TODO: Make a function that parses the config and model files and outputs python objects to be passed into run stardis so they can be individually modified in python
def run_stardis(config_fname, tracing_lambdas_or_nus):
"""
Runs a STARDIS simulation.
Expand All @@ -44,99 +29,55 @@ def run_stardis(config_fname, tracing_lambdas_or_nus):

tracing_nus = tracing_lambdas_or_nus.to(u.Hz, u.spectral())

try:
config_dict = validate_yaml(config_fname, schemapath=SCHEMA_PATH)
config = Configuration(config_dict)
except:
raise ValueError("Config failed to validate. Check the config file.")
config, adata, stellar_model = parse_config_to_model(config_fname)
set_num_threads(config.n_threads)

# Set multithreading as specified by the config
if config.n_threads == 1:
logging.info("Running in serial mode")
elif config.n_threads == -99:
logging.info("Running with max threads")
elif config.n_threads > 1:
logging.info(f"Running with {config.n_threads} threads")
numba.set_num_threads(config.n_threads)
else:
raise ValueError(
"n_threads must be a positive integer less than the number of available threads, or -99 to run with max threads."
)

adata = AtomData.from_hdf(config.atom_data)

# model
logging.info("Reading model")
if config.model.type == "marcs":
raw_marcs_model = read_marcs_model(
Path(config.model.fname), gzipped=config.model.gzipped
)
stellar_model = raw_marcs_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
)

elif config.model.type == "mesa":
raw_mesa_model = read_mesa_model(Path(config.model.fname))
if config.model.truncate_to_shell != -99:
raw_mesa_model.truncate_model(config.model.truncate_to_shell)
elif config.model.truncate_to_shell < 0:
raise ValueError(
f"{config.model.truncate_to_shell} shells were requested for mesa model truncation. -99 is default for no truncation."
)

stellar_model = raw_mesa_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
)

else:
raise ValueError("Model type not recognized. Must be either 'marcs' or 'mesa'")

# Handle case of when there are fewer elements requested vs. elements in the atomic mass fraction table.
adata.prepare_atom_data(
np.arange(
1,
np.min(
[
len(
stellar_model.composition.atomic_mass_fraction.columns.tolist()
),
config.model.final_atomic_number,
]
)
+ 1,
),
line_interaction_type="macroatom",
nlte_species=[],
continuum_interaction_species=[],
)
# plasma
logging.info("Creating plasma")
stellar_plasma = create_stellar_plasma(stellar_model, adata, config)

stellar_radiation_field = RadiationField(
tracing_nus, blackbody_flux_at_nu, stellar_model
)
logging.info("Calculating alphas")
calc_alphas(
stellar_plasma=stellar_plasma,
stellar_model=stellar_model,
stellar_radiation_field=stellar_radiation_field,
opacity_config=config.opacity,
n_threads=config.n_threads,
)
logging.info("Raytracing")
raytrace(
stellar_model,
stellar_radiation_field,
no_of_thetas=config.no_of_thetas,
n_threads=config.n_threads,
stellar_radiation_field = create_stellar_radiation_field(
tracing_nus, stellar_model, stellar_plasma, config
)

return STARDISOutput(
config.result_options, stellar_model, stellar_plasma, stellar_radiation_field
)


def set_num_threads(n_threads):
"""
Set the number of threads for multithreading.
This function sets the number of threads to be used for multithreading based on the
input argument `n_threads`. It uses Numba's `set_num_threads` function to set the
number of threads.
Parameters
----------
n_threads : int
The number of threads to use. If `n_threads` is 1, the function will run in
serial mode. If `n_threads` is -99, the function will run with the maximum
number of available threads. If `n_threads` is greater than 1, the function
will run with `n_threads` threads.
Raises
------
ValueError
If `n_threads` is not a positive integer less than the number of available
threads, and it's not -99.
"""
if n_threads == 1:
logging.info("Running in serial mode")
elif n_threads == -99:
logging.info("Running with max threads")
elif n_threads > 1:
logging.info(f"Running with {n_threads} threads")
numba.set_num_threads(n_threads)
else:
raise ValueError(
"n_threads must be a positive integer less than the number of available threads, or -99 to run with max threads."
)


class STARDISOutput:
"""
Class containing all the key outputs of a STARDIS simulation.
Expand Down
3 changes: 2 additions & 1 deletion stardis/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from stardis.radiation_field.radiation_field_solvers import raytrace
from stardis.radiation_field import RadiationField
from stardis.radiation_field.source_functions.blackbody import blackbody_flux_at_nu
from stardis import SCHEMA_PATH, STARDISOutput
from stardis import STARDISOutput
from stardis.io.base import SCHEMA_PATH

EXAMPLE_CONF_PATH = Path(__file__).parent / "tests" / "stardis_test_config.yml"
EXAMPLE_CONF_PATH_BROADENING = (
Expand Down
89 changes: 89 additions & 0 deletions stardis/io/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from pathlib import Path
import logging
import numpy as np

from tardis.io.atom_data import AtomData
from tardis.io.configuration.config_validator import validate_yaml
from tardis.io.configuration.config_reader import Configuration

from stardis.io.model.marcs import read_marcs_model
from stardis.io.model.mesa import read_mesa_model


BASE_DIR = Path(__file__).parent.parent
SCHEMA_PATH = BASE_DIR / "config_schema.yml"


def parse_config_to_model(config_fname):
"""
Parses the config and model files and outputs python objects to be passed into run stardis so they can be individually modified in python.
Parameters
----------
config_fname : str
Filepath to the STARDIS configuration. Must be a YAML file.
Returns
-------
config : stardis.io.configuration.config_reader.Configuration
Configuration object.
adata : tardis.io.atom_data.AtomData
AtomData object.
stellar_model : stardis.io.model.marcs.MarcsModel or stardis.io.model.mesa.MesaModel
Stellar model object.
"""

try:
config_dict = validate_yaml(config_fname, schemapath=SCHEMA_PATH)
config = Configuration(config_dict)
except:
raise ValueError("Config failed to validate. Check the config file.")

adata = AtomData.from_hdf(config.atom_data)

# model
logging.info("Reading model")
if config.model.type == "marcs":
raw_marcs_model = read_marcs_model(
Path(config.model.fname), gzipped=config.model.gzipped
)
stellar_model = raw_marcs_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
)

elif config.model.type == "mesa":
raw_mesa_model = read_mesa_model(Path(config.model.fname))
if config.model.truncate_to_shell != -99:
raw_mesa_model.truncate_model(config.model.truncate_to_shell)
elif config.model.truncate_to_shell < 0:
raise ValueError(
f"{config.model.truncate_to_shell} shells were requested for mesa model truncation. -99 is default for no truncation."
)

stellar_model = raw_mesa_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
)

else:
raise ValueError("Model type not recognized. Must be either 'marcs' or 'mesa'")

# Handle case of when there are fewer elements requested vs. elements in the atomic mass fraction table.
adata.prepare_atom_data(
np.arange(
1,
np.min(
[
len(
stellar_model.composition.atomic_mass_fraction.columns.tolist()
),
config.model.final_atomic_number,
]
)
+ 1,
),
line_interaction_type="macroatom",
nlte_species=[],
continuum_interaction_species=[],
)

return config, adata, stellar_model
3 changes: 3 additions & 0 deletions stardis/plasma/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pandas as pd
import logging

from astropy import constants as const, units as u
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -502,6 +503,8 @@ def create_stellar_plasma(
tardis.plasma.base.BasePlasma
"""

logging.info("Creating plasma")

# basic_properties.remove(tardis.plasma.properties.general.NumberDensity)
plasma_modules = []
plasma_modules += basic_inputs
Expand Down
54 changes: 53 additions & 1 deletion stardis/radiation_field/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from stardis.radiation_field.opacities import Opacities
import numpy as np
import logging
from stardis.radiation_field.opacities import Opacities
from stardis.radiation_field.opacities.opacities_solvers import calc_alphas
from stardis.radiation_field.radiation_field_solvers import raytrace
from stardis.radiation_field.source_functions.blackbody import blackbody_flux_at_nu


class RadiationField:
Expand Down Expand Up @@ -31,3 +35,51 @@ def __init__(self, frequencies, source_function, stellar_model):
self.source_function = source_function
self.opacities = Opacities(frequencies, stellar_model)
self.F_nu = np.zeros((stellar_model.no_of_depth_points, len(frequencies)))


def create_stellar_radiation_field(tracing_nus, stellar_model, stellar_plasma, config):
"""
Create a stellar radiation field.
This function creates a stellar radiation field by initializing a `RadiationField`
object and calculating the alpha values for the stellar plasma. It then performs
raytracing on the stellar model.
Parameters
----------
tracing_nus : array
The frequencies at which to trace the radiation field.
stellar_model : StellarModel
The stellar model to use for the radiation field.
stellar_plasma : StellarPlasma
The stellar plasma to use for the radiation field.
config : Config
The configuration object containing the opacity and threading settings.
Returns
-------
stellar_radiation_field : RadiationField
The created stellar radiation field.
"""

stellar_radiation_field = RadiationField(
tracing_nus, blackbody_flux_at_nu, stellar_model
)
logging.info("Calculating alphas")
calc_alphas(
stellar_plasma=stellar_plasma,
stellar_model=stellar_model,
stellar_radiation_field=stellar_radiation_field,
opacity_config=config.opacity,
n_threads=config.n_threads,
)
logging.info("Raytracing")
raytrace(
stellar_model,
stellar_radiation_field,
no_of_thetas=config.no_of_thetas,
n_threads=config.n_threads,
)

return stellar_radiation_field

0 comments on commit 99bc5dc

Please sign in to comment.