Skip to content

Commit

Permalink
Merge pull request #15 from gregstarr/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
gregstarr authored Jun 14, 2022
2 parents 5ec89cf + 51cef5a commit 795a4a7
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 77 deletions.
11 changes: 10 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

def pytest_addoption(parser):
parser.addoption("--skip-ftp", action='store_true', default=False)
parser.addoption("--run-slow", action="store_true", default=False, help="Run slow tests")


def pytest_collection_modifyitems(config, items):
if not config.getoption("--run-slow"):
skipper = pytest.mark.skip(reason="Only run when --run-slow is given")
for item in items:
if "slow" in item.keywords:
item.add_marker(skipper)


@pytest.fixture(scope='session', autouse=True)
Expand All @@ -22,7 +31,7 @@ def skip_ftp(request):
def setup_cfg(skip_ftp):
logger.info("setting up config")
with TemporaryDirectory() as tempdir:
config.set_base_dir(tempdir)
config.base_dir = tempdir
config.madrigal_user_affil = 'bu'
config.madrigal_user_email = '[email protected]'
config.madrigal_user_name = 'gregstarr'
Expand Down
8 changes: 4 additions & 4 deletions test/test_arb.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_download_arb(test_dates, download_dir):
['dt', 'mlt_vals'],
itertools.product(
[np.timedelta64(30, 'm'), np.timedelta64(1, 'h'), np.timedelta64(2, 'h')],
[config.get_mlt_vals(), np.arange(10)]
[config.mlt_vals, np.arange(10)]
)
)
def test_process_arb(download_dir, processed_dir, test_dates, dt, mlt_vals):
Expand All @@ -94,13 +94,13 @@ def test_process_arb_out_of_range(download_dir, processed_dir, test_dates):
start, end = [date - timedelta(days=100) for date in test_dates]
processed_file = Path(processed_dir) / 'arb_test.nc'
with pytest.raises(InvalidProcessDates):
process_interval(start, end, 'north', processed_file, download_dir, config.get_mlt_vals(), dt)
process_interval(start, end, 'north', processed_file, download_dir, config.mlt_vals, dt)


def test_get_arb_data(download_dir, processed_dir, test_dates):
start, end = test_dates
dt = np.timedelta64(1, 'h')
mlt = config.get_mlt_vals()
mlt = config.mlt_vals
correct_times = np.arange(np.datetime64(start), np.datetime64(end) + dt, dt)
processed_file = get_arb_paths(start, end, 'north', processed_dir)[0]
process_interval(start, end, 'north', processed_file, download_dir, mlt, dt)
Expand All @@ -124,7 +124,7 @@ def test_scripts(test_dates):
data = get_arb_data(start, end, 'north', cfg.processed_arb_dir)
data.load()
dt = np.timedelta64(1, 'h')
mlt = config.get_mlt_vals()
mlt = config.mlt_vals
correct_times = np.arange(np.datetime64(test_dates[0]), np.datetime64(test_dates[-1]) + dt, dt)
assert data.shape == (correct_times.shape[0], mlt.shape[0])
assert (data.mlt == mlt).all().item()
Expand Down
14 changes: 7 additions & 7 deletions test/test_tec.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def test_download_tec(test_dates, download_dir):
['dt', 'mlt_bins', 'mlat_bins'],
itertools.product(
[np.timedelta64(10, 'm'), np.timedelta64(30, 'm'), np.timedelta64(1, 'h'), np.timedelta64(2, 'h')],
[config.get_mlt_bins(), np.arange(10)],
[config.get_mlat_bins(), np.arange(10)]
[config.mlt_bins, np.arange(10)],
[config.mlat_bins, np.arange(10)]
)
)
def test_process_tec(download_dir, process_dir, test_dates, dt, mlt_bins, mlat_bins):
Expand Down Expand Up @@ -153,14 +153,14 @@ def test_process_tec_out_of_range(download_dir, process_dir, test_dates):
start, end = [date - timedelta(days=100) for date in test_dates]
processed_file = Path(process_dir) / 'tec_test.nc'
with pytest.raises(InvalidProcessDates):
process_interval(start, end, 'north', processed_file, download_dir, dt, config.get_mlat_bins(), config.get_mlt_bins())
process_interval(start, end, 'north', processed_file, download_dir, dt, config.mlat_bins, config.mlt_bins)


