Skip to content

Commit

Permalink
[BC] ModuleWorker and ModuleWorkerUtility (#387)
Browse files Browse the repository at this point in the history
ModuleWorker class which distributes the ModuleExplorers over
multiple modules.
  • Loading branch information
tvmarino authored Oct 29, 2024
1 parent 80b39a8 commit 1047803
Show file tree
Hide file tree
Showing 2 changed files with 445 additions and 11 deletions.
294 changes: 287 additions & 7 deletions compiler_opt/rl/generate_bc_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,21 @@
# limitations under the License.
"""Module for running compilation and collect data for behavior cloning."""

from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator
import gin
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Generator, Union

from absl import logging
import bisect
import dataclasses
import os
import shutil
import timeit

import math
import numpy as np
import tensorflow as tf
from tf_agents import policies
from tf_agents.typing import types as tf_types
from tf_agents.trajectories import policy_step
from tf_agents.trajectories import time_step
from tf_agents.specs import tensor_spec
Expand Down Expand Up @@ -139,6 +144,46 @@ def add_feature_list(seq_example: tf.train.SequenceExample,
add_function(seq_example, feature, feature_name)


def policy_action_wrapper(tf_policy) -> Callable[[Any], np.ndarray]:
"""Return a wrapper for a loaded policy action.
The returned function maps from an (optional) state to an np.array
that represents the action.
Args:
tf_policy: a policy (optionally can be tf_policy)
Returns:
wrap_function: function mapping a state to an np.array action.
"""

def wrap_function(*args, **kwargs):
return np.array(tf_policy.action(*args, **kwargs).action)

return wrap_function


def policy_distr_wrapper(
tf_policy: policies.TFPolicy
) -> Callable[[time_step.TimeStep, Optional[tf_types.NestedTensor]],
policy_step.PolicyStep]:
"""Return a wrapper for a loaded tf policy distribution.
The returned function maps from a state to a distribution over all actions.
Args:
tf_policy: A loaded tf policy.
Returns:
wrap_function: function mapping a state to a distribution over all actions.
"""

def wrap_function(*args, **kwargs) -> policy_step.PolicyStep:
return tf_policy.distribution(*args, **kwargs)

return wrap_function


class ExplorationWithPolicy:
"""Policy which selects states for exploration.
Expand Down Expand Up @@ -231,13 +276,11 @@ def get_advice(self, state: time_step.TimeStep) -> np.ndarray:
return policy_action


class ExplorationWorker(worker.Worker):
class ModuleExplorer:
"""Class which implements the exploration for the given module.
Attributes:
loaded_module_spec: the module to be compiled and explored
use_greedy: indicates if the default/greedy policy is used to compile the
module
env: MLGO environment.
exploration_frac: how often to explore in a trajectory
max_exploration_steps: maximum number of exploration steps
Expand Down Expand Up @@ -274,7 +317,7 @@ def __init__(

if reward_key == '':
raise TypeError(
'reward_key not specified in ExplorationWorker initialization.')
'reward_key not specified in ModuleExplorer initialization.')
self._reward_key = reward_key
kwargs.pop('reward_key', None)
self._working_dir = None
Expand Down Expand Up @@ -370,10 +413,10 @@ def explore_function(
Returns:
seq_example_list: a tf.train.SequenceExample list containing the all
trajectories from exploration.
trajectories from exploration.
working_dir_names: the directories of the compiled binaries
loss_idx: idx of the smallest loss trajectory in the seq_example_list.
base_seq_loss: loss of the trajectory compiled with policy.
base_seq_loss: loss of the best trajectory compiled with policy.
"""
seq_example_list = []
working_dir_names = []
Expand Down Expand Up @@ -529,3 +572,240 @@ def _process_obs(self, curr_obs, sequence_example):
curr_obs_feature, dtype=obs_dtype, name=curr_obs_feature_name)
add_feature_list(sequence_example, curr_obs_feature,
curr_obs_feature_name)


class ModuleWorkerResultProcessor:
"""Utility class to process ModuleExplorer results for ModuleWorker."""

def __init__(self, base_path: Optional[str] = None):
self._base_path = base_path

def _partition_for_loss(self, seq_example: tf.train.SequenceExample,
partitions: List[float], label_name: str):
"""Adds a feature to seq_example to partition the examples into buckets.
Given a tuple of partition limits (a_1, a_2, ..., a_n) we create n+1
bucekts with limits [0,a_1), [a_1, a_2), ..., [a_n-1, a_n), [a_n, +infty).
The i-th bucket contains all sequence examples with loss in (a_i-1, a_i].
Args:
seq_example: sequence example from the compiled module
partitions: a tuple of limits defining the buckets
label_name: name of the feature which will contain the bucket index."""
seq_loss = get_loss(seq_example)

label = bisect.bisect_right(partitions, seq_loss)
horizon = len(seq_example.feature_lists.feature_list['action'].feature)
label_list = [label for _ in range(horizon)]
add_feature_list(seq_example, label_list, label_name)

def process_succeeded(
self,
succeeded: List[Tuple[List, List[str], int, float]],
spec_name: str,
partitions: List[float],
label_name: str = 'label'
) -> Tuple[tf.train.SequenceExample, Dict[str, Union[str, float, int]], Dict[
str, Union[str, float, int]]]:
seq_example_list = [exploration_res[0] for exploration_res in succeeded]
working_dir_list = [(exploration_res[1], exploration_res[2])
for exploration_res in succeeded]
seq_example_losses = [exploration_res[3] for exploration_res in succeeded]

best_policy_idx = np.argmin(seq_example_losses)
best_exploration_idx = working_dir_list[best_policy_idx][1]

# comparator is the last policy in the policy_paths list
module_dict_pol = self._profiling_dict(spec_name, seq_example_list[-1][0])

module_dict_max = self._profiling_dict(
spec_name, seq_example_list[best_policy_idx][best_exploration_idx])

seq_example = seq_example_list[best_policy_idx][best_exploration_idx]
best_exploration_idxs = [
exploration_res[2] for exploration_res in succeeded
]
logging.info('best policy idx: %s, best exploration idxs %s',
best_policy_idx, best_exploration_idxs)

if self._base_path:
# as long as we have one process handles one module this can stay here
temp_working_dir_idx = working_dir_list[best_policy_idx][1]
temp_working_dir_list = working_dir_list[best_policy_idx][0]
temp_working_dir = temp_working_dir_list[temp_working_dir_idx]
self._save_binary(self._base_path, spec_name, temp_working_dir)

self._partition_for_loss(seq_example, partitions, label_name)

return seq_example, module_dict_max, module_dict_pol

def _profiling_dict(
self, module_name: str, feature_list: tf.train.SequenceExample
) -> Dict[str, Union[str, float, int]]:
"""Return a dictionary for the module containing the name, loss and horizon.
Args:
module_name: name of module
feature_list: tf.train.SequenceExample of the compiled module
Returns:
per_module_dict: dictionary containing the name, loss and horizon of
compiled module.
"""

per_module_dict = {
'module_name':
module_name,
'loss':
float(get_loss(feature_list)),
'horizon':
len(feature_list.feature_lists.feature_list['action'].feature),
}
return per_module_dict

def _save_binary(self, base_path: str, save_path: str, binary_path: str):
path_head_tail = os.path.split(save_path)
path_head = path_head_tail[0]
path_tail = path_head_tail[1]
save_dir = os.path.join(base_path, path_head)
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
shutil.copy(
os.path.join(binary_path, 'comp_binary'),
os.path.join(save_dir, path_tail))


@gin.configurable
class ModuleWorker(worker.Worker):
"""Class which sets up an exploration worker and processes the results.
Given a list of policies and an exploration policy, the class processes
modules, one a time, returning the maximum reward trajectory, where maximum
is taken over the list of policies together with exploration fascilitated
by the exploration policy if given.
Attributes:
clang_path: path to clang
mlgo_task: the type of compilation task
policy_paths: list of policies to load and use for forming the trajectories
exploration_policy_paths: list of policies to be used for exploration,
i-th policy in exploration_policy_paths explores when using i-th policy
in policy_paths for compilation
exploration_frac: how often to explore in a trajectory
max_exploration_steps: maximum number of exploration steps
tf_policy_action: list of the action/advice function from loaded policies
exploration_policy_paths: paths to load exploration policies.
explore_on_features: dict of feature names and functions which specify
when to explore on the respective feature
obs_action_specs: optional observation and action spec annotating TimeStep
base_path: root path to save best compiled binaries for linking
partitions: a tuple of limits defining the buckets, see partition_for_loss
env_args: additional arguments to pass to the InliningTask, used in creating
the environment. This has to include the reward_key
"""

def __init__(
# pylint: disable=dangerous-default-value
self,
clang_path: str = gin.REQUIRED,
mlgo_task: Type[env.MLGOTask] = gin.REQUIRED,
policy_paths: List[str] = gin.REQUIRED,
exploration_frac: float = gin.REQUIRED,
max_exploration_steps: int = 7,
exploration_policy_paths: Optional[str] = None,
explore_on_features: Optional[Dict[str, Callable[[tf.Tensor],
bool]]] = None,
obs_action_specs: Optional[Tuple[time_step.TimeStep,
tensor_spec.BoundedTensorSpec,]] = None,
base_path: Optional[str] = None,
partitions: List[float] = [
0.,
],
**envargs,
):
logging.info('Environment args: %s', envargs)
self._clang_path: str = clang_path
self._mlgo_task: Type[env.MLGOTask] = mlgo_task
self._policy_paths: List[str] = policy_paths
self._exploration_policy_paths: Optional[str] = exploration_policy_paths
self._exploration_frac: float = exploration_frac
self._max_exploration_steps: int = max_exploration_steps
self._tf_policy_action: List[Optional[Callable[[Any], np.ndarray]]] = []
self._exploration_policy_distrs: Optional[List[
Callable[[time_step.TimeStep, Optional[tf_types.NestedTensor]],
policy_step.PolicyStep]]] = None
self._explore_on_features: Optional[Dict[str, Callable[
[tf.Tensor], bool]]] = explore_on_features
self._obs_action_specs: Optional[Tuple[
time_step.TimeStep, tensor_spec.BoundedTensorSpec]] = obs_action_specs
self._mw_utility = ModuleWorkerResultProcessor(base_path)
self._partitions = partitions
self._envargs = envargs

for policy_path in policy_paths:
tf_policy = tf.saved_model.load(policy_path, tags=None, options=None)
self._tf_policy_action.append(policy_action_wrapper(tf_policy))
if exploration_policy_paths:
if len(exploration_policy_paths) != len(policy_paths):
raise AssertionError(
f'Number of exploration policies: {0},' \
f'does not match number of policies: {1}'
.format(len(exploration_policy_paths), len(policy_paths)))
self._exploration_policy_distrs = []
for exploration_policy_path in exploration_policy_paths:
expl_policy = tf.saved_model.load(
exploration_policy_path, tags=None, options=None)
self._exploration_policy_distrs.append(
policy_distr_wrapper(expl_policy))

def select_best_exploration(
self,
loaded_module_spec: corpus.LoadedModuleSpec,
) -> tf.train.SequenceExample:

num_calls = len(self._tf_policy_action)
time_call_compiler = 0
logging.info('Processing module: %s', loaded_module_spec.name)
start = timeit.default_timer()
work = list(zip(self._tf_policy_action, self._exploration_policy_distrs))
exploration_worker = ModuleExplorer(
loaded_module_spec=loaded_module_spec,
clang_path=self._clang_path,
mlgo_task=self._mlgo_task,
exploration_frac=self._exploration_frac,
max_exploration_steps=self._max_exploration_steps,
explore_on_features=self._explore_on_features,
obs_action_specs=self._obs_action_specs,
**self._envargs)
succeeded = []
for policy_action, explore_policy in work:
exploration_res = None
try:
exploration_res = exploration_worker.explore_function(
policy_action, explore_policy)
except Exception as e: # pylint: disable=broad-except
logging.info('Compilation exception %s at %s', e,
loaded_module_spec.name)
if exploration_res is not None:
succeeded.append(exploration_res)

end = timeit.default_timer()
time_call_compiler += end - start
logging.info('Processed module %s in time %s', loaded_module_spec.name,
time_call_compiler)
(seq_example, module_dict_max,
module_dict_pol) = self._mw_utility.process_succeeded(
succeeded, loaded_module_spec.name, self._partitions)

working_dir_list = [exploration_res[1] for exploration_res in succeeded]

for temp_dirs in working_dir_list:
for temp_dir in temp_dirs:
temp_dir_head = os.path.split(temp_dir)[0]
shutil.rmtree(temp_dir_head)

return (
num_calls,
module_dict_max,
module_dict_pol,
), seq_example.SerializeToString()
Loading

0 comments on commit 1047803

Please sign in to comment.