Skip to content

Commit

Permalink
finished simple tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Olender committed Jul 26, 2024
1 parent 1b38d3a commit f843412
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 23 deletions.
28 changes: 14 additions & 14 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Attach 0",
"type": "python",
"request": "attach",
"port": 3000,
"host": "localhost",
},
{
"name": "Python Attach 1",
"type": "python",
"request": "attach",
"port": 3001,
"host": "localhost"
},
// {
// "name": "Python Attach 0",
// "type": "python",
// "request": "attach",
// "port": 3000,
// "host": "localhost",
// },
// {
// "name": "Python Attach 1",
// "type": "python",
// "request": "attach",
// "port": 3001,
// "host": "localhost"
// },
{
"name": "Python Debugger: Current File",
"type": "debugpy",
Expand Down
15 changes: 13 additions & 2 deletions spyro/solvers/inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ..plots import plot_model as spyro_plot_model
from ..io.basicio import ensemble_shot_record
from ..io.basicio import switch_serial_shot
from ..io.basicio import load_shots, save_shots


try:
Expand Down Expand Up @@ -169,6 +170,9 @@ def __init__(self, dictionary=None, comm=None):
self.functional_history = []
self.control_out = fire.File("results/control.pvd")
self.gradient_out = fire.File("results/gradient.pvd")
self.real_shot_record_files = dictionary["inversion"].get("real_shot_record_files", None)
if self.real_shot_record_files is not None:
self.load_real_shot_record(filename=self.real_shot_record_files)

def calculate_misfit(self, c=None):
"""
Expand Down Expand Up @@ -199,7 +203,7 @@ def calculate_misfit(self, c=None):
self.misfit = self.real_shot_record - self.guess_shot_record
return self.misfit

def generate_real_shot_record(self, plot_model=False, filename="model.png", abc_points=None):
def generate_real_shot_record(self, plot_model=False, model_filename="model.png", abc_points=None, save_shot_record=True, shot_filename="shots/shot_record_"):
"""
Generates the real synthetic shot record. Only for use in synthetic test cases.
"""
Expand All @@ -210,9 +214,11 @@ def generate_real_shot_record(self, plot_model=False, filename="model.png", abc_
Wave_obj_real_velocity.initial_velocity_model = self.real_velocity_model

if plot_model and Wave_obj_real_velocity.comm.comm.rank == 0 and Wave_obj_real_velocity.comm.ensemble_comm.rank == 0:
spyro_plot_model(Wave_obj_real_velocity, filename=filename, abc_points=abc_points)
spyro_plot_model(Wave_obj_real_velocity, filename=model_filename, abc_points=abc_points)

Wave_obj_real_velocity.forward_solve()
if save_shot_record:
save_shots(Wave_obj_real_velocity, file_name=shot_filename)
self.real_shot_record = Wave_obj_real_velocity.real_shot_record
self.quadrature_rule = Wave_obj_real_velocity.quadrature_rule

Expand Down Expand Up @@ -560,6 +566,11 @@ def _apply_gradient_mask(self):
else:
pass

def load_real_shot_record(self, filename="shots/shot_record_"):
load_shots(self, file_name=filename)
self.real_shot_record = self.forward_solution_receivers
self.forward_solution_receivers = None


class SyntheticRealAcousticWave(AcousticWave):
"""
Expand Down
24 changes: 17 additions & 7 deletions test_polygon_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_polygon_vp():
assert test


def test_real_shot_record_generation_for_polygon():
def test_real_shot_record_generation_for_polygon_and_save_and_load():
dictionary = {}
dictionary["absorving_boundary_conditions"] = {
"pad_length": 2.0, # True or false
Expand All @@ -77,20 +77,30 @@ def test_real_shot_record_generation_for_polygon():
}
dictionary["time_axis"] = {
"final_time": 1.0, # Final time for event
"dt": 0.0001, # timestep size
"dt": 0.0005, # timestep size
"amplitude": 1, # the Ricker has an amplitude of 1.
"output_frequency": 500, # how frequently to output solution to pvds
# how frequently to save solution to RAM
"gradient_sampling_frequency": 1,
}
fwi = spyro.examples.Polygon_acoustic_FWI(dictionary=dictionary, periodic=True)
fwi.generate_real_shot_record(plot_model=True)
spyro.io.save_shots(fwi)
fwi.generate_real_shot_record(plot_model=True, save_shot_record=True)

print("END")
dictionary["inversion"] = {
"real_shot_record_files": "shots/shot_record_",
}
fwi2 = spyro.examples.Polygon_acoustic_FWI(dictionary=dictionary, periodic=True)

test1 = np.isclose(np.max(fwi2.real_shot_record[:, 0]), 0.18, atol=1e-2)
test2 = np.isclose(np.max(fwi2.real_shot_record[:, -1]), 0.0243, atol=1e-3)

test = all([test1, test2])

print(f"Correctly loaded shots: {test}")

assert test


if __name__ == "__main__":
# test_polygon_vp()
test_real_shot_record_generation_for_polygon()
test_polygon_vp()
test_real_shot_record_generation_for_polygon_and_save_and_load()

0 comments on commit f843412

Please sign in to comment.