Skip to content

Commit

Permalink
[Bugfix] Fix the case when the input value of wd_l2norm is using scie…
Browse files Browse the repository at this point in the history
…ntific notation.
  • Loading branch information
Xiang Song committed Oct 1, 2024
1 parent 0b48f4e commit 205ffa7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
7 changes: 6 additions & 1 deletion python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1723,7 +1723,12 @@ def wd_l2norm(self):
"""
# pylint: disable=no-member
if hasattr(self, "_wd_l2norm"):
return self._wd_l2norm
try:
wd_l2norm = float(self._wd_l2norm)
except:
raise ValueError("wd-l2norm must be a floating point " \
f"but get {self._wd_l2norm}")
return wd_l2norm
return 0

@property
Expand Down
8 changes: 8 additions & 0 deletions tests/unit-tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def create_train_config(tmp_path, file_name):
yaml_object["gsf"]["hyperparam"] = {
"topk_model_to_save": 4,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": 5e-5,
}
with open(os.path.join(tmp_path, file_name+"1.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand All @@ -261,6 +262,7 @@ def create_train_config(tmp_path, file_name):
'save_model_frequency': 2000,
"topk_model_to_save": 5,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": "1e-3",
}
with open(os.path.join(tmp_path, file_name+"2.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand Down Expand Up @@ -291,6 +293,7 @@ def create_train_config(tmp_path, file_name):
"use_early_stop": True,
"early_stop_burnin_rounds": -1,
"early_stop_rounds": 0,
"wd_l2norm": "NA",
}

with open(os.path.join(tmp_path, file_name+"_fail.yaml"), "w") as f:
Expand All @@ -301,6 +304,7 @@ def create_train_config(tmp_path, file_name):
'save_model_frequency': 2000,
"topk_model_to_save": 3,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": "",
}
with open(os.path.join(tmp_path, file_name+"_fail1.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand Down Expand Up @@ -350,12 +354,14 @@ def test_train_info():
args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test1.yaml'), local_rank=0)
config = GSConfig(args)
assert config.topk_model_to_save == 4
assert config.wd_l2norm == 5e-5

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test2.yaml'), local_rank=0)
config = GSConfig(args)
assert config.eval_frequency == 1000
assert config.save_model_frequency == 2000
assert config.topk_model_to_save == 5
assert config.wd_l2norm == 1e-3

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test3.yaml'), local_rank=0)
config = GSConfig(args)
Expand All @@ -380,12 +386,14 @@ def test_train_info():
check_failure(config, "topk_model_to_save")
check_failure(config, "early_stop_burnin_rounds")
check_failure(config, "early_stop_rounds")
check_failure(config, "wd_l2norm")

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test_fail1.yaml'), local_rank=0)
config = GSConfig(args)
# in PR # 893 we loose the constraints of model saving frequency and eval frequency
# so here we do not check failure, but check the topk model argument
assert config.topk_model_to_save == 3
check_failure(config, "wd_l2norm")

def create_rgcn_config(tmp_path, file_name):
yaml_object = create_dummpy_config_obj()
Expand Down

0 comments on commit 205ffa7

Please sign in to comment.