From 32ee8668e91d102f65f0effe6318ccec477ca3c7 Mon Sep 17 00:00:00 2001 From: Xiang Song Date: Tue, 1 Oct 2024 15:17:24 -0700 Subject: [PATCH] update --- python/graphstorm/config/argument.py | 9 +++++++-- tests/unit-tests/test_config.py | 8 ++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/python/graphstorm/config/argument.py b/python/graphstorm/config/argument.py index 728d3daad..0d7538bad 100644 --- a/python/graphstorm/config/argument.py +++ b/python/graphstorm/config/argument.py @@ -1726,7 +1726,7 @@ def wd_l2norm(self): try: wd_l2norm = float(self._wd_l2norm) except: - raise ValueError("wd-l2norm must be a floating point " \ + raise ValueError("wd_l2norm must be a floating point " \ f"but get {self._wd_l2norm}") return wd_l2norm return 0 @@ -1740,7 +1740,12 @@ def alpha_l2norm(self): """ # pylint: disable=no-member if hasattr(self, "_alpha_l2norm"): - return self._alpha_l2norm + try: + alpha_l2norm = float(self._alpha_l2norm) + except: + raise ValueError("alpha_l2norm must be a floating point " \ + f"but get {self._alpha_l2norm}") + return alpha_l2norm return .0 @property diff --git a/tests/unit-tests/test_config.py b/tests/unit-tests/test_config.py index 8e4650b93..c0f83203c 100644 --- a/tests/unit-tests/test_config.py +++ b/tests/unit-tests/test_config.py @@ -253,6 +253,7 @@ def create_train_config(tmp_path, file_name): "topk_model_to_save": 4, "save_model_path": os.path.join(tmp_path, "save"), "wd_l2norm": 5e-5, + "alpha_l2norm": 5e-5, } with open(os.path.join(tmp_path, file_name+"1.yaml"), "w") as f: yaml.dump(yaml_object, f) @@ -263,6 +264,7 @@ def create_train_config(tmp_path, file_name): "topk_model_to_save": 5, "save_model_path": os.path.join(tmp_path, "save"), "wd_l2norm": "1e-3", + "alpha_l2norm": "1e-3", } with open(os.path.join(tmp_path, file_name+"2.yaml"), "w") as f: yaml.dump(yaml_object, f) @@ -294,6 +296,7 @@ def create_train_config(tmp_path, file_name): "early_stop_burnin_rounds": -1, "early_stop_rounds": 0, "wd_l2norm": "NA", + "alpha_l2norm": "NA", } with open(os.path.join(tmp_path, file_name+"_fail.yaml"), "w") as f: @@ -305,6 +308,7 @@ def create_train_config(tmp_path, file_name): "topk_model_to_save": 3, "save_model_path": os.path.join(tmp_path, "save"), "wd_l2norm": "", + "alpha_l2norm": "", } with open(os.path.join(tmp_path, file_name+"_fail1.yaml"), "w") as f: yaml.dump(yaml_object, f) @@ -355,6 +359,7 @@ def test_train_info(): config = GSConfig(args) assert config.topk_model_to_save == 4 assert config.wd_l2norm == 5e-5 + assert config.alpha_l2norm == 5e-5 args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test2.yaml'), local_rank=0) config = GSConfig(args) @@ -362,6 +367,7 @@ def test_train_info(): assert config.save_model_frequency == 2000 assert config.topk_model_to_save == 5 assert config.wd_l2norm == 1e-3 + assert config.alpha_l2norm == 1e-3 args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test3.yaml'), local_rank=0) config = GSConfig(args) @@ -387,6 +393,7 @@ def test_train_info(): check_failure(config, "early_stop_burnin_rounds") check_failure(config, "early_stop_rounds") check_failure(config, "wd_l2norm") + check_failure(config, "alpha_l2norm") args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test_fail1.yaml'), local_rank=0) config = GSConfig(args) @@ -394,6 +401,7 @@ def test_train_info(): # so here we do not check failure, but check the topk model argument assert config.topk_model_to_save == 3 check_failure(config, "wd_l2norm") + check_failure(config, "alpha_l2norm") def create_rgcn_config(tmp_path, file_name): yaml_object = create_dummpy_config_obj()