Skip to content

Commit

Permalink
Add AD versions of these.
Browse files Browse the repository at this point in the history
  • Loading branch information
athas committed Oct 18, 2024
1 parent 6805553 commit b87d5ed
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
21 changes: 21 additions & 0 deletions rsbench/rsbench.fut
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
-- input @ data/small.in.gz output { 880018i64 }
-- input @ data/large.in.gz output { 358389i64 }

-- ==
-- entry: diff
-- input @ data/small.in.gz
-- input @ data/large.in.gz

type input =
{ lookups: i64,
doppler: i32
Expand Down Expand Up @@ -287,3 +292,19 @@ def main lookups doppler
let (input, sd) = unpack lookups doppler
n_windows poles_ls poles_cs windows_f64s windows_i32s pseudo_K0RS num_nucs mats concs
in #[unsafe] verification (run_event_based_simulation input.lookups input.doppler sd)

entry diff lookups doppler
n_windows poles_ls poles_cs windows_f64s windows_i32s pseudo_K0RS num_nucs mats concs =
let (input, sd) = unpack lookups doppler
n_windows poles_ls poles_cs windows_f64s windows_i32s pseudo_K0RS num_nucs mats concs
let diff_res = (vjp (run_event_based_simulation input.lookups input.doppler)
sd
(replicate input.lookups (1,1,1,1))).poles
in (map (map (.l_value)) diff_res,
map (map (.mp_ea.i)) diff_res,
map (map (.mp_ea.r)) diff_res,
map (map (.mp_ra.i)) diff_res,
map (map (.mp_ra.r)) diff_res,
map (map (.mp_rf.i)) diff_res,
map (map (.mp_rt.r)) diff_res,
)
30 changes: 28 additions & 2 deletions xsbench/xsbench.fut
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
-- no_rtx2080 no_k40 no_gtx780 input @ data/large.in.gz
-- output { 952131i64 }

-- ==
-- entry: diff
-- input @ data/small.in.gz
-- no_rtx2080 no_k40 no_gtx780 input @ data/large.in.gz


type nuclide_grid_point =
{ energy: f64,
total_xs: f64,
Expand Down Expand Up @@ -225,9 +231,29 @@ def unpack n_isotopes n_gridpoints grid_type hash_bins lookups
{num_nucs, concs, mats, nuclide_grid, index_grid, unionized_energy_array}
in (inputs, sd)

def main n_isotopes n_gridpoints grid_type hash_bins lookups
num_nucs concs mats nuclide_grid index_grid unionized_energy_array =
entry main n_isotopes n_gridpoints grid_type hash_bins lookups
num_nucs concs mats nuclide_grid index_grid unionized_energy_array =
let (inputs, sd) =
unpack n_isotopes n_gridpoints grid_type hash_bins lookups
num_nucs concs mats nuclide_grid index_grid unionized_energy_array
in #[unsafe] verification (run_event_based_simulation inputs sd)

-- Performs a single vjp pass with an all-unit seed vector. This is
-- unlikely to produce a useful gradient, but does show the overhead
-- of a single jvp invocation.
entry diff n_isotopes n_gridpoints grid_type hash_bins lookups
num_nucs concs mats nuclide_grid index_grid unionized_energy_array =
let (inputs, sd) =
unpack n_isotopes n_gridpoints grid_type hash_bins lookups
num_nucs concs mats nuclide_grid index_grid unionized_energy_array
let diff_res = #[unsafe]
(vjp (run_event_based_simulation inputs)
sd
(replicate inputs.lookups (1,1,1,1,1))).nuclide_grid
in (map (map (.absorbtion_xs)) diff_res,
map (map (.elastic_xs)) diff_res,
map (map (.energy)) diff_res,
map (map (.fission_xs)) diff_res,
map (map (.nu_fission_xs)) diff_res,
map (map (.total_xs)) diff_res,
)

0 comments on commit b87d5ed

Please sign in to comment.