Skip to content

Commit

Permalink
more fixes to the montecarlo globals
Browse files Browse the repository at this point in the history
  • Loading branch information
wkerzendorf committed Jul 12, 2024
1 parent b3b470a commit 61bea81
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 72 deletions.
4 changes: 2 additions & 2 deletions tardis/simulation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from tardis.model import SimulationState
from tardis.plasma.standard_plasmas import assemble_plasma
from tardis.simulation.convergence import ConvergenceSolver
from tardis.transport.montecarlo import montecarlo_configuration
from tardis.transport.montecarlo.base import MonteCarloTransportSolver
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.util.base import is_notebook
from tardis.visualization import ConvergencePlots

Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(
self._callbacks = OrderedDict()
self._cb_next_id = 0

montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED = (
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED = (
not self.plasma.continuum_interaction_species.empty
)

Expand Down
14 changes: 7 additions & 7 deletions tardis/transport/montecarlo/formal_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
)
from tardis.spectrum import TARDISSpectrum
from tardis.transport.montecarlo import (
montecarlo_configuration,
njit_dict,
njit_dict_no_parallel,
)
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON
from tardis.transport.montecarlo.formal_integral_cuda import (
CudaFormalIntegrator,
)
from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON
from tardis.transport.montecarlo.numba_interface import (
opacity_state_initialize,
)
Expand Down Expand Up @@ -283,8 +283,8 @@ def __init__(self, simulation_state, plasma, transport, points=1000):
self.plasma = opacity_state_initialize(
plasma,
transport.line_interaction_type,
montecarlo_configuration.DISABLE_LINE_SCATTERING,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.DISABLE_LINE_SCATTERING,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
)
self.atomic_data = plasma.atomic_data
self.original_plasma = plasma
Expand All @@ -307,8 +307,8 @@ def generate_numba_objects(self):
self.opacity_state = opacity_state_initialize(
self.original_plasma,
self.transport.line_interaction_type,
montecarlo_configuration.DISABLE_LINE_SCATTERING,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.DISABLE_LINE_SCATTERING,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
)
if self.transport.use_gpu:
self.integrator = CudaFormalIntegrator(
Expand Down Expand Up @@ -360,7 +360,7 @@ def raise_or_return(message):
'and line_interaction_type == "macroatom"'
)

if montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED:
if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
return raise_or_return(
"The FormalIntegrator currently does not work for "
"continuum interactions."
Expand Down
18 changes: 6 additions & 12 deletions tardis/transport/montecarlo/formal_integral_cuda.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import sys
import numpy as np
from astropy import units as u
from numba import float64, int64, cuda
import math

from tardis.transport.montecarlo.numba_config import SIGMA_THOMSON
import numpy as np
from numba import cuda

from tardis.transport.montecarlo.configuration.constants import SIGMA_THOMSON

C_INV = 3.33564e-11
M_PI = np.arccos(-1)
Expand Down Expand Up @@ -65,8 +64,7 @@ def cuda_formal_integral(
z : array(float64, 2d, C)
shell_id : array(int64, 2d, C)
"""

# todo: add all the original todos
# TODO: add all the original todos

# global read-only values
size_line, size_shell = tau_sobolev.shape
Expand Down Expand Up @@ -210,7 +208,7 @@ def cuda_formal_integral(
)


class CudaFormalIntegrator(object):
class CudaFormalIntegrator:
"""
Helper class for performing the formal integral
with CUDA.
Expand Down Expand Up @@ -376,8 +374,6 @@ class BoundsError(IndexError):
binary search
"""

pass


@cuda.jit(device=True)
def line_search_cuda(nu, nu_insert, number_of_lines):
Expand Down Expand Up @@ -465,7 +461,6 @@ def trapezoid_integration_cuda(arr, dx):
arr : (array(float64, 1d, C)
dx : np.float64
"""

result = arr[0] + arr[-1]

for x in range(1, len(arr) - 1):
Expand Down Expand Up @@ -510,5 +505,4 @@ def calculate_p_values(R_max, N):
-------
float64
"""

return np.arange(N).astype(np.float64) * R_max / (N - 1)
3 changes: 2 additions & 1 deletion tardis/transport/montecarlo/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
MacroAtomTransitionType,
macro_atom,
)
from tardis.transport.montecarlo.numba_config import (

from tardis.transport.montecarlo.configuration.constants import (
LineInteractionType,
)
from tardis.transport.montecarlo.r_packet import (
Expand Down
5 changes: 3 additions & 2 deletions tardis/transport/montecarlo/montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def montecarlo_main_loop(
iteration,
show_progress_bars,
total_iterations,
montecarlo_configuration,
):
"""This is the main loop of the MonteCarlo routine that generates packets
and sends them through the ejecta.
Expand Down Expand Up @@ -168,7 +169,7 @@ def montecarlo_main_loop(
for sub_estimator in estimator_list:
estimators.increment(sub_estimator)

if montecarlo_configuration.ENABLE_VPACKET_TRACKING:
if montecarlo_globals.ENABLE_VPACKET_TRACKING:
vpacket_tracker = consolidate_vpacket_tracker(
vpacket_collections,
spectrum_frequency,
Expand All @@ -185,7 +186,7 @@ def montecarlo_main_loop(
1,
)

if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if montecarlo_globals.ENABLE_RPACKET_TRACKING:
for rpacket_tracker in rpacket_trackers:
rpacket_tracker.finalize_array()

Expand Down
71 changes: 34 additions & 37 deletions tardis/transport/montecarlo/single_packet_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_doppler_factor,
get_inverse_doppler_factor,
)
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.estimators.radfield_estimator_calcs import (
update_bound_free_estimators,
)
Expand All @@ -21,16 +22,12 @@
InteractionType,
PacketStatus,
)
from tardis.transport.montecarlo.vpacket import trace_vpacket_volley
from tardis.transport.frame_transformations import (
get_doppler_factor,
get_inverse_doppler_factor,
)
from tardis.transport.montecarlo.r_packet_transport import (
move_packet_across_shell_boundary,
move_r_packet,
trace_packet,
)
from tardis.transport.montecarlo.vpacket import trace_vpacket_volley

C_SPEED_OF_LIGHT = const.c.to("cm/s").value

Expand Down Expand Up @@ -63,16 +60,16 @@ def single_packet_loop(
This function does not return anything but changes the r_packet object
and if virtual packets are requested - also updates the vpacket_collection
"""
line_interaction_type = montecarlo_configuration.LINE_INTERACTION_TYPE
line_interaction_type = montecarlo_globals.LINE_INTERACTION_TYPE

if montecarlo_configuration.ENABLE_FULL_RELATIVITY:
if montecarlo_globals.ENABLE_FULL_RELATIVITY:
set_packet_props_full_relativity(r_packet, time_explosion)
else:
set_packet_props_partial_relativity(r_packet, time_explosion)
r_packet.initialize_line_id(
opacity_state,
time_explosion,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)

trace_vpacket_volley(
Expand All @@ -81,13 +78,13 @@ def single_packet_loop(
numba_radial_1d_geometry,
time_explosion,
opacity_state,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
montecarlo_configuration.VPACKET_TAU_RUSSIAN,
montecarlo_configuration.SURVIVAL_PROBABILITY,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
)

if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if montecarlo_globals.ENABLE_RPACKET_TRACKING:
rpacket_tracker.track(r_packet)

# this part of the code is temporary and will be better incorporated
Expand All @@ -98,14 +95,14 @@ def single_packet_loop(
r_packet.r,
r_packet.mu,
time_explosion,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)

comov_nu = r_packet.nu * doppler_factor
chi_e = chi_electron_calculator(
opacity_state, comov_nu, r_packet.current_shell_id
)
if montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED:
if montecarlo_globals.CONTINUUM_PROCESSES_ENABLED:
(
chi_bf_tot,
chi_bf_contributions,
Expand All @@ -118,7 +115,7 @@ def single_packet_loop(
chi_continuum = chi_e + chi_bf_tot + chi_ff

escat_prob = chi_e / chi_continuum # probability of e-scatter
if montecarlo_configuration.ENABLE_FULL_RELATIVITY:
if montecarlo_globals.ENABLE_FULL_RELATIVITY:
chi_continuum *= doppler_factor
distance, interaction_type, delta_shell = trace_packet(
r_packet,
Expand All @@ -128,9 +125,9 @@ def single_packet_loop(
estimators,
chi_continuum,
escat_prob,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_configuration.DISABLE_LINE_SCATTERING,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
montecarlo_globals.DISABLE_LINE_SCATTERING,
)
update_bound_free_estimators(
comov_nu,
Expand All @@ -146,7 +143,7 @@ def single_packet_loop(
else:
escat_prob = 1.0
chi_continuum = chi_e
if montecarlo_configuration.ENABLE_FULL_RELATIVITY:
if montecarlo_globals.ENABLE_FULL_RELATIVITY:
chi_continuum *= doppler_factor
distance, interaction_type, delta_shell = trace_packet(
r_packet,
Expand All @@ -156,9 +153,9 @@ def single_packet_loop(
estimators,
chi_continuum,
escat_prob,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_configuration.DISABLE_LINE_SCATTERING,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
montecarlo_globals.DISABLE_LINE_SCATTERING,
)

# If continuum processes: update continuum estimators
Expand All @@ -169,7 +166,7 @@ def single_packet_loop(
distance,
time_explosion,
estimators,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)
move_packet_across_shell_boundary(
r_packet, delta_shell, len(numba_radial_1d_geometry.r_inner)
Expand All @@ -182,26 +179,26 @@ def single_packet_loop(
distance,
time_explosion,
estimators,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)
line_scatter(
r_packet,
time_explosion,
line_interaction_type,
opacity_state,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)
trace_vpacket_volley(
r_packet,
vpacket_collection,
numba_radial_1d_geometry,
time_explosion,
opacity_state,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
montecarlo_configuration.VPACKET_TAU_RUSSIAN,
montecarlo_configuration.SURVIVAL_PROBABILITY,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
)

elif interaction_type == InteractionType.ESCATTERING:
Expand All @@ -212,12 +209,12 @@ def single_packet_loop(
distance,
time_explosion,
estimators,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)
thomson_scatter(
r_packet,
time_explosion,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)

trace_vpacket_volley(
Expand All @@ -226,10 +223,10 @@ def single_packet_loop(
numba_radial_1d_geometry,
time_explosion,
opacity_state,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
montecarlo_configuration.VPACKET_TAU_RUSSIAN,
montecarlo_configuration.SURVIVAL_PROBABILITY,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
)
elif (
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED
Expand All @@ -241,7 +238,7 @@ def single_packet_loop(
distance,
time_explosion,
estimators,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)
continuum_event(
r_packet,
Expand All @@ -251,8 +248,8 @@ def single_packet_loop(
chi_ff,
chi_bf_contributions,
current_continua,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
)

trace_vpacket_volley(
Expand All @@ -261,14 +258,14 @@ def single_packet_loop(
numba_radial_1d_geometry,
time_explosion,
opacity_state,
montecarlo_configuration.ENABLE_FULL_RELATIVITY,
montecarlo_globals.ENABLE_FULL_RELATIVITY,
montecarlo_configuration.VPACKET_TAU_RUSSIAN,
montecarlo_configuration.SURVIVAL_PROBABILITY,
montecarlo_configuration.CONTINUUM_PROCESSES_ENABLED,
montecarlo_globals.CONTINUUM_PROCESSES_ENABLED,
)
else:
pass
if montecarlo_configuration.ENABLE_RPACKET_TRACKING:
if montecarlo_globals.ENABLE_RPACKET_TRACKING:
rpacket_tracker.track(r_packet)


Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/tests/test_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy.testing as npt
import numpy as np
import tardis.transport.montecarlo.interaction as interaction
from tardis.transport.montecarlo.numba_config import (
from tardis.transport.montecarlo.configuration.constants import (
LineInteractionType,
)

Expand Down
Loading

0 comments on commit 61bea81

Please sign in to comment.