def test_get_tec_data(download_dir, process_dir, test_dates):
start, end = test_dates
dt = np.timedelta64(1, 'h')
mlt_bins = config.get_mlt_bins()
mlat_bins = config.get_mlat_bins()
mlt_bins = config.mlt_bins
mlat_bins = config.mlat_bins
mlt_vals = (mlt_bins[:-1] + mlt_bins[1:]) / 2
mlat_vals = (mlat_bins[:-1] + mlat_bins[1:]) / 2
correct_times = np.arange(np.datetime64(start), np.datetime64(end) + dt, dt)
Expand Down Expand Up @@ -191,8 +191,8 @@ def test_scripts(test_dates):
data = get_tec_data(start, end, hemisphere, cfg.processed_tec_dir)
data.load()
dt = np.timedelta64(1, 'h')
mlt_vals = config.get_mlt_vals()
mlat_vals = config.get_mlat_vals()
mlt_vals = config.mlt_vals
mlat_vals = config.mlat_vals
correct_times = np.arange(np.datetime64(test_dates[0]), np.datetime64(test_dates[-1]) + dt, dt)
h = 1 if hemisphere == 'north' else -1
assert data.shape == (correct_times.shape[0], mlat_vals.shape[0], mlt_vals.shape[0])
Expand Down
55 changes: 49 additions & 6 deletions test/test_trough.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_preprocess_interval():


def test_model_artificial_example():
mlt_grid, mlat_grid = np.meshgrid(config.get_mlt_vals(), config.get_mlat_vals())
mlt_grid, mlat_grid = np.meshgrid(config.mlt_vals, config.mlat_vals)
times = np.datetime64("2000") + np.arange(4) * np.timedelta64(1, 'h')
# nominal labels
labels1 = abs(mlat_grid - 65) <= 2
Expand All @@ -46,7 +46,7 @@ def test_model_artificial_example():
shp = (4,) + mlat_grid.shape
det_log_tec = det_log_tec.reshape(shp)
det_log_tec += np.random.randn(*shp) * .1
coords = {'time': times, 'mlat': config.get_mlat_vals(), 'mlt': config.get_mlt_vals()}
coords = {'time': times, 'mlat': config.mlat_vals, 'mlt': config.mlt_vals}
data = xr.Dataset({
'x': xr.DataArray(
det_log_tec,
Expand All @@ -55,7 +55,7 @@ def test_model_artificial_example():
),
'model': xr.DataArray(
np.ones((times.shape[0], mlt_grid.shape[1])) * 65,
coords={'time': times, 'mlt': config.get_mlt_vals()},
coords={'time': times, 'mlt': config.mlt_vals},
dims=['time', 'mlt']
),
})
Expand All @@ -72,7 +72,7 @@ def test_postprocess():
"""verify that small troughs are rejected, verify that troughs that wrap around the border are not incorrectly
rejected
"""
mlt_grid, mlat_grid = np.meshgrid(config.get_mlt_vals(), config.get_mlat_vals())
mlt_grid, mlat_grid = np.meshgrid(config.mlt_vals, config.mlat_vals)
good_labels = (abs(mlat_grid - 65) < 3) * (abs(mlt_grid) < 2)
small_reject = (abs(mlat_grid - 52) <= 1) * (abs(mlt_grid - 4) <= .5)
boundary_good_labels = (abs(mlat_grid - 65) < 3) * (abs(mlt_grid) >= 10.2)
Expand All @@ -83,7 +83,7 @@ def test_postprocess():
high_labels = (abs(mlat_grid - 80) < 2) * (abs(mlt_grid + 6) <= 3)
arb = np.ones((1, 180)) * 70
initial_labels = good_labels + small_reject + boundary_good_labels + boundary_bad_labels + weird_good_labels + high_labels
coords = {'time': [0], 'mlat': config.get_mlat_vals(), 'mlt': config.get_mlt_vals()}
coords = {'time': [0], 'mlat': config.mlat_vals, 'mlt': config.mlt_vals}
data = xr.Dataset({
'labels': xr.DataArray(
initial_labels[None],
Expand All @@ -92,7 +92,7 @@ def test_postprocess():
),
'arb': xr.DataArray(
arb,
coords={'time': [0], 'mlt': config.get_mlt_vals()},
coords={'time': [0], 'mlt': config.mlt_vals},
dims=['time', 'mlt']
),
})
Expand Down Expand Up @@ -217,3 +217,46 @@ def test_script(dates):
data.load()
assert data.time.shape[0] == n_times
data.close()


