Skip to content

Commit

Permalink
ENH: Adds selective hyperparameter optimization
Browse files Browse the repository at this point in the history
Also fixes double-dipping problems in model fitting
  • Loading branch information
itellaetxe committed Oct 16, 2024
1 parent 5bedacc commit b309449
Show file tree
Hide file tree
Showing 5 changed files with 261 additions and 162 deletions.
58 changes: 35 additions & 23 deletions src/ageml/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def configure_parser(self):
self.parser.add_argument(
"-ht",
"--hyperparameter_tuning",
nargs=1,
nargs="+",
default=["0"],
help=messages.hyperparameter_grid_description,
)
Expand Down Expand Up @@ -164,12 +164,28 @@ def configure_args(self, args):
else:
args.model_params = {}

# Set hyperparameter grid search value
if len(args.hyperparameter_tuning) > 1 or not args.hyperparameter_tuning[0].isdigit():
# Parse hyperparameter_tuning values
hyperparam_tuning = args.hyperparameter_tuning
if not hyperparam_tuning[0].isdigit():
raise ValueError("Hyperparameter grid points must be a non negative integer.")
else:
args.hyperparameter_tuning = args.hyperparameter_tuning[0]
args.hyperparameter_tuning = int(convert(args.hyperparameter_tuning))
args.hyperparameter_tuning = int(convert(hyperparam_tuning[0]))

hyperparameter_params = {}
if len(hyperparam_tuning) > 1:
for item in hyperparam_tuning[1:]:
if item.count("=") != 1:
err_msg = (
"Hyperparameter tuning parameters must be in the format "
"param1=value1_low,value1_high param2=value2_low, value2_high..."
)
raise ValueError(err_msg)
key, value = item.split("=")
low, high = value.split(",")
hyperparameter_params[key] = [convert(low), convert(high)]
# Add attribute to args
args.hyperparameter_params = hyperparameter_params

# Set polynomial feature extension value
if len(args.feature_extension) > 1 or not args.feature_extension[0].isdigit():
raise ValueError("Polynomial feature extension degree must be a non negative integer.")
Expand Down Expand Up @@ -227,13 +243,11 @@ def configure_parser(self):
help=messages.factors_long_description,
)

self.parser.add_argument("--covariates", metavar="FILE", required=False,
help=messages.covar_long_description)
self.parser.add_argument("--clinical", metavar="FILE", required=False,
help=messages.clinical_long_description)
self.parser.add_argument("--covcorr_mode", metavar="MODE", required=False,
choices=["cn", "each", "all"],
help=messages.covcorr_mode_long_description)
self.parser.add_argument("--covariates", metavar="FILE", required=False, help=messages.covar_long_description)
self.parser.add_argument("--clinical", metavar="FILE", required=False, help=messages.clinical_long_description)
self.parser.add_argument(
"--covcorr_mode", metavar="MODE", required=False, choices=["cn", "each", "all"], help=messages.covcorr_mode_long_description
)


class ClinicalGroups(Interface):
Expand Down Expand Up @@ -284,11 +298,10 @@ def configure_parser(self):
)

# Optional arguments
self.parser.add_argument("--covariates", metavar="FILE", required=False,
help=messages.covar_long_description)
self.parser.add_argument("--covcorr_mode", metavar="MODE", required=False,
choices=["cn", "each", "all"],
help=messages.covcorr_mode_long_description)
self.parser.add_argument("--covariates", metavar="FILE", required=False, help=messages.covar_long_description)
self.parser.add_argument(
"--covcorr_mode", metavar="MODE", required=False, choices=["cn", "each", "all"], help=messages.covcorr_mode_long_description
)


class ClinicalClassification(Interface):
Expand Down Expand Up @@ -372,12 +385,11 @@ def configure_parser(self):
)

# Optional arguments
self.parser.add_argument("--covariates", metavar="FILE", required=False,
help=messages.covar_long_description)
self.parser.add_argument("--covcorr_mode", metavar="MODE", required=False,
choices=["cn", "each", "all"],
help=messages.covcorr_mode_long_description)

self.parser.add_argument("--covariates", metavar="FILE", required=False, help=messages.covar_long_description)
self.parser.add_argument(
"--covcorr_mode", metavar="MODE", required=False, choices=["cn", "each", "all"], help=messages.covcorr_mode_long_description
)

def configure_args(self, args):
"""Configure argumens with required fromatting for modelling.
Expand Down
4 changes: 2 additions & 2 deletions src/ageml/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@

hyperparameter_grid_description = (
"Number of points for which the hyperparameter optimization Grid Search will train\n"
"a model. The parameter ranges are predefined. An integer is required.\n"
"(e.g. -ht 100 / --hyperparameter_tuning 100)"
"a model, and parameter ranges to sample from. An integer is required, followed \n"
"by the parameters to optimize. (e.g. -ht 100 C=1,2,3 kernel=linear,rbf)"
)

thr_long_description = "Threshold for classification. Default: 0.5 \n" "The threshold is used for assingning hard labels. (e.g. --thr 0.5)"
Expand Down
Loading

0 comments on commit b309449

Please sign in to comment.