Skip to content

Commit

Permalink
beginning of adding tests to derive-annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
LoannPeurey committed Mar 12, 2024
1 parent 69329dc commit ce5385f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
13 changes: 0 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,4 @@
import os
import shutil

TO_REMOVE = [os.path.join('examples','valid_raw_data','metadata','annotations.csv'), os.path.join('output')]

def pytest_sessionstart(session):
"""
Called after the Session object has been created and
before performing collection and entering the run test loop.
"""
for path in TO_REMOVE:
if os.path.exists(path):
if os.path.isdir(path):
shutil.rmtree(path)
else:
os.remove(path)

34 changes: 21 additions & 13 deletions tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,26 +298,35 @@ def dv_func(a, b, x, type):
if type == 'number':
return 1
elif type == 'columns':
return pd.DataFrame([],columns=['segment_onset','segment_offset'])
return pd.DataFrame([], columns=['segment_onset','segment_offset'])
elif type == 'normal':
return x


# function used for derivation but does not hav correct signature
def bad_func(a, b):
return b
@pytest.mark.parametrize("input_set,function,output_set,ow,error",
[("missing", partial(dv_func,type='normal'), "output", False, AssertionError),
("derivation", partial(dv_func,type='number'), "output", False, None),
("derivation", partial(dv_func,type='columns'), "output", False, None),
("derivation", bad_func, "output", False, None),
("derivation", partial(dv_func,type='normal'), "derivation", False, AssertionError),
("input_reimport.csv", True, "imp_reimport_ow.csv", None, None),
("input_importoverlaps.csv", False, "imp_overlap.csv", "err_overlap.csv", None),
("input_import_duration_overflow.csv", False, None, None, AssertionError),
@pytest.mark.parametrize("input_set,function,output_set,exists,ow,error",
[("missing", partial(dv_func, type='normal'), "output", False, False, AssertionError),
("vtc_present", partial(dv_func, type='number'), "output", False, False, None),
("vtc_present", partial(dv_func, type='columns'), "output", False, False, None),
("vtc_present", bad_func, "output", False, False, None),
("vtc_present", partial(dv_func, type='normal'), "vtc_present", False, False, AssertionError),
("vtc_present", partial(dv_func, type='normal'), "output", True, False, AssertionError),
("vtc_present", partial(dv_func, type='normal'), "output", True, True, AssertionError),
])
def test_derive_inputs(project, am, input_set, function, output_set, ow, error):
pass
def test_derive_inputs(project, am, input_set, function, output_set, exists, ow, error):
am.read()
# copy the input set to act as an existing output_set
if exists:
shutil.copytree(src=PATH / 'annotations' / 'vtc_present', dst=PATH / 'annotations' / output_set)
additions = am.annotations[am.annotations['set'] == input_set].copy()
additions['set'] = output_set
am.annotations = pd.concat([am.annotations, additions])

if error:
with pytest.raises(error):
am.rename_set(old, new)

def test_intersect(project, am):
input_annotations = pd.read_csv("examples/valid_raw_data/annotations/intersect.csv")
Expand Down Expand Up @@ -589,7 +598,6 @@ def test_rename(project, am, old, new, error, mf, index):
tg_count = am.annotations[am.annotations["set"] == "textgrid"].shape[0]

if error:
print(am.annotations[am.annotations["set"] == new].shape[0])
with pytest.raises(error):
am.rename_set(old, new)
else:
Expand Down

0 comments on commit ce5385f

Please sign in to comment.