From 051e3814ea09de4382c8e62facbc331d8ea88d43 Mon Sep 17 00:00:00 2001 From: Alexander Barth Date: Wed, 11 Dec 2019 20:32:56 +0100 Subject: [PATCH] set seed --- DINCAE.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/DINCAE.py b/DINCAE.py index 61cb7e1..f8140c4 100644 --- a/DINCAE.py +++ b/DINCAE.py @@ -240,7 +240,8 @@ def reconstruct(lon,lat,mask,meandata, nvar = 10, enc_ksize_internal = [16,24,36,54], clip_grad = 5.0, - regularization_L2_beta = 0 + regularization_L2_beta = 0, + iseed = None ): """ Train a neural network to reconstruct missing data using the training data set @@ -281,6 +282,11 @@ def reconstruct(lon,lat,mask,meandata, * `regularization_L2_beta`: scalar to enforce L2 regularization on the weight """ + if iseed != None: + np.random.seed(iseed) + tf.set_random_seed(np.random.randint(0,2**32-1)) + random.seed(np.random.randint(0,2**32-1)) + print("regularization_L2_beta ",regularization_L2_beta) print("enc_ksize_internal ",enc_ksize_internal) enc_ksize = [nvar] + enc_ksize_internal @@ -566,6 +572,7 @@ def reconstruct_gridded_nc(filename,varname,outdir, `DINCAE.load_gridded_nc` for the NetCDF format. """ + lon,lat,time,data,missing,mask = load_gridded_nc(filename,varname) train_datagen,train_len,meandata = data_generator( lon,lat,time,data,missing,