def test_script_multiple_config():
start_date = datetime(2015, 3, 17, 5, 0, 0)
end_date = datetime(2015, 3, 17, 15, 0, 0)
offset = timedelta(hours=5)
with TemporaryDirectory() as tempdir:
with config.temp_config(base_dir=tempdir):
scripts.full_run(start_date, end_date)
with config.temp_config(base_dir=tempdir, time_res_n=30, time_res_unit='m'):
scripts.full_run(start_date + offset, end_date + offset)
with config.temp_config(base_dir=tempdir):
n_files = len([p for p in Path(config.processed_labels_dir).glob('labels*.nc')])
n_times = 1 + ((end_date - start_date) / timedelta(hours=1))
assert n_files == (end_date.year - start_date.year + 1) * 2
data = get_data(start_date, end_date, 'north')
data.load()
assert data.time.shape[0] == n_times
data.close()
with config.temp_config(base_dir=tempdir, time_res_n=30, time_res_unit='m'):
n_files = len([p for p in Path(config.processed_labels_dir).glob('labels*.nc')])
n_times = 1 + ((end_date - start_date) / timedelta(minutes=30))
assert n_files == (end_date.year - start_date.year + 1) * 2
data = get_data(start_date + offset, end_date + offset, 'north')
data.load()
assert data.time.shape[0] == n_times
data.close()


