diff --git a/tests/conftest.py b/tests/conftest.py index 22153455..4cc25be7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,4 @@ import os import shutil -TO_REMOVE = [os.path.join('examples','valid_raw_data','metadata','annotations.csv'), os.path.join('output')] - -def pytest_sessionstart(session): - """ - Called after the Session object has been created and - before performing collection and entering the run test loop. - """ - for path in TO_REMOVE: - if os.path.exists(path): - if os.path.isdir(path): - shutil.rmtree(path) - else: - os.remove(path) diff --git a/tests/test_annotations.py b/tests/test_annotations.py index 568cc232..76142121 100644 --- a/tests/test_annotations.py +++ b/tests/test_annotations.py @@ -298,7 +298,7 @@ def dv_func(a, b, x, type): if type == 'number': return 1 elif type == 'columns': - return pd.DataFrame([],columns=['segment_onset','segment_offset']) + return pd.DataFrame([], columns=['segment_onset','segment_offset']) elif type == 'normal': return x @@ -306,18 +306,27 @@ def dv_func(a, b, x, type): # function used for derivation but does not hav correct signature def bad_func(a, b): return b -@pytest.mark.parametrize("input_set,function,output_set,ow,error", - [("missing", partial(dv_func,type='normal'), "output", False, AssertionError), - ("derivation", partial(dv_func,type='number'), "output", False, None), - ("derivation", partial(dv_func,type='columns'), "output", False, None), - ("derivation", bad_func, "output", False, None), - ("derivation", partial(dv_func,type='normal'), "derivation", False, AssertionError), - ("input_reimport.csv", True, "imp_reimport_ow.csv", None, None), - ("input_importoverlaps.csv", False, "imp_overlap.csv", "err_overlap.csv", None), - ("input_import_duration_overflow.csv", False, None, None, AssertionError), +@pytest.mark.parametrize("input_set,function,output_set,exists,ow,error", + [("missing", partial(dv_func, type='normal'), "output", False, False, AssertionError), + ("vtc_present", partial(dv_func, type='number'), "output", False, False, None), + ("vtc_present", partial(dv_func, type='columns'), "output", False, False, None), + ("vtc_present", bad_func, "output", False, False, None), + ("vtc_present", partial(dv_func, type='normal'), "vtc_present", False, False, AssertionError), + ("vtc_present", partial(dv_func, type='normal'), "output", True, False, AssertionError), + ("vtc_present", partial(dv_func, type='normal'), "output", True, True, AssertionError), ]) -def test_derive_inputs(project, am, input_set, function, output_set, ow, error): - pass +def test_derive_inputs(project, am, input_set, function, output_set, exists, ow, error): + am.read() + # copy the input set to act as an existing output_set + if exists: + shutil.copytree(src=PATH / 'annotations' / 'vtc_present', dst=PATH / 'annotations' / output_set) + additions = am.annotations[am.annotations['set'] == input_set].copy() + additions['set'] = output_set + am.annotations = pd.concat([am.annotations, additions]) + + if error: + with pytest.raises(error): + am.rename_set(old, new) def test_intersect(project, am): input_annotations = pd.read_csv("examples/valid_raw_data/annotations/intersect.csv") @@ -589,7 +598,6 @@ def test_rename(project, am, old, new, error, mf, index): tg_count = am.annotations[am.annotations["set"] == "textgrid"].shape[0] if error: - print(am.annotations[am.annotations["set"] == new].shape[0]) with pytest.raises(error): am.rename_set(old, new) else: