Skip to content

Commit

Permalink
Benchmark raytrace function
Browse files Browse the repository at this point in the history
  • Loading branch information
light2802 committed Jul 11, 2023
1 parent e1f7d06 commit 5062605
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions benchmarks/config_schema.yml
59 changes: 59 additions & 0 deletions benchmarks/run_stardis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,18 @@
import os
import numpy as np
from stardis.base import run_stardis

from astropy import units as u

from tardis.io.atom_data import AtomData
from tardis.io.config_validator import validate_yaml
from tardis.io.config_reader import Configuration

from stardis.model import read_marcs_to_fv
from stardis.plasma import create_stellar_plasma
from stardis.opacities import calc_alphas
from stardis.transport import raytrace


class BenchmarkRunStardis:
"""
Expand All @@ -21,3 +31,52 @@ def setup(self):

def time_run_stardis(self):
run_stardis(self.config, self.tracing_lambdas)


class Benchmarkraytrace:
"""
Class to benchmark raytrace function.
"""

def setup(self):
base_dir = os.path.abspath(os.path.dirname(__file__))
os.chdir(base_dir)
schema = os.path.join(base_dir, "config_schema.yml")
config_fname = os.path.join(base_dir, "benchmark_config.yml")
tracing_lambdas_or_nus = np.arange(6540, 6590, 0.01) * u.Angstrom

tracing_nus = tracing_lambdas_or_nus.to(u.Hz, u.spectral())

config_dict = validate_yaml(config_fname, schemapath=schema)
config = Configuration(config_dict)

adata = AtomData.from_hdf(config.atom_data)

stellar_model = read_marcs_to_fv(
config.model.fname,
adata,
final_atomic_number=config.model.final_atomic_number,
)
adata.prepare_atom_data(stellar_model.abundances.index.tolist())

stellar_plasma = create_stellar_plasma(stellar_model, adata)

alphas, gammas, doppler_widths = calc_alphas(
stellar_plasma=stellar_plasma,
stellar_model=stellar_model,
tracing_nus=tracing_nus,
opacity_config=config.opacity,
)

self.stellar_model = stellar_model
self.alphas = alphas
self.tracing_nus = tracing_nus
self.config = config

def time_raytrace(self):
raytrace(
self.stellar_model,
self.alphas,
self.tracing_nus,
no_of_thetas=self.config.no_of_thetas,
)

0 comments on commit 5062605

Please sign in to comment.