Skip to content

Commit

Permalink
fix bug and test additional case
Browse files Browse the repository at this point in the history
  • Loading branch information
jvshields committed Jun 13, 2024
1 parent 7cfa55f commit 6c2ea1a
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
9 changes: 7 additions & 2 deletions stardis/io/model/tests/test_model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,14 @@ def test_marcs_model(marcs_model):

def test_rescale_nuclide_mass_fraction(example_stellar_model):
rescaled = rescale_nuclide_mass_fractions(
example_stellar_model.composition.nuclide_mass_fraction, [5], [0.5]
example_stellar_model.composition.nuclide_mass_fraction, [4, 5], [1.1, 0.8]
)
assert np.allclose(
rescaled.loc[5].values,
example_stellar_model.composition.nuclide_mass_fraction.loc[5].values * 0.5,
example_stellar_model.composition.nuclide_mass_fraction.loc[5].values * 0.8,
)

assert np.allclose(
rescaled.loc[5].values,
example_stellar_model.composition.nuclide_mass_fraction.loc[4].values * 1.1,
)
2 changes: 1 addition & 1 deletion stardis/io/model/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def rescale_nuclide_mass_fractions(nuclide_mass_fraction, nuclides, scale_factor
if nuclide not in new_mass_fractions.columns:
raise ValueError(f"{nuclide} not available in the simulation")

new_mass_fractions[nuclides] = new_mass_fractions[nuclides] * scale_factor
new_mass_fractions[nuclide] = new_mass_fractions[nuclide] * scale_factor

return new_mass_fractions.T.div(
new_mass_fractions.T.sum(axis=0)
Expand Down

0 comments on commit 6c2ea1a

Please sign in to comment.