Skip to content

Commit

Permalink
Fix docstring formatting in compiler_opt/es
Browse files Browse the repository at this point in the history
This patch fixes some docstring issues in compiler_opt/es, like
missing/extra parameters, and reflows some text to make them more style
guide compliant.
  • Loading branch information
boomanaiden154 committed Sep 26, 2024
1 parent e0643d7 commit aba2cfb
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 16 deletions.
1 change: 0 additions & 1 deletion compiler_opt/es/blackbox_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def __init__(self,
policy_saver_fn: function to save a policy to cns
model_weights: the weights of the current model
config: configuration for blackbox optimization.
stubs: grpc stubs to inlining/regalloc servers
initial_step: the initial step for learning.
deadline: the deadline in seconds for requests to the inlining server.
"""
Expand Down
29 changes: 19 additions & 10 deletions compiler_opt/es/blackbox_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def filter_top_directions(
For antithetic, the total number of perturbations will
be 2* this number, because we count p, -p as a single
direction
Returns:
A pair (perturbations, function_values) consisting of the top perturbations.
function_values[i] is the reward of perturbations[i]
Expand Down Expand Up @@ -199,8 +200,6 @@ def get_hyperparameters(self) -> List[float]:
Returns the list of hyperparameters for blackbox function runs that can be
updated on the fly.
Args:
Returns:
The set of hyperparameters for blackbox function runs.
"""
Expand All @@ -212,8 +211,6 @@ def get_state(self) -> List[float]:
Returns the state of the optimizer.
Args:
Returns:
The state of the optimizer.
"""
Expand All @@ -227,8 +224,6 @@ def update_state(self, evaluation_stats: SequenceOfFloats) -> None:
Args:
evaluation_stats: stats from evaluation used to update hyperparameters
Returns:
"""
raise NotImplementedError('Abstract method')

Expand All @@ -240,8 +235,6 @@ def set_state(self, state: SequenceOfFloats) -> None:
Args:
state: state to be set up
Returns:
"""
raise NotImplementedError('Abstract method')

Expand Down Expand Up @@ -520,8 +513,12 @@ def monte_carlo_gradient(precision_parameter: float,
function_values: reward from perturbations (possibly normalized)
current_value: estimated reward at current point (possibly normalized)
energy: optional, for softmax weighting of the average (default = 0)
Returns:
The Monte Carlo gradient estimate.
Raises:
ValueError: When an invalid estimator type is specified.
"""
dim = len(perturbations[0])
b_vector = None
Expand Down Expand Up @@ -558,8 +555,12 @@ def sklearn_regression_gradient(clf: LinearModel, est_type: EstimatorType,
perturbations: the simulated perturbations
function_values: reward from perturbations (possibly normalized)
current_value: estimated reward at current point (possibly normalized)
Returns:
The regression estimate of the gradient.
Raises:
ValueError: When an invalid estimator type is specified.
"""
matrix = None
b_vector = None
Expand Down Expand Up @@ -639,6 +640,7 @@ def f(self, x: FloatArray) -> float:
Args:
x: numpy vector
Returns:
Scalar f(x)
"""
Expand All @@ -649,6 +651,7 @@ def grad(self, x: FloatArray) -> FloatArray:
Args:
x: input vector
Returns:
A vector of the same dimension as x, the gradient of the quadratic at x.
"""
Expand Down Expand Up @@ -729,7 +732,8 @@ def make_projector(radius: float) -> Callable[[FloatArray], FloatArray]:
"""Makes an L2 projector function centered at origin.
Args:
radius: the radius to project on
radius: the radius to project on.
Returns:
A function of one argument that projects onto L2 ball.
"""
Expand Down Expand Up @@ -946,6 +950,7 @@ def trust_region_test(self, current_input: FloatArray,
Args:
current_input: the weights of current candidate point
current_value: the reward of the current point
Returns:
TRUE if the step is accepted
FALSE is the step is rejected
Expand Down Expand Up @@ -1015,7 +1020,7 @@ def update_hessian_part(self, perturbations: FloatArray2D,
See run_step() for a description of arguments.
Args:
perturbations:
perturbations: The perturbations to process.
function_values: (possibly normalized) function values
current_value: (possibly normalized) current value, used as the
Gaussian smoothing estimate if current_point_estimate
Expand Down Expand Up @@ -1071,6 +1076,7 @@ def hessv_func(x: FloatArray) -> FloatArray:
Args:
x: the direction to evaluate the product, i.e Hx
Returns:
Hessian-vector product.
"""
Expand All @@ -1087,6 +1093,7 @@ def hessv_func(x: FloatArray) -> FloatArray:
Args:
x: the direction to evaluate the product, i.e Hx
Returns:
Hessian-vector product.
"""
Expand Down Expand Up @@ -1126,6 +1133,7 @@ def update_quadratic_model(self, perturbations: FloatArray2D,
current_value: unnormalized reward of the current policy
is_update: whether the previous step was rejected and this is the same
as the last accepted policy
Returns:
A QuadraticModel object with the local quadratic model after the updates.
"""
Expand Down Expand Up @@ -1167,6 +1175,7 @@ def run_step(self, perturbations: FloatArray2D, function_values: FloatArray,
function_values: list of scalars, reward corresponding to perturbation
current_input: numpy vector, current model weights
current_value: scalar, reward of current model
Returns:
updated model weights
"""
Expand Down
18 changes: 13 additions & 5 deletions compiler_opt/es/policy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,11 @@ def get_vectorized_parameters_from_policy(
def set_vectorized_parameters_for_policy(
policy: 'tf_policy.TFPolicy | HasModelVariables',
parameters: npt.NDArray[np.float32]) -> None:
"""Separates values in parameters into the policy's shapes
and sets the policy variables to those values"""
"""Separates values in parameters.
Packs parameters into the policy's shapes and sets the policy variables to
those values.
"""
if isinstance(policy, tf_policy.TFPolicy):
variables = policy.variables()
elif hasattr(policy, 'model_variables'):
Expand All @@ -108,9 +111,14 @@ def set_vectorized_parameters_for_policy(
def save_policy(policy: 'tf_policy.TFPolicy | HasModelVariables',
parameters: npt.NDArray[np.float32], save_folder: str,
policy_name: str) -> None:
"""Assigns a policy the name policy_name
and saves it to the directory of save_folder
with the values in parameters."""
"""Assigns a policy a name and writes it to disk.
Args:
policy: The policy to save.
parameters: The model weights for the policy.
save_folder: The location to save the policy to.
policy_name: The value to name the policy.
"""
set_vectorized_parameters_for_policy(policy, parameters)
saver = policy_saver.PolicySaver({policy_name: policy})
saver.save(save_folder)

0 comments on commit aba2cfb

Please sign in to comment.