@pytest.mark.slow
def test_date_error():
start_date = datetime(2015, 3, 17)
end_date = datetime(2015, 3, 18)
n_times = 1 + ((end_date - start_date) / timedelta(minutes=30))
with TemporaryDirectory() as tempdir:
with config.temp_config(base_dir=tempdir, time_res_n=30, time_res_unit='m'):
scripts.full_run(start_date, end_date)
n_files = len([p for p in Path(config.processed_labels_dir).glob('labels*.nc')])
assert n_files == (end_date.year - start_date.year + 1) * 2
data = get_data(start_date, end_date, 'north')
data.load()
assert data.time.shape[0] == n_times
data.close()
4 changes: 2 additions & 2 deletions trough/_arb.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def process_auroral_boundary_dataset(start_date, end_date, download_dir=None, pr
if process_dir is None:
process_dir = config.processed_arb_dir
if mlt_vals is None:
mlt_vals = config.get_mlt_vals()
mlt_vals = config.mlt_vals
if dt is None:
dt = np.timedelta64(1, 'h')
dt = config.sample_dt
Path(process_dir).mkdir(exist_ok=True, parents=True)

for year in range(start_date.year, end_date.year + 1):
Expand Down
103 changes: 51 additions & 52 deletions trough/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,6 @@ def dict(self):
return dataclasses.asdict(self)


def _get_default_directory_structure(base_dir):
base = Path(base_dir)
download_base = base / 'download'
processed_base = base / 'processed'
download_tec_dir = download_base / 'tec'
download_arb_dir = download_base / 'arb'
download_omni_dir = download_base / 'omni'
processed_tec_dir = processed_base / 'tec'
processed_arb_dir = processed_base / 'arb'
processed_omni_file = processed_base / 'omni.nc'
processed_labels_dir = processed_base / 'labels'
return {
'download_tec_dir': str(download_tec_dir),
'download_arb_dir': str(download_arb_dir),
'download_omni_dir': str(download_omni_dir),
'processed_tec_dir': str(processed_tec_dir),
'processed_arb_dir': str(processed_arb_dir),
'processed_omni_file': str(processed_omni_file),
'processed_labels_dir': str(processed_labels_dir),
}


trough_dirs = appdirs.AppDirs(appname='trough')


Expand All @@ -66,14 +44,7 @@ def parse_date(date_str):
class Config:

def __init__(self, config_path=None):
default_dirs = _get_default_directory_structure(trough_dirs.user_data_dir)
self.download_tec_dir = default_dirs['download_tec_dir']
self.download_arb_dir = default_dirs['download_arb_dir']
self.download_omni_dir = default_dirs['download_omni_dir']
self.processed_tec_dir = default_dirs['processed_tec_dir']
self.processed_arb_dir = default_dirs['processed_arb_dir']
self.processed_omni_file = default_dirs['processed_omni_file']
self.processed_labels_dir = default_dirs['processed_labels_dir']
self.base_dir = trough_dirs.user_data_dir
self.trough_id_params = TroughIdParams()
self.madrigal_user_name = None
self.madrigal_user_email = None
Expand All @@ -95,40 +66,74 @@ def __init__(self, config_path=None):
if config_path is not None:
self.load_json(config_path)

def get_config_name(self):
@property
def config_name(self):
cfg = self.dict()
return f"{cfg['script_name']}_{cfg['start_date']}_{cfg['end_date']}_config.json"

def get_mlat_bins(self):
@property
def mlat_bins(self):
return np.arange(self.mlat_min - self.lat_res / 2, 90, self.lat_res)

def get_mlat_vals(self):
@property
def mlat_vals(self):
return np.arange(self.mlat_min, 90, self.lat_res)

def get_mlt_bins(self):
@property
def mlt_bins(self):
return np.arange(-12, 12 + 24 / 360, self.lon_res * 24 / 360)

def get_mlt_vals(self):
@property
def mlt_vals(self):
return np.arange(-12 + .5 * self.lon_res * 24 / 360, 12 + 24 / 360, self.lon_res * 24 / 360)

def get_sample_dt(self):
@property
def sample_dt(self):
return np.timedelta64(self.time_res_n, self.time_res_unit)

@property
def download_base(self):
return str(Path(self.base_dir) / 'download')

@property
def processed_base(self):
return str(Path(self.base_dir) / f'processed_{self.lat_res}_{self.lon_res}_{self.time_res_n}{self.time_res_unit}')

@property
def download_tec_dir(self):
return str(Path(self.download_base) / 'tec')

@property
def download_arb_dir(self):
return str(Path(self.download_base) / 'arb')

@property
def download_omni_dir(self):
return str(Path(self.download_base) / 'omni')

@property
def processed_tec_dir(self):
return str(Path(self.processed_base) / 'tec')

@property
def processed_arb_dir(self):
return str(Path(self.processed_base) / 'arb')

@property
def processed_omni_file(self):
return str(Path(self.processed_base) / 'omni.nc')

@property
def processed_labels_dir(self):
return str(Path(self.processed_base) / 'labels')

def load_json(self, config_path):
with open(config_path) as f:
params = json.load(f)
if 'base_dir' in params:
params.update(**_get_default_directory_structure(params['base_dir']))
self.load_dict(params)

def load_dict(self, config_dict):
self.download_tec_dir = config_dict.get('download_tec_dir', self.download_tec_dir)
self.download_arb_dir = config_dict.get('download_arb_dir', self.download_arb_dir)
self.download_omni_dir = config_dict.get('download_omni_dir', self.download_omni_dir)
self.processed_tec_dir = config_dict.get('processed_tec_dir', self.processed_tec_dir)
self.processed_arb_dir = config_dict.get('processed_arb_dir', self.processed_arb_dir)
self.processed_omni_file = config_dict.get('processed_omni_file', self.processed_omni_file)
self.processed_labels_dir = config_dict.get('processed_labels_dir', self.processed_labels_dir)
self.base_dir = config_dict.get('base_dir', self.base_dir)
self.trough_id_params = config_dict.get('trough_id_params', self.trough_id_params)
if not isinstance(self.trough_id_params, TroughIdParams):
self.trough_id_params = TroughIdParams(**self.trough_id_params)
Expand All @@ -151,7 +156,7 @@ def load_dict(self, config_dict):

def save(self, config_path=None):
if config_path is None:
config_path = Path(trough_dirs.user_config_dir) / self.get_config_name()
config_path = Path(trough_dirs.user_config_dir) / self.config_name
save_dict = self.dict().copy()
Path(config_path).parent.mkdir(exist_ok=True, parents=True)
with open(config_path, 'w') as f:
Expand All @@ -160,10 +165,6 @@ def save(self, config_path=None):
cfg_pointer.write_text(str(config_path))
print(f"Saved config and setting default: {config_path}")

def set_base_dir(self, base_dir):
data_dirs = _get_default_directory_structure(base_dir)
self.load_dict(data_dirs)

def dict(self):
param_dict = self.__dict__.copy()
param_dict['trough_id_params'] = self.trough_id_params.dict()
Expand All @@ -178,8 +179,6 @@ def temp_config(self, **kwargs):
original_params = self.dict().copy()
new_params = original_params.copy()
new_params.update(**kwargs)
if 'base_dir' in kwargs:
new_params.update(**_get_default_directory_structure(kwargs['base_dir']))
try:
self.load_dict(new_params)
yield self
Expand Down
6 changes: 3 additions & 3 deletions trough/_tec.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,11 @@ def process_tec_dataset(start_date, end_date, download_dir=None, process_dir=Non
if process_dir is None:
process_dir = config.processed_tec_dir
if mlt_bins is None:
mlt_bins = config.get_mlt_bins()
mlt_bins = config.mlt_bins
if mlat_bins is None:
mlat_bins = config.get_mlat_bins()
mlat_bins = config.mlat_bins
if dt is None:
dt = np.timedelta64(1, 'h')
dt = config.sample_dt
Path(process_dir).mkdir(exist_ok=True, parents=True)

logger.info(f"processing tec dataset over interval {start_date=} {end_date=}")
Expand Down
Loading

0 comments on commit 795a4a7

Please sign in to comment.