Skip to content

Commit

Permalink
train_diagnostics_fig: add var
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings committed Nov 29, 2023
1 parent 6120b20 commit 546285f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/mpol/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,8 @@ def split_diagnostics_fig(splitter, channel=0, save_prefix=None):
return fig, axes


def train_diagnostics_fig(model, losses=None, learn_rates=None, fluxes=None, old_model_image=None,
def train_diagnostics_fig(model, losses=None, learn_rates=None, fluxes=None,
old_model_image=None, old_model_epoch=None,
kfold=None, epoch=None,
channel=0, save_prefix=None):
"""
Expand All @@ -369,7 +370,9 @@ def train_diagnostics_fig(model, losses=None, learn_rates=None, fluxes=None, old
fluxes : list
Total flux in model image at each epoch in the training loop
old_model_image : 2D image array, default=None
Model image of a previous epoch for comparison to current image
Model image of a previous epoch for comparison to current image
old_model_epoch : int
Epoch of `old_model_image`
kfold : int, default=None
Current cross-validation k-fold
epoch : int, default=None
Expand All @@ -389,7 +392,7 @@ def train_diagnostics_fig(model, losses=None, learn_rates=None, fluxes=None, old
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(8, 8))
axes[1][1].remove()

fig.suptitle(f"Pixel size {model.coords.cell_size * 1e3:.2f} mas, Npix {model.coords.npix}\nk-fold {kfold}, epoch {epoch}", fontsize=10)
fig.suptitle(f"Pixel size {model.coords.cell_size * 1e3:.2f} mas, N_pix {model.coords.npix}\nk-fold {kfold}, epoch {epoch}", fontsize=10)

mod_im = torch2npy(model.icube.sky_cube[channel])
mod_grad = torch2npy(packed_cube_to_sky_cube(model.bcube.base_cube.grad)[channel])
Expand All @@ -416,7 +419,7 @@ def train_diagnostics_fig(model, losses=None, learn_rates=None, fluxes=None, old
diff_image = mod_im - old_model_image
diff_im_norm = get_image_cmap_norm(diff_image, symmetric=True)
plot_image(diff_image, extent, cmap='RdBu_r', ax=ax, xlab='', ylab='', norm=diff_im_norm)
ax.set_title("Difference image", fontsize=10)
ax.set_title(f"Difference (epoch {epoch} - {old_model_epoch})", fontsize=10)

if losses is not None:
# loss function
Expand Down

0 comments on commit 546285f

Please sign in to comment.