Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiang Song committed Oct 1, 2024
1 parent 205ffa7 commit 32ee866
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/graphstorm/config/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1726,7 +1726,7 @@ def wd_l2norm(self):
try:
wd_l2norm = float(self._wd_l2norm)
except:
raise ValueError("wd-l2norm must be a floating point " \
raise ValueError("wd_l2norm must be a floating point " \
f"but get {self._wd_l2norm}")
return wd_l2norm
return 0
Expand All @@ -1740,7 +1740,12 @@ def alpha_l2norm(self):
"""
# pylint: disable=no-member
if hasattr(self, "_alpha_l2norm"):
return self._alpha_l2norm
try:
alpha_l2norm = float(self._alpha_l2norm)
except:
raise ValueError("alpha_l2norm must be a floating point " \
f"but get {self._alpha_l2norm}")
return alpha_l2norm
return .0

@property
Expand Down
8 changes: 8 additions & 0 deletions tests/unit-tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def create_train_config(tmp_path, file_name):
"topk_model_to_save": 4,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": 5e-5,
"alpha_l2norm": 5e-5,
}
with open(os.path.join(tmp_path, file_name+"1.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand All @@ -263,6 +264,7 @@ def create_train_config(tmp_path, file_name):
"topk_model_to_save": 5,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": "1e-3",
"alpha_l2norm": "1e-3",
}
with open(os.path.join(tmp_path, file_name+"2.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand Down Expand Up @@ -294,6 +296,7 @@ def create_train_config(tmp_path, file_name):
"early_stop_burnin_rounds": -1,
"early_stop_rounds": 0,
"wd_l2norm": "NA",
"alpha_l2norm": "NA",
}

with open(os.path.join(tmp_path, file_name+"_fail.yaml"), "w") as f:
Expand All @@ -305,6 +308,7 @@ def create_train_config(tmp_path, file_name):
"topk_model_to_save": 3,
"save_model_path": os.path.join(tmp_path, "save"),
"wd_l2norm": "",
"alpha_l2norm": "",
}
with open(os.path.join(tmp_path, file_name+"_fail1.yaml"), "w") as f:
yaml.dump(yaml_object, f)
Expand Down Expand Up @@ -355,13 +359,15 @@ def test_train_info():
config = GSConfig(args)
assert config.topk_model_to_save == 4
assert config.wd_l2norm == 5e-5
assert config.alpha_l2norm == 5e-5

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test2.yaml'), local_rank=0)
config = GSConfig(args)
assert config.eval_frequency == 1000
assert config.save_model_frequency == 2000
assert config.topk_model_to_save == 5
assert config.wd_l2norm == 1e-3
assert config.alpha_l2norm == 1e-3

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test3.yaml'), local_rank=0)
config = GSConfig(args)
Expand All @@ -387,13 +393,15 @@ def test_train_info():
check_failure(config, "early_stop_burnin_rounds")
check_failure(config, "early_stop_rounds")
check_failure(config, "wd_l2norm")
check_failure(config, "alpha_l2norm")

args = Namespace(yaml_config_file=os.path.join(Path(tmpdirname), 'train_test_fail1.yaml'), local_rank=0)
config = GSConfig(args)
# in PR # 893 we loose the constraints of model saving frequency and eval frequency
# so here we do not check failure, but check the topk model argument
assert config.topk_model_to_save == 3
check_failure(config, "wd_l2norm")
check_failure(config, "alpha_l2norm")

def create_rgcn_config(tmp_path, file_name):
yaml_object = create_dummpy_config_obj()
Expand Down

0 comments on commit 32ee866

Please sign in to comment.