From 7abe3f36039f9f1e31a527cbc4b1c0e252d8c9d5 Mon Sep 17 00:00:00 2001 From: thevindu-w Date: Fri, 17 May 2024 14:29:23 +0530 Subject: [PATCH] Fix Python lint issues --- src_python/models/supervised.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src_python/models/supervised.py b/src_python/models/supervised.py index 591278606..fa36ff521 100644 --- a/src_python/models/supervised.py +++ b/src_python/models/supervised.py @@ -59,16 +59,28 @@ def initialize(self, **hyper_params): """ if 'batch_size' not in hyper_params: batch_size = 20 + else: + batch_size = hyper_params['batch_size'] if 'layer_sizes' not in hyper_params: - num_samples = [20, 10] - if 'num_samples' not in hyper_params: layer_sizes = [10, 10] + else: + layer_sizes = hyper_params['layer_sizes'] + if 'num_samples' not in hyper_params: + num_samples = [20, 10] + else: + num_samples = hyper_params['num_samples'] if 'bias' not in hyper_params: bias = True + else: + bias = hyper_params['bias'] if 'dropout' not in hyper_params: dropout = 0.1 + else: + dropout = hyper_params['dropout'] if 'lr' not in hyper_params: lr = 1e-2 + else: + lr = hyper_params['lr'] graph = sg.StellarGraph(nodes=self.nodes, edges=self.edges)