Skip to content

Commit

Permalink
add reward quantile monitoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Yundi Qian committed Jan 20, 2022
1 parent 3fed8a1 commit 26cbe68
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
18 changes: 17 additions & 1 deletion compiler_opt/rl/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import abc
import time
from typing import Dict, Iterator, Tuple, Sequence

from typing import Iterator, Tuple, Dict
import numpy as np
from tf_agents.trajectories import trajectory

# Deadline for data collection.
Expand All @@ -31,6 +32,21 @@
# _DEADLINE_IN_SECONDS time.
WAIT_TERMINATION = ((0.9, 0), (0.8, 0.5), (0, 1))

DELTA = 0.01

REWARD_QUANTILE_MONITOR = (0.1, 0.5, 1, 2, 3, 4, 5, 6, 8, 10, 20, 30, 40, 50,
60, 70, 80, 90, 95, 99, 99.5, 99.9)


def build_distribution_monitor(data: Sequence[float]) -> Dict[str, float]:
quantiles = np.percentile(
data, REWARD_QUANTILE_MONITOR, interpolation='lower')
monitor_dict = {
f'p_{x}': y for (x, y) in zip(REWARD_QUANTILE_MONITOR, quantiles)
}
monitor_dict['mean'] = np.mean(data)
return monitor_dict


class DataCollector(metaclass=abc.ABCMeta):
"""Abstract class for data collection."""
Expand Down
5 changes: 5 additions & 0 deletions compiler_opt/rl/data_collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@

class DataCollectorTest(absltest.TestCase):

def test_build_distribution_monitor(self):
data = [3, 2, 1]
monitor_dict = data_collector.build_distribution_monitor(data)
self.assertDictContainsSubset({'mean': 2, 'p_0.1': 1}, monitor_dict)

@mock.patch('time.time')
def test_early_exit(self, mock_time):
mock_time.return_value = 0
Expand Down
7 changes: 7 additions & 0 deletions compiler_opt/rl/local_data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ def get_num_finished_work():

monitor_dict = {}
monitor_dict['default'] = {'success_modules': len(finished_work)}
rewards = [
1 - (res.get()[3] + data_collector.DELTA) /
(res.get()[1] + data_collector.DELTA) for (_, res) in successful_work
]
monitor_dict[
'reward_distribution'] = data_collector.build_distribution_monitor(
rewards)

return self._parser(sequence_examples), monitor_dict

Expand Down
8 changes: 4 additions & 4 deletions compiler_opt/rl/local_data_collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,14 @@ def _test_iterator_fn(data_list):
data_iterator, monitor_dict = collector.collect_data(policy_path='policy')
data = list(data_iterator)
self.assertEqual([1, 2, 3], data)
expected_monitor_dict = {'default': {'success_modules': 9}}
self.assertEqual(expected_monitor_dict, monitor_dict)
expected_monitor_dict_subset = {'default': {'success_modules': 9}}
self.assertDictContainsSubset(expected_monitor_dict_subset, monitor_dict)

data_iterator, monitor_dict = collector.collect_data(policy_path='policy')
data = list(data_iterator)
self.assertEqual([4, 5, 6], data)
expected_monitor_dict = {'default': {'success_modules': 9}}
self.assertEqual(expected_monitor_dict, monitor_dict)
expected_monitor_dict_subset = {'default': {'success_modules': 9}}
self.assertDictContainsSubset(expected_monitor_dict_subset, monitor_dict)

collector.close_pool()

Expand Down

0 comments on commit 26cbe68

Please sign in to comment.