diff --git a/src/certfuzz/campaign/campaign_base.py b/src/certfuzz/campaign/campaign_base.py index ab01c07..d823ccd 100644 --- a/src/certfuzz/campaign/campaign_base.py +++ b/src/certfuzz/campaign/campaign_base.py @@ -11,7 +11,6 @@ import shutil import tempfile import traceback -import cPickle as pickle import signal from certfuzz.campaign.errors import CampaignError @@ -25,6 +24,10 @@ import gc from certfuzz.config.simple_loader import load_and_fix_config from certfuzz.helpers.misc import import_module_by_name +from certfuzz.fuzztools.object_caching import dump_obj_to_file,\ + load_obj_from_file +import json +from certfuzz.fuzztools.filetools import write_file logger = logging.getLogger(__name__) @@ -120,7 +123,7 @@ def __init__(self, config_file, result_dir=None, debug=False): self.sf_set_out = os.path.join(self.outdir, 'seedfiles') if not self.cached_state_file: - cachefile = 'campaign_%s.pkl' % _campaign_id_with_underscores + cachefile = 'campaign_%s.json' % _campaign_id_with_underscores self.cached_state_file = os.path.join( self.work_dir_base, cachefile) if not self.seed_interval: @@ -165,7 +168,6 @@ def __enter__(self): if _result is not None: self = _result - self._read_state() self._check_prog() self._setup_workdir() self._set_fuzzer() @@ -173,6 +175,7 @@ def __enter__(self): self._check_runner() self._setup_output() self._create_seedfile_set() + self._read_state() _result = self._post_enter() if _result is not None: @@ -327,50 +330,133 @@ def _create_seedfile_set(self): outputpath=self.sf_set_out) as sfset: self.seedfile_set = sfset - @abc.abstractmethod - def __getstate__(self): - raise NotImplementedError - - @abc.abstractmethod - def __setstate__(self): - raise NotImplementedError + def _read_cached_data(self, cachefile): + try: + with open(cachefile, 'rb') as fp: + cached_data = json.load(fp) + except (IOError, ValueError) as e: + logger.info( + 'No cached campaign data found, will proceed as new campaign: %s', e) + return + return cached_data - def _read_state(self, cache_file=None): - if not cache_file: - cache_file = self.cached_state_file + def _restore_seedfile_scores(self, sf_scores): + for sf_md5, sf_score in sf_scores.iteritems(): + # is this seedfile still around? + try: + arm_to_update = self.seedfile_set.arms[sf_md5] + except KeyError: + # if not, just skip it + logger.warning( + 'Skipping seedfile score recovery for %s: maybe seedfile was removed?', sf_md5) + continue - if not os.path.exists(cache_file): - logger.info('No cached campaign found, using new campaign') - return + cached_successes = sf_score['successes'] + cached_trials = sf_score['trials'] - try: - with open(cache_file, 'rb') as fp: - campaign = pickle.load(fp) - except Exception, e: - logger.warning( - 'Unable to read %s, will use new campaign instead: %s', cache_file, e) - return + arm_to_update.update( + successes=cached_successes, trials=cached_trials) - if campaign: + def _restore_rangefinder_scores(self, rf_scores): + for sf_md5, rangelist in rf_scores.iteritems(): + # is this seedfile still around? try: - if self.config['config_timestamp'] != campaign.__dict__['config_timestamp']: - logger.warning( - 'Config file modified. Discarding cached campaign') - else: - self.__dict__.update(campaign.__dict__) - logger.info('Reloaded campaign from %s', cache_file) + sf_to_update = self.seedfile_set.things[sf_md5] except KeyError: logger.warning( - 'No config date detected. Discarding cached campaign') - else: + 'Skipping rangefinder score recovery for %s: maybe seedfile was removed?', sf_md5) + continue + + # if you got here, you have a seedfile to update + # we're going to need its rangefinder + rangefinder = sf_to_update.rangefinder + + # construct a rangefinder key lookup table + rf_lookup = {} + for key, item in rangefinder.things.iteritems(): + lookup_key = (item.min, item.max) + rf_lookup[lookup_key] = key + + for r in rangelist: + # is this range still correct? + cached_rmin = r['range_key']['range_min'] + cached_rmax = r['range_key']['range_max'] + lkey = (cached_rmin, cached_rmax) + try: + rk = rf_lookup[lkey] + except KeyError: + logger.warning( + 'Skipping rangefinder score recovery for %s range %s: range not found', sf_md5, lkey) + continue + + # if you got here you have a matching range to update + # fyi: .arms and .things have the same keys + arm_to_update = rangefinder.arms[rk] + cached_successes = r['range_score']['successes'] + cached_trials = r['range_score']['trials'] + + arm_to_update.update( + successes=cached_successes, trials=cached_trials) + + def _restore_campaign_from_cache(self, cached_data): + self.current_seed = cached_data['current_seed'] + self._restore_seedfile_scores(cached_data['seedfile_scores']) + self._restore_rangefinder_scores(cached_data['rangefinder_scores']) + logger.info('Restoring cached campaign data done') + + def _read_state(self, cachefile=None): + if not cachefile: + cachefile = self.cached_state_file + + cached_data = self._read_cached_data(cachefile) + if cached_data is None: + return + + # check the timestamp + # if the cache is older than the current config file, we should + # ignore the cached data and just start fresh + cached_cfg_ts = cached_data['config_timestamp'] + if self.config['config_timestamp'] != cached_cfg_ts: logger.warning( - 'Unable to reload campaign from %s, will use new campaign instead', cache_file) + 'Config file modified since campaign data cache was created. Discarding cached campaign data. Will proceed as new campaign.') + return 2 + + # if you got here, the cached file is ok to use + + self._restore_campaign_from_cache(cached_data) + + def _get_state_as_dict(self): + state = {'current_seed': self.current_seed, + 'config_timestamp': self.config['config_timestamp'], + 'seedfile_scores': self.seedfile_set.arms_as_dict(), + 'rangefinder_scores': None + } + + # add rangefinder scores from each seedfile + d = {} + for k, sf in self.seedfile_set.things.iteritems(): + d[k] = [] + + for rk, rf in sf.rangefinder.things.iteritems(): + arm = sf.rangefinder.arms[rk] + rkey = {'range_min': rf.min, 'range_max': rf.max} + rdata = {'range_key': rkey, + 'range_score': dict(arm.__dict__)} + d[k].append(rdata) + + state['rangefinder_scores'] = d + + return state + + def _get_state_as_json(self): + state = self._get_state_as_dict() + return json.dumps(state, indent=4, sort_keys=True) def _save_state(self, cachefile=None): if not cachefile: cachefile = self.cached_state_file - # FIXME - # dump_obj_to_file(cachefile, self) + state_as_json = self._get_state_as_json() + write_file(state_as_json, cachefile) def _testcase_is_unique(self, testcase_id, exploitability='UNKNOWN'): ''' @@ -405,10 +491,9 @@ def _do_interval(self): sf = self.seedfile_set.next_item() logger.info('Selected seedfile: %s', sf.basename) -# TODO: restore this -# if self.current_seed % self.status_interval == 0: -# # cache our current state -# self._save_state() + if (self.current_seed > 0) and (self.current_seed % self.status_interval == 0): + # cache our current state + self._save_state() r = sf.rangefinder.next_item() diff --git a/src/certfuzz/campaign/campaign_linux.py b/src/certfuzz/campaign/campaign_linux.py index fa2b20b..10e5469 100644 --- a/src/certfuzz/campaign/campaign_linux.py +++ b/src/certfuzz/campaign/campaign_linux.py @@ -182,30 +182,6 @@ def _set_debugger(self): ''' pass - def __setstate__(self): - ''' - Overrides parent class - ''' - pass - - def _read_state(self): - ''' - Overrides parent class - ''' - pass - - def __getstate__(self): - ''' - Overrides parent class - ''' - pass - - def _save_state(self): - ''' - Overrides parent class - ''' - pass - def _do_iteration(self, seedfile, range_obj, seednum): # Prevent watchdog from rebooting VM. # If /tmp/fuzzing exists and is stale, the machine will reboot diff --git a/src/certfuzz/campaign/campaign_windows.py b/src/certfuzz/campaign/campaign_windows.py index 2c8d388..74d8f54 100644 --- a/src/certfuzz/campaign/campaign_windows.py +++ b/src/certfuzz/campaign/campaign_windows.py @@ -36,39 +36,6 @@ def __init__(self, config_file, result_dir=None, debug=False): self.debugger_module_name = 'certfuzz.debuggers.gdb' TWDF.disable() - def __getstate__(self): - state = self.__dict__.copy() - - state['testcases_seen'] = list(state['testcases_seen']) - if state['seedfile_set']: - state['seedfile_set'] = state['seedfile_set'].__getstate__() - - # for attributes that are modules, - # we can safely delete them as they will be - # reconstituted when we __enter__ a context - for key in ['fuzzer_module', 'fuzzer_cls', - 'runner_module', 'runner_cls', - 'debugger_module' - ]: - if key in state: - del state[key] - return state - - def __setstate__(self, state): - # turn the list into a set - state['testcases_seen'] = set(state['testcases_seen']) - - # reconstitute the seedfile set - with SeedfileSet(state['campaign_id'], state['seed_dir_in'], state['seed_dir_local'], - state['sf_set_out']) as sfset: - new_sfset = sfset - - new_sfset.__setstate__(state['seedfile_set']) - state['seedfile_set'] = new_sfset - - # update yourself - self.__dict__.update(state) - def _pre_enter(self): # check to see if the platform supports winrun # set runner module to none otherwise diff --git a/src/certfuzz/file_handlers/seedfile.py b/src/certfuzz/file_handlers/seedfile.py index e2a812a..65d87ac 100644 --- a/src/certfuzz/file_handlers/seedfile.py +++ b/src/certfuzz/file_handlers/seedfile.py @@ -36,7 +36,8 @@ def __init__(self, output_base_dir, path): BasicFile.__init__(self, path) if not self.len > 0: - raise SeedFileError('You cannot do bitwise fuzzing on a zero-length file: %s' % self.path) + raise SeedFileError( + 'You cannot do bitwise fuzzing on a zero-length file: %s' % self.path) # use len for bytewise, bitlen for bitwise if self.len > 1: @@ -50,29 +51,6 @@ def __init__(self, output_base_dir, path): self.rangefinder = RangeFinder(self.range_min, self.range_max) - def __getstate__(self): - ''' - Pickle a SeedFile object - @return a dict representation of the pickled object - ''' - state = self.__dict__.copy() - state['rangefinder'] = self.rangefinder.__getstate__() - return state - - def __setstate__(self, state): - old_rf = state.pop('rangefinder') - - # rebuild the rangefinder - new_rf = self._get_rangefinder() - old_ranges = old_rf['things'] - for k, old_range in old_ranges.iteritems(): - if k in new_rf.things: - # things = ranges - new_range = new_rf.things[k] - for attr in ['a', 'b', 'probability', 'seen', 'successes', 'tries']: - setattr(new_range, attr, old_range[attr]) - self.rangefinder = new_rf - def cache_key(self): return 'seedfile-%s' % self.md5 @@ -81,5 +59,6 @@ def pkl_file(self): def to_json(self, sort_keys=True, indent=None): state = self.__dict__.copy() - state['rangefinder'] = state['rangefinder'].to_json(sort_keys=sort_keys, indent=indent) + state['rangefinder'] = state['rangefinder'].to_json( + sort_keys=sort_keys, indent=indent) return json.dumps(state, sort_keys=sort_keys, indent=indent) diff --git a/src/certfuzz/file_handlers/seedfile_set.py b/src/certfuzz/file_handlers/seedfile_set.py index 2006d28..afb9ff4 100644 --- a/src/certfuzz/file_handlers/seedfile_set.py +++ b/src/certfuzz/file_handlers/seedfile_set.py @@ -11,7 +11,8 @@ from certfuzz.file_handlers.seedfile import SeedFile from certfuzz.fuzztools import filetools -# Using a generic name here so we can easily swap out other MAB implementations if we want to +# Using a generic name here so we can easily swap out other MAB +# implementations if we want to from certfuzz.scoring.multiarmed_bandit.bayesian_bandit import BayesianMultiArmedBandit as MultiArmedBandit logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ class SeedfileSet(MultiArmedBandit): ''' classdocs ''' + def __init__(self, campaign_id=None, originpath=None, localpath=None, outputpath='.', logfile=None): ''' @@ -43,7 +45,8 @@ def __init__(self, campaign_id=None, originpath=None, localpath=None, hdlr = logging.FileHandler(logfile) logger.addHandler(hdlr) - logger.debug('SeedfileSet output_dir: %s', self.seedfile_output_base_dir) + logger.debug( + 'SeedfileSet output_dir: %s', self.seedfile_output_base_dir) def __enter__(self): self._setup() @@ -94,7 +97,8 @@ def copy_file_from_origin(self, f): # convert the local filenames from . to . basename = 'sf_' + f.md5 + f.ext - targets = [os.path.join(d, basename) for d in (self.localpath, self.outputpath)] + targets = [os.path.join(d, basename) + for d in (self.localpath, self.outputpath)] filetools.copy_file(f.path, *targets) for target in targets: filetools.make_writable(target) @@ -121,39 +125,6 @@ def next_item(self): return sf else: # it doesn't exist, remove it from the set - logger.warning('Seedfile no longer exists, removing from set: %s', sf.path) + logger.warning( + 'Seedfile no longer exists, removing from set: %s', sf.path) self.del_item(sf.md5) - -# def __setstate__(self, state): -# newstate = state.copy() -# -# # copy out old things and replace with an empty dict -# oldthings = newstate.pop('things') -# newstate['things'] = {} -# -# # refresh the directories -# self.__dict__.update(newstate) -# self._setup() -# -# # clean up things that no longer exist -# self.sfcount = 0 -# self.sfdel = 0 -# for k, old_sf in oldthings.iteritems(): -# # update the seedfiles for ones that are still present -# if k in self.things: -# # print "%s in things..." % k -# self.things[k].__setstate__(old_sf) -# self.sfcount += 1 - -# def __getstate__(self): -# state = ScorableSet3.__getstate__(self) -# -# # remove things we can recreate -# try: -# for k in ('origindir', 'localdir', 'outputdir'): -# del state[k] -# except KeyError: -# # it's ok if they don't exist -# pass -# -# return state diff --git a/src/certfuzz/fuzztools/object_caching.py b/src/certfuzz/fuzztools/object_caching.py index df027a2..63f5a7c 100644 --- a/src/certfuzz/fuzztools/object_caching.py +++ b/src/certfuzz/fuzztools/object_caching.py @@ -16,7 +16,8 @@ def dump_obj_to_file(cachefile, obj): pickle.dump(obj, fd) logger.debug('Wrote %s to %s', obj.__class__.__name__, cachefile) except (IOError, TypeError) as e: - logger.warning('Unable to write %s to cache file %s: %s', obj.__class__.__name__, cachefile, e) + logger.warning( + 'Unable to write %s to cache file %s: %s', obj.__class__.__name__, cachefile, e) def load_obj_from_file(cachefile): @@ -28,12 +29,3 @@ def load_obj_from_file(cachefile): except StandardError, e: logger.debug("Unable to read from %s: %s", cachefile, e) return obj - - -def cache_state(key_prefix, key_suffix, obj, cachefile): - dump_obj_to_file(cachefile, obj) - - -def get_cached_state(key_suffix, key_prefix, cachefile): - obj = load_obj_from_file(cachefile) - return obj diff --git a/src/certfuzz/fuzztools/rangefinder.py b/src/certfuzz/fuzztools/rangefinder.py index 9cda110..030e8dc 100644 --- a/src/certfuzz/fuzztools/rangefinder.py +++ b/src/certfuzz/fuzztools/rangefinder.py @@ -23,6 +23,7 @@ class RangeFinder(MultiArmedBandit): 3. a probability distribution across all ranges as well as a picker method to randomly choose a range based on the probability distribution. ''' + def __init__(self, low, high): MultiArmedBandit.__init__(self) @@ -36,20 +37,6 @@ def __init__(self, low, high): self._set_ranges() -# def __getstate__(self): -# # we can't pickle the logger. -# # But that's okay. We can get it back in __setstate__ -# state = ScorableSet2.__getstate__(self) -# del state['logger'] -# return state -# -# def __setstate__(self, d): -# self.__dict__.update(d) -# for thing in self.things.iteritems(): -# assert type(thing) == Range, 'Type is %s' % type(thing) -# # recover the logger we had to drop in __getstate__ -# self._set_logger() - def _exp_range(self, low, factor): high = low * factor # don't overshoot the high diff --git a/src/certfuzz/scoring/multiarmed_bandit/multiarmed_bandit_base.py b/src/certfuzz/scoring/multiarmed_bandit/multiarmed_bandit_base.py index 08e68bf..31786d8 100644 --- a/src/certfuzz/scoring/multiarmed_bandit/multiarmed_bandit_base.py +++ b/src/certfuzz/scoring/multiarmed_bandit/multiarmed_bandit_base.py @@ -21,6 +21,9 @@ def __init__(self): self.things = {} self.arms = {} + def arms_as_dict(self): + return {k: dict(arm.__dict__) for k, arm in self.arms.iteritems()} + def add_item(self, key=None, obj=None): if key is None: raise MultiArmedBanditError('unspecified key for arm') @@ -38,7 +41,6 @@ def add_item(self, key=None, obj=None): # but don't trust those averages too strongly new_arm.doubt() - # add the new arm to the set self.arms[key] = new_arm @@ -54,7 +56,8 @@ def del_item(self, key=None): pass def record_result(self, key, successes=0, trials=0): - logger.debug('Recording result: key=%s successes=%d trials=%d', key, successes, trials) + logger.debug( + 'Recording result: key=%s successes=%d trials=%d', key, successes, trials) arm = self.arms[key] arm.update(successes, trials) diff --git a/src/test_certfuzz/campaign/test_campaign_base.py b/src/test_certfuzz/campaign/test_campaign_base.py index c54c36e..c12f5ce 100644 --- a/src/test_certfuzz/campaign/test_campaign_base.py +++ b/src/test_certfuzz/campaign/test_campaign_base.py @@ -14,6 +14,8 @@ from certfuzz.campaign.errors import CampaignError from test_certfuzz.mocks import MockCfg import yaml +from certfuzz.file_handlers.seedfile_set import SeedfileSet +import json class UnimplementedCampaign(CampaignBase): @@ -21,11 +23,6 @@ class UnimplementedCampaign(CampaignBase): class ImplementedCampaign(CampaignBase): - def __getstate__(self): - pass - - def __setstate__(self): - pass def _do_interval(self): pass @@ -51,15 +48,18 @@ def _set_fuzzer(self): def _set_runner(self): pass + class Test(unittest.TestCase): + def setUp(self): self.tmpdir = mkdtemp() - fd, cfgfile = tempfile.mkstemp(prefix='config_', suffix=".yaml", dir=self.tmpdir) + fd, cfgfile = tempfile.mkstemp( + prefix='config_', suffix=".yaml", dir=self.tmpdir) os.close(fd) - + cfg = MockCfg(templated=False) - with open(cfgfile,'wb') as f: - yaml.dump(cfg,f) + with open(cfgfile, 'wb') as f: + yaml.dump(cfg, f) self.campaign = ImplementedCampaign(cfgfile) def tearDown(self): @@ -110,7 +110,8 @@ def test_setup_workdir(self): self.campaign._setup_workdir() self.assertTrue(os.path.isdir(self.campaign.work_dir_base)) self.assertTrue(os.path.isdir(self.campaign.working_dir)) - self.assertTrue(self.campaign.seed_dir_local.startswith(self.campaign.working_dir)) + self.assertTrue( + self.campaign.seed_dir_local.startswith(self.campaign.working_dir)) self.assertTrue(self.campaign.seed_dir_local.endswith('seedfiles')) def test_cleanup_workdir(self): @@ -138,6 +139,149 @@ def test_keep_going(self): for _x in range(100): self.assertTrue(self.campaign._keep_going()) + def _check_data_structure(self, x): + for k in ['current_seed', 'config_timestamp', 'seedfile_scores', 'rangefinder_scores']: + self.assertTrue(k in x) + + self.assertEqual(x['current_seed'], self.campaign.current_seed) + self.assertEqual( + x['config_timestamp'], self.campaign.config['config_timestamp']) + + self.assertEqual( + len(x['rangefinder_scores']), len(self.campaign.seedfile_set.arms)) + + self.assertEqual( + len(x['seedfile_scores']), len(self.campaign.seedfile_set.arms)) + + # verify the data structures + for score in x['seedfile_scores'].values(): + for k in ['successes', 'trials', 'probability']: + self.assertTrue(k in score) + + for items in x['rangefinder_scores'].values(): + for item in items: + for k in ['range_key', 'range_score']: + self.assertTrue(k in item) + score = item['range_score'] + for k in ['successes', 'trials', 'probability']: + self.assertTrue(k in score) + + def _populate_sf_set(self): + self.campaign.seedfile_set = SeedfileSet() + + files = [] + for x in xrange(10): + _fd, _fname = tempfile.mkstemp(prefix='seedfile_', dir=self.tmpdir) + os.write(_fd, str(x)) + os.close(_fd) + files.append(_fname) + + self.campaign.seedfile_set.add_file(*files) + + def test_get_state_as_dict(self): + self._populate_sf_set() + x = self.campaign._get_state_as_dict() + self._check_data_structure(x) + + def test_get_state_as_json(self): + self._populate_sf_set() + j = self.campaign._get_state_as_json() + x = json.loads(j) + self._check_data_structure(x) + + def test_save_state(self): + fd, fpath = tempfile.mkstemp( + suffix=".json", prefix="campaign_state_", dir=self.tmpdir) + os.close(fd) + os.remove(fpath) + self.assertFalse(os.path.exists(fpath)) + + self._populate_sf_set() + self.campaign._save_state(fpath) + self.assertTrue(os.path.exists(fpath)) + self.assertTrue(os.path.getsize(fpath) > 0) + + with open(fpath, 'rb') as f: + x = json.load(f) + + self._check_data_structure(x) + + for k, v in self.campaign._get_state_as_dict().iteritems(): + self.assertTrue(k in x) + self.assertEqual(x[k], v) + + def test_read_state(self): + fd, fpath = tempfile.mkstemp( + suffix=".json", prefix="campaign_state_", dir=self.tmpdir) + os.close(fd) + self._populate_sf_set() + d = self.campaign._get_state_as_dict() + + d['current_seed'] = 1000 + for score in d['seedfile_scores'].itervalues(): + score['successes'] = 10 + score['trials'] = 100 + + for sf in d['rangefinder_scores'].values(): + for r in sf: + r['range_score']['successes'] = 5 + r['range_score']['trials'] = 50 + + with open(fpath, 'wb') as f: + json.dump(d, f) + + self.assertNotEqual(self.campaign.current_seed, d['current_seed']) + successes = [x['successes'] + for x in self.campaign.seedfile_set.arms_as_dict().values()] + for _score in successes: + self.assertEqual(0, _score) + trials = [x['trials'] + for x in self.campaign.seedfile_set.arms_as_dict().values()] + for _score in trials: + self.assertEqual(0, _score) + + for sf in self.campaign.seedfile_set.things.values(): + for r in sf.rangefinder.arms.values(): + self.assertEqual(0, r.successes) + self.assertEqual(0, r.trials) + + self.campaign._read_state(fpath) + + self.assertEqual(self.campaign.current_seed, d['current_seed']) + successes = [x['successes'] + for x in self.campaign.seedfile_set.arms_as_dict().values()] + for _score in successes: + self.assertEqual(10, _score) + trials = [x['trials'] + for x in self.campaign.seedfile_set.arms_as_dict().values()] + for _score in trials: + self.assertEqual(100, _score) + + for sf in self.campaign.seedfile_set.things.values(): + for r in sf.rangefinder.arms.values(): + self.assertEqual(5, r.successes) + self.assertEqual(50, r.trials) + + def test_reject_cached_data_if_newer_config(self): + fd, fpath = tempfile.mkstemp( + suffix=".json", prefix="campaign_state_", dir=self.tmpdir) + os.close(fd) + self._populate_sf_set() + d = self.campaign._get_state_as_dict() + d['config_timestamp'] = d['config_timestamp'] - 1000.0 + with open(fpath, 'wb') as f: + json.dump(d, f) + + self.assertEqual(2, self.campaign._read_state(fpath)) + + def test_reject_cached_data_if_no_file(self): + fd, fpath = tempfile.mkstemp( + suffix=".json", prefix="campaign_state_", dir=self.tmpdir) + os.close(fd) + self.assertEqual(None, self.campaign._read_cached_data(fpath)) + os.remove(fpath) + self.assertEqual(None, self.campaign._read_cached_data(fpath)) + if __name__ == "__main__": # import sys;sys.argv = ['', 'Test.testName'] diff --git a/src/test_certfuzz/file_handlers/test_seedfile.py b/src/test_certfuzz/file_handlers/test_seedfile.py index bc740a4..e322a75 100644 --- a/src/test_certfuzz/file_handlers/test_seedfile.py +++ b/src/test_certfuzz/file_handlers/test_seedfile.py @@ -8,7 +8,6 @@ import tempfile import os from certfuzz.file_handlers.seedfile import SeedFile -from certfuzz.fuzztools.rangefinder import RangeFinder class Test(unittest.TestCase): @@ -28,18 +27,6 @@ def tearDown(self): def test_init(self): pass -# def test_getstate(self): -# self.assertEqual(RangeFinder, type(self.sf.rangefinder)) -# state = self.sf.__getstate__() -# self.assertEqual(dict, type(state)) -# self.assertEqual(dict, type(state['rangefinder'])) -# -# def test_setstate(self): -# state = self.sf.__getstate__() -# self.sf.__setstate__(state) -# # make sure we restore rangefinder -# self.assertEqual(RangeFinder, type(self.sf.rangefinder)) - if __name__ == "__main__": # import sys;sys.argv = ['', 'Test.testName'] diff --git a/src/test_certfuzz/file_handlers/test_seedfile_set.py b/src/test_certfuzz/file_handlers/test_seedfile_set.py index 4df125e..2ebce05 100644 --- a/src/test_certfuzz/file_handlers/test_seedfile_set.py +++ b/src/test_certfuzz/file_handlers/test_seedfile_set.py @@ -33,7 +33,8 @@ def setUp(self): self.files.append(f) # create a set - self.sfs = SeedfileSet(campaign_id, self.origindir, self.localdir, self.outputdir) + self.sfs = SeedfileSet( + campaign_id, self.origindir, self.localdir, self.outputdir) def tearDown(self): for f in self.files: @@ -101,105 +102,9 @@ def test_init(self): self.assertEqual(self.outputdir, self.sfs.seedfile_output_base_dir) self.assertEqual(0, len(self.sfs.things)) -# def test_getstate_is_pickle_friendly(self): -# # getstate should return a pickleable object -# import pickle -# state = self.sfs.__getstate__() -# try: -# pickle.dumps(state) -# except Exception, e: -# self.fail('Failed to pickle state: %s' % e) -# -# def test_getstate(self): -# state = self.sfs.__getstate__() -# self.assertEqual(dict, type(state)) -# -# for k in self.sfs.__dict__.iterkeys(): -# # make sure we're deleting what we need to -# if k in ['localdir', 'origindir', 'outputdir']: -# self.assertFalse(k in state) -# else: -# self.assertTrue(k in state, '%s not found' % k) - -# def test_setstate(self): -# self.sfs.__enter__() -# state_before = self.sfs.__getstate__() -# self.sfs.__setstate__(state_before) -# self.assertEqual(self.file_count, self.sfs.sfcount) -# state_after = self.sfs.__getstate__() -# -# for k, v in state_before.iteritems(): -# self.assertTrue(k in state_after) -# if not k == 'things': -# self.assertEqual(v, state_after[k]) -# -# for k, thing in state_before['things'].iteritems(): -# # is there a corresponding thing in sfs? -# self.assertTrue(k in self.sfs.things) -# -# for x in thing.iterkeys(): -# # was it set correctly? -# self.assertEqual(thing[x], self.sfs.things[k].__dict__[x]) -# -# self.assertEqual(self.file_count, self.sfs.sfcount) - -# def test_setstate_with_changed_files(self): -# # refresh the sfs -# self.sfs.__enter__() -# -# # get the original state -# state_before = self.sfs.__getstate__() -# self.assertEqual(len(state_before['things']), self.file_count) -# -# # delete one of the files -# file_to_remove = self.files.pop() -# localfile_md5 = hashlib.md5(open(file_to_remove, 'rb').read()).hexdigest() -# localfilename = "sf_%s" % localfile_md5 -# -# # remove it from origin -# os.remove(file_to_remove) -# self.assertFalse(file_to_remove in self.files) -# self.assertFalse(os.path.exists(file_to_remove)) -## print "removed %s" % file_to_remove -# -## # remove it from localdir -# localfile_to_remove = os.path.join(self.localdir, localfilename) -# os.remove(localfile_to_remove) -# self.assertFalse(os.path.exists(localfile_to_remove)) -# -# # create a new sfs -# new_sfs = SeedfileSet() -# new_sfs.__setstate__(state_before) -# -# self.assertEqual(len(new_sfs.things), (self.file_count - 1)) -# -## print "Newthings: %s" % new_sfs.things.keys() -# for k, thing in state_before['things'].iteritems(): -## print "k: %s" % k -# if k == localfile_md5: -# self.assertFalse(k in new_sfs.things) -# continue -# else: -# # is there a corresponding thing in sfs? -# self.assertTrue(k in new_sfs.things) -# -# for x, y in thing.iteritems(): -# # was it set correctly? -# sfsthing = new_sfs.things[k].__dict__[x] -# if hasattr(sfsthing, '__dict__'): -# # some things are complex objects themselves -# # so we have to compare their __dict__ versions -# self._same_dict(y, sfsthing.__dict__) -# else: -# # others are just simple objects and we can -# # compare them directly -# self.assertEqual(y, sfsthing) -# -# self.assertEqual(self.file_count - 1, new_sfs.sfcount) - def _same_dict(self, d1, d2): for k, v in d1.iteritems(): -# print k + # print k self.assertTrue(k in d2) self.assertEqual(v, d2[k]) diff --git a/src/test_certfuzz/fuzztools/test_range.py b/src/test_certfuzz/fuzztools/test_range.py index 3f57622..db5d33d 100644 --- a/src/test_certfuzz/fuzztools/test_range.py +++ b/src/test_certfuzz/fuzztools/test_range.py @@ -6,7 +6,9 @@ import unittest from certfuzz.fuzztools.range import Range + class Test(unittest.TestCase): + def setUp(self): self.r = Range(0, 1) @@ -22,45 +24,6 @@ def test_init(self): def test_repr(self): self.assertEqual(self.r.__repr__(), '0.000000-1.000000') -# def test_getstate_is_pickle_friendly(self): -# # getstate should return a pickleable object -# import pickle -# state = self.r.__getstate__() -# try: -# pickle.dumps(state) -# except Exception, e: -# self.fail('Failed to pickle state: %s' % e) -# -# def test_getstate_has_all_expected_items(self): -# state = self.r.__getstate__() -# for k, v in self.r.__dict__.iteritems(): -# # make sure we're deleting what we need to -# if k in ['logger']: -# self.assertFalse(k in state) -# else: -# self.assertTrue(k in state, '%s not found' % k) -# self.assertEqual(state[k], v) -# -# def test_getstate(self): -# state = self.r.__getstate__() -# self.assertEqual(dict, type(state)) -# print 'as dict...' -# pprint.pprint(state) -# -# def test_to_json(self): -# as_json = self.r.to_json(indent=4) -# -# print 'as JSON...' -# for l in as_json.splitlines(): -# print l -# -# from_json = json.loads(as_json) -# -# # make sure we can round-trip it -# for k, v in self.r.__getstate__().iteritems(): -# self.assertTrue(k in from_json) -# self.assertEqual(from_json[k], v) - if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.testName'] unittest.main() diff --git a/src/test_certfuzz/fuzztools/test_rangefinder.py b/src/test_certfuzz/fuzztools/test_rangefinder.py index 591d800..c44c097 100644 --- a/src/test_certfuzz/fuzztools/test_rangefinder.py +++ b/src/test_certfuzz/fuzztools/test_rangefinder.py @@ -12,6 +12,7 @@ class Test(unittest.TestCase): + def delete_file(self, f): os.remove(f) self.assertFalse(os.path.exists(f)) @@ -74,28 +75,6 @@ def test_range_mean(self): for x in self.r.things.values(): self.assertAlmostEqual(x.mean, ((x.max + x.min) / 2)) -# def test_getstate_is_pickle_friendly(self): -# # getstate should return a pickleable object -# import pickle -# state = self.r.__getstate__() -# try: -# pickle.dumps(state) -# except Exception, e: -# self.fail('Failed to pickle state: %s' % e) -# -# def test_getstate_has_all_expected_items(self): -# state = self.r.__getstate__() -# for k, v in self.r.__dict__.iteritems(): -# # make sure we're deleting what we need to -# if k in ['logger']: -# self.assertFalse(k in state) -# else: -# self.assertTrue(k in state, '%s not found' % k) -# self.assertEqual(type(state[k]), type(v)) -# -# def test_getstate(self): -# state = self.r.__getstate__() -# self.assertEqual(dict, type(state)) if __name__ == "__main__": # import sys;sys.argv = ['', 'Test.testName'] diff --git a/src/test_certfuzz/scoring/multiarmed_bandit/test_multiarmed_bandit_base.py b/src/test_certfuzz/scoring/multiarmed_bandit/test_multiarmed_bandit_base.py index 1ffdccc..f3b3fd2 100644 --- a/src/test_certfuzz/scoring/multiarmed_bandit/test_multiarmed_bandit_base.py +++ b/src/test_certfuzz/scoring/multiarmed_bandit/test_multiarmed_bandit_base.py @@ -21,8 +21,10 @@ def tearDown(self): def test_add(self): self.assertRaises(MultiArmedBanditError, self.mab.add_item) - self.assertRaises(MultiArmedBanditError, self.mab.add_item, key=None, obj='obj') - self.assertRaises(MultiArmedBanditError, self.mab.add_item, key='key', obj=None) + self.assertRaises( + MultiArmedBanditError, self.mab.add_item, key=None, obj='obj') + self.assertRaises( + MultiArmedBanditError, self.mab.add_item, key='key', obj=None) self.assertEqual(len(self.keys), len(self.mab.things)) self.assertEqual(len(self.keys), len(self.mab.arms)) @@ -92,6 +94,17 @@ def test_next(self): # empty set raises StopIteration self.assertRaises(StopIteration, self.mab.next) + def test_arms_as_dict(self): + d = self.mab.arms_as_dict() + + self.assertTrue(isinstance(d, dict)) + + for k, arm in self.mab.arms.iteritems(): + self.assertTrue(isinstance(d[k], dict)) + for attrname in ['successes', 'probability', 'trials']: + self.assertTrue(attrname in d[k]) + self.assertEqual(d[k][attrname], getattr(arm, attrname)) + if __name__ == "__main__": # import sys;sys.argv = ['', 'Test.testName'] unittest.main()