Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/tutorial_and_fixes' into tutoria…
Browse files Browse the repository at this point in the history
…l_and_fixes
  • Loading branch information
v1docq committed Oct 24, 2024
2 parents 4705b83 + e1aabfc commit 96e0324
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ def _get_parameters_from_trial(self, graph: OptGraph, trial: Trial) -> dict:
sampling_scope = parameter_properties.get('sampling-scope')
if parameter_type == 'discrete':
new_parameters.update({node_op_parameter_name:
trial.suggest_int(node_op_parameter_name, *sampling_scope)})
trial.suggest_int(node_op_parameter_name, *sampling_scope)})
elif parameter_type == 'continuous':
new_parameters.update({node_op_parameter_name:
trial.suggest_float(node_op_parameter_name, *sampling_scope)})
trial.suggest_float(node_op_parameter_name, *sampling_scope)})
elif parameter_type == 'categorical':
new_parameters.update({node_op_parameter_name:
trial.suggest_categorical(node_op_parameter_name, *sampling_scope)})
trial.suggest_categorical(node_op_parameter_name, *sampling_scope)})
return new_parameters

def _get_initial_point(self, graph: OptGraph) -> Tuple[dict, bool]:
Expand All @@ -125,7 +125,7 @@ def _get_initial_point(self, graph: OptGraph) -> Tuple[dict, bool]:
if tunable_node_params:
has_parameters_to_optimize = True
tunable_initial_params = {get_node_operation_parameter_label(node_id, operation_name, p):
node.parameters[p] for p in node.parameters if p in tunable_node_params}
node.parameters[p] for p in node.parameters if p in tunable_node_params}
if tunable_initial_params:
initial_parameters.update(tunable_initial_params)
return initial_parameters, has_parameters_to_optimize
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -622,4 +622,3 @@ def has_no_data_flow_conflicts_in_industrial_pipeline(pipeline: Pipeline):
def _crossover_by_type(self, crossover_type: CrossoverTypesEnum) -> None:
IndustrialCrossover()
return None

0 comments on commit 96e0324

Please sign in to comment.