diff --git a/test/conftest.py b/test/conftest.py index 9ac7bbf..587df88 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -11,6 +11,15 @@ def pytest_addoption(parser): parser.addoption("--skip-ftp", action='store_true', default=False) + parser.addoption("--run-slow", action="store_true", default=False, help="Run slow tests") + + +def pytest_collection_modifyitems(config, items): + if not config.getoption("--run-slow"): + skipper = pytest.mark.skip(reason="Only run when --run-slow is given") + for item in items: + if "slow" in item.keywords: + item.add_marker(skipper) @pytest.fixture(scope='session', autouse=True) @@ -22,7 +31,7 @@ def skip_ftp(request): def setup_cfg(skip_ftp): logger.info("setting up config") with TemporaryDirectory() as tempdir: - config.set_base_dir(tempdir) + config.base_dir = tempdir config.madrigal_user_affil = 'bu' config.madrigal_user_email = 'gstarr@bu.edu' config.madrigal_user_name = 'gregstarr' diff --git a/test/test_arb.py b/test/test_arb.py index beb7f8b..dfc0f24 100644 --- a/test/test_arb.py +++ b/test/test_arb.py @@ -70,7 +70,7 @@ def test_download_arb(test_dates, download_dir): ['dt', 'mlt_vals'], itertools.product( [np.timedelta64(30, 'm'), np.timedelta64(1, 'h'), np.timedelta64(2, 'h')], - [config.get_mlt_vals(), np.arange(10)] + [config.mlt_vals, np.arange(10)] ) ) def test_process_arb(download_dir, processed_dir, test_dates, dt, mlt_vals): @@ -94,13 +94,13 @@ def test_process_arb_out_of_range(download_dir, processed_dir, test_dates): start, end = [date - timedelta(days=100) for date in test_dates] processed_file = Path(processed_dir) / 'arb_test.nc' with pytest.raises(InvalidProcessDates): - process_interval(start, end, 'north', processed_file, download_dir, config.get_mlt_vals(), dt) + process_interval(start, end, 'north', processed_file, download_dir, config.mlt_vals, dt) def test_get_arb_data(download_dir, processed_dir, test_dates): start, end = test_dates dt = np.timedelta64(1, 'h') - mlt = config.get_mlt_vals() + mlt = config.mlt_vals correct_times = np.arange(np.datetime64(start), np.datetime64(end) + dt, dt) processed_file = get_arb_paths(start, end, 'north', processed_dir)[0] process_interval(start, end, 'north', processed_file, download_dir, mlt, dt) @@ -124,7 +124,7 @@ def test_scripts(test_dates): data = get_arb_data(start, end, 'north', cfg.processed_arb_dir) data.load() dt = np.timedelta64(1, 'h') - mlt = config.get_mlt_vals() + mlt = config.mlt_vals correct_times = np.arange(np.datetime64(test_dates[0]), np.datetime64(test_dates[-1]) + dt, dt) assert data.shape == (correct_times.shape[0], mlt.shape[0]) assert (data.mlt == mlt).all().item() diff --git a/test/test_tec.py b/test/test_tec.py index af89053..3d65f6a 100644 --- a/test/test_tec.py +++ b/test/test_tec.py @@ -122,8 +122,8 @@ def test_download_tec(test_dates, download_dir): ['dt', 'mlt_bins', 'mlat_bins'], itertools.product( [np.timedelta64(10, 'm'), np.timedelta64(30, 'm'), np.timedelta64(1, 'h'), np.timedelta64(2, 'h')], - [config.get_mlt_bins(), np.arange(10)], - [config.get_mlat_bins(), np.arange(10)] + [config.mlt_bins, np.arange(10)], + [config.mlat_bins, np.arange(10)] ) ) def test_process_tec(download_dir, process_dir, test_dates, dt, mlt_bins, mlat_bins): @@ -153,14 +153,14 @@ def test_process_tec_out_of_range(download_dir, process_dir, test_dates): start, end = [date - timedelta(days=100) for date in test_dates] processed_file = Path(process_dir) / 'tec_test.nc' with pytest.raises(InvalidProcessDates): - process_interval(start, end, 'north', processed_file, download_dir, dt, config.get_mlat_bins(), config.get_mlt_bins()) + process_interval(start, end, 'north', processed_file, download_dir, dt, config.mlat_bins, config.mlt_bins) def test_get_tec_data(download_dir, process_dir, test_dates): start, end = test_dates dt = np.timedelta64(1, 'h') - mlt_bins = config.get_mlt_bins() - mlat_bins = config.get_mlat_bins() + mlt_bins = config.mlt_bins + mlat_bins = config.mlat_bins mlt_vals = (mlt_bins[:-1] + mlt_bins[1:]) / 2 mlat_vals = (mlat_bins[:-1] + mlat_bins[1:]) / 2 correct_times = np.arange(np.datetime64(start), np.datetime64(end) + dt, dt) @@ -191,8 +191,8 @@ def test_scripts(test_dates): data = get_tec_data(start, end, hemisphere, cfg.processed_tec_dir) data.load() dt = np.timedelta64(1, 'h') - mlt_vals = config.get_mlt_vals() - mlat_vals = config.get_mlat_vals() + mlt_vals = config.mlt_vals + mlat_vals = config.mlat_vals correct_times = np.arange(np.datetime64(test_dates[0]), np.datetime64(test_dates[-1]) + dt, dt) h = 1 if hemisphere == 'north' else -1 assert data.shape == (correct_times.shape[0], mlat_vals.shape[0], mlt_vals.shape[0]) diff --git a/test/test_trough.py b/test/test_trough.py index 2c4e429..6899eeb 100644 --- a/test/test_trough.py +++ b/test/test_trough.py @@ -30,7 +30,7 @@ def test_preprocess_interval(): def test_model_artificial_example(): - mlt_grid, mlat_grid = np.meshgrid(config.get_mlt_vals(), config.get_mlat_vals()) + mlt_grid, mlat_grid = np.meshgrid(config.mlt_vals, config.mlat_vals) times = np.datetime64("2000") + np.arange(4) * np.timedelta64(1, 'h') # nominal labels labels1 = abs(mlat_grid - 65) <= 2 @@ -46,7 +46,7 @@ def test_model_artificial_example(): shp = (4,) + mlat_grid.shape det_log_tec = det_log_tec.reshape(shp) det_log_tec += np.random.randn(*shp) * .1 - coords = {'time': times, 'mlat': config.get_mlat_vals(), 'mlt': config.get_mlt_vals()} + coords = {'time': times, 'mlat': config.mlat_vals, 'mlt': config.mlt_vals} data = xr.Dataset({ 'x': xr.DataArray( det_log_tec, @@ -55,7 +55,7 @@ def test_model_artificial_example(): ), 'model': xr.DataArray( np.ones((times.shape[0], mlt_grid.shape[1])) * 65, - coords={'time': times, 'mlt': config.get_mlt_vals()}, + coords={'time': times, 'mlt': config.mlt_vals}, dims=['time', 'mlt'] ), }) @@ -72,7 +72,7 @@ def test_postprocess(): """verify that small troughs are rejected, verify that troughs that wrap around the border are not incorrectly rejected """ - mlt_grid, mlat_grid = np.meshgrid(config.get_mlt_vals(), config.get_mlat_vals()) + mlt_grid, mlat_grid = np.meshgrid(config.mlt_vals, config.mlat_vals) good_labels = (abs(mlat_grid - 65) < 3) * (abs(mlt_grid) < 2) small_reject = (abs(mlat_grid - 52) <= 1) * (abs(mlt_grid - 4) <= .5) boundary_good_labels = (abs(mlat_grid - 65) < 3) * (abs(mlt_grid) >= 10.2) @@ -83,7 +83,7 @@ def test_postprocess(): high_labels = (abs(mlat_grid - 80) < 2) * (abs(mlt_grid + 6) <= 3) arb = np.ones((1, 180)) * 70 initial_labels = good_labels + small_reject + boundary_good_labels + boundary_bad_labels + weird_good_labels + high_labels - coords = {'time': [0], 'mlat': config.get_mlat_vals(), 'mlt': config.get_mlt_vals()} + coords = {'time': [0], 'mlat': config.mlat_vals, 'mlt': config.mlt_vals} data = xr.Dataset({ 'labels': xr.DataArray( initial_labels[None], @@ -92,7 +92,7 @@ def test_postprocess(): ), 'arb': xr.DataArray( arb, - coords={'time': [0], 'mlt': config.get_mlt_vals()}, + coords={'time': [0], 'mlt': config.mlt_vals}, dims=['time', 'mlt'] ), }) @@ -217,3 +217,46 @@ def test_script(dates): data.load() assert data.time.shape[0] == n_times data.close() + + +def test_script_multiple_config(): + start_date = datetime(2015, 3, 17, 5, 0, 0) + end_date = datetime(2015, 3, 17, 15, 0, 0) + offset = timedelta(hours=5) + with TemporaryDirectory() as tempdir: + with config.temp_config(base_dir=tempdir): + scripts.full_run(start_date, end_date) + with config.temp_config(base_dir=tempdir, time_res_n=30, time_res_unit='m'): + scripts.full_run(start_date + offset, end_date + offset) + with config.temp_config(base_dir=tempdir): + n_files = len([p for p in Path(config.processed_labels_dir).glob('labels*.nc')]) + n_times = 1 + ((end_date - start_date) / timedelta(hours=1)) + assert n_files == (end_date.year - start_date.year + 1) * 2 + data = get_data(start_date, end_date, 'north') + data.load() + assert data.time.shape[0] == n_times + data.close() + with config.temp_config(base_dir=tempdir, time_res_n=30, time_res_unit='m'): + n_files = len([p for p in Path(config.processed_labels_dir).glob('labels*.nc')]) + n_times = 1 + ((end_date - start_date) / timedelta(minutes=30)) + assert n_files == (end_date.year - start_date.year + 1) * 2 + data = get_data(start_date + offset, end_date + offset, 'north') + data.load() + assert data.time.shape[0] == n_times + data.close() + + +@pytest.mark.slow +def test_date_error(): + start_date = datetime(2015, 3, 17) + end_date = datetime(2015, 3, 18) + n_times = 1 + ((end_date - start_date) / timedelta(minutes=30)) + with TemporaryDirectory() as tempdir: + with config.temp_config(base_dir=tempdir, time_res_n=30, time_res_unit='m'): + scripts.full_run(start_date, end_date) + n_files = len([p for p in Path(config.processed_labels_dir).glob('labels*.nc')]) + assert n_files == (end_date.year - start_date.year + 1) * 2 + data = get_data(start_date, end_date, 'north') + data.load() + assert data.time.shape[0] == n_times + data.close() diff --git a/trough/_arb.py b/trough/_arb.py index a0454d7..eed62b7 100644 --- a/trough/_arb.py +++ b/trough/_arb.py @@ -113,9 +113,9 @@ def process_auroral_boundary_dataset(start_date, end_date, download_dir=None, pr if process_dir is None: process_dir = config.processed_arb_dir if mlt_vals is None: - mlt_vals = config.get_mlt_vals() + mlt_vals = config.mlt_vals if dt is None: - dt = np.timedelta64(1, 'h') + dt = config.sample_dt Path(process_dir).mkdir(exist_ok=True, parents=True) for year in range(start_date.year, end_date.year + 1): diff --git a/trough/_config.py b/trough/_config.py index 57d76be..ebb18fc 100644 --- a/trough/_config.py +++ b/trough/_config.py @@ -26,28 +26,6 @@ def dict(self): return dataclasses.asdict(self) -def _get_default_directory_structure(base_dir): - base = Path(base_dir) - download_base = base / 'download' - processed_base = base / 'processed' - download_tec_dir = download_base / 'tec' - download_arb_dir = download_base / 'arb' - download_omni_dir = download_base / 'omni' - processed_tec_dir = processed_base / 'tec' - processed_arb_dir = processed_base / 'arb' - processed_omni_file = processed_base / 'omni.nc' - processed_labels_dir = processed_base / 'labels' - return { - 'download_tec_dir': str(download_tec_dir), - 'download_arb_dir': str(download_arb_dir), - 'download_omni_dir': str(download_omni_dir), - 'processed_tec_dir': str(processed_tec_dir), - 'processed_arb_dir': str(processed_arb_dir), - 'processed_omni_file': str(processed_omni_file), - 'processed_labels_dir': str(processed_labels_dir), - } - - trough_dirs = appdirs.AppDirs(appname='trough') @@ -66,14 +44,7 @@ def parse_date(date_str): class Config: def __init__(self, config_path=None): - default_dirs = _get_default_directory_structure(trough_dirs.user_data_dir) - self.download_tec_dir = default_dirs['download_tec_dir'] - self.download_arb_dir = default_dirs['download_arb_dir'] - self.download_omni_dir = default_dirs['download_omni_dir'] - self.processed_tec_dir = default_dirs['processed_tec_dir'] - self.processed_arb_dir = default_dirs['processed_arb_dir'] - self.processed_omni_file = default_dirs['processed_omni_file'] - self.processed_labels_dir = default_dirs['processed_labels_dir'] + self.base_dir = trough_dirs.user_data_dir self.trough_id_params = TroughIdParams() self.madrigal_user_name = None self.madrigal_user_email = None @@ -95,40 +66,74 @@ def __init__(self, config_path=None): if config_path is not None: self.load_json(config_path) - def get_config_name(self): + @property + def config_name(self): cfg = self.dict() return f"{cfg['script_name']}_{cfg['start_date']}_{cfg['end_date']}_config.json" - def get_mlat_bins(self): + @property + def mlat_bins(self): return np.arange(self.mlat_min - self.lat_res / 2, 90, self.lat_res) - def get_mlat_vals(self): + @property + def mlat_vals(self): return np.arange(self.mlat_min, 90, self.lat_res) - def get_mlt_bins(self): + @property + def mlt_bins(self): return np.arange(-12, 12 + 24 / 360, self.lon_res * 24 / 360) - def get_mlt_vals(self): + @property + def mlt_vals(self): return np.arange(-12 + .5 * self.lon_res * 24 / 360, 12 + 24 / 360, self.lon_res * 24 / 360) - def get_sample_dt(self): + @property + def sample_dt(self): return np.timedelta64(self.time_res_n, self.time_res_unit) + @property + def download_base(self): + return str(Path(self.base_dir) / 'download') + + @property + def processed_base(self): + return str(Path(self.base_dir) / f'processed_{self.lat_res}_{self.lon_res}_{self.time_res_n}{self.time_res_unit}') + + @property + def download_tec_dir(self): + return str(Path(self.download_base) / 'tec') + + @property + def download_arb_dir(self): + return str(Path(self.download_base) / 'arb') + + @property + def download_omni_dir(self): + return str(Path(self.download_base) / 'omni') + + @property + def processed_tec_dir(self): + return str(Path(self.processed_base) / 'tec') + + @property + def processed_arb_dir(self): + return str(Path(self.processed_base) / 'arb') + + @property + def processed_omni_file(self): + return str(Path(self.processed_base) / 'omni.nc') + + @property + def processed_labels_dir(self): + return str(Path(self.processed_base) / 'labels') + def load_json(self, config_path): with open(config_path) as f: params = json.load(f) - if 'base_dir' in params: - params.update(**_get_default_directory_structure(params['base_dir'])) self.load_dict(params) def load_dict(self, config_dict): - self.download_tec_dir = config_dict.get('download_tec_dir', self.download_tec_dir) - self.download_arb_dir = config_dict.get('download_arb_dir', self.download_arb_dir) - self.download_omni_dir = config_dict.get('download_omni_dir', self.download_omni_dir) - self.processed_tec_dir = config_dict.get('processed_tec_dir', self.processed_tec_dir) - self.processed_arb_dir = config_dict.get('processed_arb_dir', self.processed_arb_dir) - self.processed_omni_file = config_dict.get('processed_omni_file', self.processed_omni_file) - self.processed_labels_dir = config_dict.get('processed_labels_dir', self.processed_labels_dir) + self.base_dir = config_dict.get('base_dir', self.base_dir) self.trough_id_params = config_dict.get('trough_id_params', self.trough_id_params) if not isinstance(self.trough_id_params, TroughIdParams): self.trough_id_params = TroughIdParams(**self.trough_id_params) @@ -151,7 +156,7 @@ def load_dict(self, config_dict): def save(self, config_path=None): if config_path is None: - config_path = Path(trough_dirs.user_config_dir) / self.get_config_name() + config_path = Path(trough_dirs.user_config_dir) / self.config_name save_dict = self.dict().copy() Path(config_path).parent.mkdir(exist_ok=True, parents=True) with open(config_path, 'w') as f: @@ -160,10 +165,6 @@ def save(self, config_path=None): cfg_pointer.write_text(str(config_path)) print(f"Saved config and setting default: {config_path}") - def set_base_dir(self, base_dir): - data_dirs = _get_default_directory_structure(base_dir) - self.load_dict(data_dirs) - def dict(self): param_dict = self.__dict__.copy() param_dict['trough_id_params'] = self.trough_id_params.dict() @@ -178,8 +179,6 @@ def temp_config(self, **kwargs): original_params = self.dict().copy() new_params = original_params.copy() new_params.update(**kwargs) - if 'base_dir' in kwargs: - new_params.update(**_get_default_directory_structure(kwargs['base_dir'])) try: self.load_dict(new_params) yield self diff --git a/trough/_tec.py b/trough/_tec.py index 64c94b7..11b5aa8 100644 --- a/trough/_tec.py +++ b/trough/_tec.py @@ -173,11 +173,11 @@ def process_tec_dataset(start_date, end_date, download_dir=None, process_dir=Non if process_dir is None: process_dir = config.processed_tec_dir if mlt_bins is None: - mlt_bins = config.get_mlt_bins() + mlt_bins = config.mlt_bins if mlat_bins is None: - mlat_bins = config.get_mlat_bins() + mlat_bins = config.mlat_bins if dt is None: - dt = np.timedelta64(1, 'h') + dt = config.sample_dt Path(process_dir).mkdir(exist_ok=True, parents=True) logger.info(f"processing tec dataset over interval {start_date=} {end_date=}") diff --git a/trough/_trough.py b/trough/_trough.py index 71c0245..9241aab 100644 --- a/trough/_trough.py +++ b/trough/_trough.py @@ -76,7 +76,12 @@ def _get_weighted_kp(times, omni_data, tau=.6, T=10): prehistory = np.column_stack([ap[T - i:ap.shape[0] - i] for i in range(T)]) weight_factors = tau ** np.arange(T) ap_tau = np.sum((1 - tau) * prehistory * weight_factors, axis=1) - return 2.1 * np.log(.2 * ap_tau + 1) + values = 2.1 * np.log(.2 * ap_tau + 1) + if (times.values[1] - times.values[0]).astype('timedelta64[m]').astype(int) == 60: + return values + ut_initial = np.arange(times.values[0], times.values[-1] + np.timedelta64(1, 'h'), np.timedelta64(1, 'h')).astype('datetime64[s]').astype(int) + ut_final = times.values.astype('datetime64[s]').astype(int) + return np.interp(ut_final, ut_initial, values) def estimate_background(tec, patch_shape): @@ -322,7 +327,10 @@ def get_data(start_date, end_date, hemisphere, tec_dir=None, omni_file=None, lab omni_file = config.processed_omni_file if labels_dir is None: labels_dir = config.processed_labels_dir + tec = _tec.get_tec_data(start_date, end_date, hemisphere, tec_dir) data = xr.open_dataset(omni_file).sel(time=slice(start_date, end_date)) - data['tec'] = _tec.get_tec_data(start_date, end_date, hemisphere, tec_dir) + if (tec.time.values[1] - tec.time.values[0]).astype('timedelta64[m]').astype(int) != 60: + data = data.interp(time=tec.time) + data['tec'] = tec data['labels'] = get_trough_labels(start_date, end_date, hemisphere, labels_dir) return data