Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow user to specify config values to overwrite in run_stardis function #191

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions stardis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import logging


def run_stardis(config_fname, tracing_lambdas_or_nus):
def run_stardis(
config_fname, tracing_lambdas_or_nus, add_config_keys=None, add_config_vals=None
):
"""
Runs a STARDIS simulation.

Expand All @@ -20,6 +22,10 @@ def run_stardis(config_fname, tracing_lambdas_or_nus):
Numpy array of the frequencies or wavelengths to calculate the
spectrum for. Must have units attached to it, with dimensions
of either length or inverse time.
add_config_keys : list, optional
List of additional keys to add or overwrite for the configuration file.
add_config_vals : list, optional
List of corresponding additional values to add to the configuration file.

Returns
-------
Expand All @@ -29,7 +35,9 @@ def run_stardis(config_fname, tracing_lambdas_or_nus):

tracing_nus = tracing_lambdas_or_nus.to(u.Hz, u.spectral())

config, adata, stellar_model = parse_config_to_model(config_fname)
config, adata, stellar_model = parse_config_to_model(
config_fname, add_config_keys, add_config_vals
)
set_num_threads(config.n_threads)
stellar_plasma = create_stellar_plasma(stellar_model, adata, config)
stellar_radiation_field = create_stellar_radiation_field(
Expand Down
25 changes: 23 additions & 2 deletions stardis/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from tardis.io.atom_data import AtomData
from tardis.io.configuration.config_validator import validate_yaml
from tardis.io.configuration.config_validator import validate_yaml, validate_dict
from tardis.io.configuration.config_reader import Configuration

from stardis.io.model.marcs import read_marcs_model
Expand All @@ -15,14 +15,18 @@
SCHEMA_PATH = BASE_DIR / "config_schema.yml"


def parse_config_to_model(config_fname):
def parse_config_to_model(config_fname, add_config_keys=None, add_config_vals=None):
"""
Parses the config and model files and outputs python objects to be passed into run stardis so they can be individually modified in python.

Parameters
----------
config_fname : str
Filepath to the STARDIS configuration. Must be a YAML file.
add_config_keys : list, optional
List of additional keys to add or overwrite for the configuration file.
add_config_vals : list, optional
List of corresponding additional values to add to the configuration file.

Returns
-------
Expand All @@ -40,6 +44,23 @@ def parse_config_to_model(config_fname):
except:
raise ValueError("Config failed to validate. Check the config file.")

if (
not add_config_keys
): # If a dictionary was passed, update the config with the dictionary
pass
else:
print("Updating config with additional keys and values")
try:
for key, val in zip(add_config_keys, add_config_vals):
config.set_config_item(key, val)
except:
config.set_config_item(add_config_keys, add_config_vals)

try:
config_dict = validate_dict(config, schemapath=SCHEMA_PATH)
except:
raise ValueError("Additional config keys and values failed to validate.")

adata = AtomData.from_hdf(config.atom_data)

# model
Expand Down
2 changes: 1 addition & 1 deletion stardis/io/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from stardis.io.model.marcs import read_marcs_model
from stardis.io.model.marcs import *
Loading