diff --git a/spyro/io/basicio.py b/spyro/io/basicio.py index 887742bf..c3d4bd74 100644 --- a/spyro/io/basicio.py +++ b/spyro/io/basicio.py @@ -94,7 +94,7 @@ def switch_serial_shot(wave, propagation_id): for array_i, array in enumerate(stacked_shot_arrays): wave.forward_solution[array_i].dat.data[:] = array wave.forward_solution_receivers = np.load(f"tmp_rec{propagation_id}_comm{spatialcomm}"+id_str+".npy") - + wave.receivers_output = wave.forward_solution_receivers def ensemble_gradient(func): """Decorator for gradient to distribute shots for ensemble parallelism""" diff --git a/temp_test_serialshots_grad.py b/temp_test_serialshots_grad.py index 8f2ea6e8..d23e93df 100644 --- a/temp_test_serialshots_grad.py +++ b/temp_test_serialshots_grad.py @@ -9,9 +9,12 @@ from firedrake import File import firedrake as fire import spyro +import warnings -def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False): +warnings.filterwarnings("ignore") + +def check_gradient(Wave_obj_guess, dJ, rec_out_exact_list, Jm_list, plot=False): steps = [1e-3, 1e-4, 1e-5] # step length errors = [] @@ -25,14 +28,16 @@ def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False): for step in steps: - Wave_obj_guess.reset_pressure() - c_guess = fire.Constant(2.0) + step*dm - Wave_obj_guess.initial_velocity_model = c_guess - Wave_obj_guess.forward_solve() - misfit_plusdm = rec_out_exact - Wave_obj_guess.receivers_output - J_plusdm = spyro.utils.compute_functional(Wave_obj_guess, misfit_plusdm) + grad_fd = 0.0 + for snum in range(Wave_obj_guess.number_of_sources): + Wave_obj_guess.reset_pressure() + c_guess = fire.Constant(2.0) + step*dm + Wave_obj_guess.initial_velocity_model = c_guess + Wave_obj_guess.forward_solve() + misfit_plusdm = rec_out_exact_list[snum] - Wave_obj_guess.receivers_output + J_plusdm = spyro.utils.compute_functional(Wave_obj_guess, misfit_plusdm) - grad_fd = (J_plusdm - Jm) / (step) + grad_fd += (J_plusdm - Jm_list[snum]) / (step) projnorm = fire.assemble(dJ * dm * fire.dx(scheme=Wave_obj_guess.quadrature_rule)) error = 100 * ((grad_fd - projnorm) / projnorm) @@ -122,48 +127,55 @@ def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False): } -def get_forward_model(load_true=False): - if load_true is False: - Wave_obj_exact = spyro.AcousticWave(dictionary=dictionary) - Wave_obj_exact.set_mesh(mesh_parameters={"dx": 0.1}) - # Wave_obj_exact.set_initial_velocity_model(constant=3.0) - cond = fire.conditional(Wave_obj_exact.mesh_z > -1.5, 1.5, 3.5) - Wave_obj_exact.set_initial_velocity_model( - conditional=cond, - # output=True - ) - # spyro.plots.plot_model(Wave_obj_exact, abc_points=[(-1, 1), (-2, 1), (-2, 4), (-1, 2)]) - Wave_obj_exact.forward_solve() - # forward_solution_exact = Wave_obj_exact.forward_solution - rec_out_exact = Wave_obj_exact.receivers_output - # np.save("rec_out_exact", rec_out_exact) - - else: - rec_out_exact = np.load("rec_out_exact.npy") +def get_forward_model(): + + print(f"Calculating exact", flush=True) + Wave_obj_exact = spyro.AcousticWave(dictionary=dictionary) + Wave_obj_exact.set_mesh(mesh_parameters={"dx": 0.1}) + + cond = fire.conditional(Wave_obj_exact.mesh_z > -1.5, 1.5, 3.5) + Wave_obj_exact.set_initial_velocity_model( + conditional=cond, + ) + + Wave_obj_exact.forward_solve() + print(f"Calculating guess", flush=True) Wave_obj_guess = spyro.AcousticWave(dictionary=dictionary) Wave_obj_guess.set_mesh(mesh_parameters={"dx": 0.1}) Wave_obj_guess.set_initial_velocity_model(constant=2.0) Wave_obj_guess.forward_solve() - rec_out_guess = Wave_obj_guess.receivers_output - return rec_out_exact, rec_out_guess, Wave_obj_guess + rec_exact_list = [] + rec_guess_list = [] + print(f"Sending shot records and guess object", flush=True) + for propagation_id in range(Wave_obj_exact.number_of_sources): + spyro.io.switch_serial_shot(Wave_obj_exact, propagation_id) + rec_exact_list.append(Wave_obj_exact.receivers_output) + rec_guess_list.append(Wave_obj_guess.receivers_output) + return rec_exact_list, rec_guess_list, Wave_obj_guess -def test_gradient_supershot(): - rec_out_exact, rec_out_guess, Wave_obj_guess = get_forward_model(load_true=False) - misfit = rec_out_exact - rec_out_guess +def test_gradient_serialshots(): + print(f"Starting", flush=True) + rec_exact_list, rec_guess_list, Wave_obj_guess = get_forward_model() - Jm = spyro.utils.compute_functional(Wave_obj_guess, misfit) - print(f"Cost functional : {Jm}", flush=True) + Jm_list = [] + print(f"Saving cost functionals", flush=True) + for propagation_id in range(Wave_obj_guess.number_of_sources): + misfit = rec_exact_list[propagation_id] - rec_guess_list[propagation_id] + Jm = spyro.utils.compute_functional(Wave_obj_guess, misfit) + print(f"Cost functional : {Jm}", flush=True) + Jm_list.append(Jm) # compute the gradient of the control (to be verified) + print(f"Gradient calculation", flush=True) dJ = Wave_obj_guess.gradient_solve() File("gradient.pvd").write(dJ) - check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=True) + check_gradient(Wave_obj_guess, dJ, rec_exact_list, Jm_list, plot=True) if __name__ == "__main__": - test_gradient_supershot() + test_gradient_serialshots() diff --git a/temp_test_supershot_grady.py b/temp_test_supershot_grady.py index e2a7dadd..3fb80444 100644 --- a/temp_test_supershot_grady.py +++ b/temp_test_supershot_grady.py @@ -1,7 +1,7 @@ -from mpi4py.MPI import COMM_WORLD -import debugpy -debugpy.listen(3000 + COMM_WORLD.rank) -debugpy.wait_for_client() +# from mpi4py.MPI import COMM_WORLD +# import debugpy +# debugpy.listen(3000 + COMM_WORLD.rank) +# debugpy.wait_for_client() import numpy as np import math import matplotlib.pyplot as plt @@ -17,10 +17,10 @@ def check_gradient(Wave_obj_guess, dJ, rec_out_exact, Jm, plot=False): errors = [] V_c = Wave_obj_guess.function_space dm = fire.Function(V_c) - # size, = np.shape(dm.dat.data[:]) - # dm_data = np.random.rand(size) + size, = np.shape(dm.dat.data[:]) + dm_data = np.random.rand(size) # np.save(f"dmdata{COMM_WORLD.rank}", dm_data) - dm_data = np.load(f"dmdata{COMM_WORLD.rank}.npy") + # dm_data = np.load(f"dmdata{COMM_WORLD.rank}.npy") dm.dat.data[:] = dm_data # dm.assign(dJ)