diff --git a/DINCAE.py b/DINCAE.py index 87584cc..61cb7e1 100644 --- a/DINCAE.py +++ b/DINCAE.py @@ -265,7 +265,7 @@ def reconstruct(lon,lat,mask,meandata, * `resize_method`: one of the resize methods defined in [TensorFlow](https://www.tensorflow.org/api_docs/python/tf/image/resize_images) * `epochs`: number of epochs for training the neural network * `batch_size`: size of a mini-batch - * `save_each`: reconstruct the missing data every `save_each` epoch + * `save_each`: reconstruct the missing data every `save_each` epoch. Repeated saving is disabled if `save_each` is zero. The last epoch is always saved. * `save_model_each`: save a checkpoint of the neural network every `save_model_each` epoch * `skipconnections`: list of indices of convolutional layers with @@ -281,6 +281,8 @@ def reconstruct(lon,lat,mask,meandata, * `regularization_L2_beta`: scalar to enforce L2 regularization on the weight """ + print("regularization_L2_beta ",regularization_L2_beta) + print("enc_ksize_internal ",enc_ksize_internal) enc_ksize = [nvar] + enc_ksize_internal if not os.path.isdir(outdir): @@ -519,7 +521,7 @@ def reconstruct(lon,lat,mask,meandata, "Training loss: {:.4f}".format(batch_cost), "RMS: {:.4f}".format(batch_RMS)) - if e % save_each == 0: + if (e == epochs-1) or ((save_each > 0) and (e % save_each == 0)): print("Save output",e) timestr = datetime.now().strftime("%Y-%m-%dT%H%M%S") @@ -539,7 +541,7 @@ def reconstruct(lon,lat,mask,meandata, savesample(fname,batch_m_rec,batch_σ2_rec,meandata,lon,lat,e,ii, offset) - if e % save_model_each == 0: + if (save_model_each > 0) and (e % save_model_each == 0): save_path = saver.save(sess, os.path.join( outdir,"model-{:03d}.ckpt".format(e+1)))