diff --git a/src/in_silico_fate_mapping/_reader.py b/src/in_silico_fate_mapping/_reader.py index cfc564a..b657d93 100644 --- a/src/in_silico_fate_mapping/_reader.py +++ b/src/in_silico_fate_mapping/_reader.py @@ -3,7 +3,13 @@ import numpy as np import pandas as pd -TRACKS_HEADER = ("TrackID", "t", "z", "y", "x") +TRACKS_HEADER = ( + ("track_id", "TrackID"), + ("t", "T"), + ("z", "Z"), + ("y", "Z"), + ("x", "X"), +) def napari_get_reader(path): @@ -17,8 +23,8 @@ def napari_get_reader(path): return None header = pd.read_csv(path, nrows=0).columns.tolist() - for colname in TRACKS_HEADER: - if colname != "z" and colname not in header: + for colnames in TRACKS_HEADER: + if all(c not in header for c in colnames) and colnames[0] != "z": return None return reader_function @@ -28,12 +34,15 @@ def read_csv(path: str): df = pd.read_csv(path) data = [] - for colname in TRACKS_HEADER: - try: - data.append(df[colname]) - except KeyError: - if colname != "z": - raise KeyError(f"{colname} not found in .csv header.") + for colnames in TRACKS_HEADER: + found = False + for c in colnames: + if c in df.columns: + data.append(df[c]) + found = True + break + if not found and colnames[0] != "z": + raise KeyError(f"{colnames[0]} not found in .csv header.") data = np.stack(data).T diff --git a/src/in_silico_fate_mapping/_tests/test_io.py b/src/in_silico_fate_mapping/_tests/test_io.py index 7918937..9898a30 100644 --- a/src/in_silico_fate_mapping/_tests/test_io.py +++ b/src/in_silico_fate_mapping/_tests/test_io.py @@ -25,10 +25,14 @@ def tracks(n_nodes: int = 10) -> pd.DataFrame: return pd.DataFrame(tracks_data, columns=["TrackID", "t", "z", "y", "x"]) -def test_get_reader(tmp_path: Path, tracks: pd.DataFrame) -> None: +@pytest.mark.parametrize("track_id_col", ["TrackID", "track_id"]) +def test_get_reader( + tmp_path: Path, tracks: pd.DataFrame, track_id_col: str +) -> None: path = tmp_path / "good_tracks.csv" tracks["NodeID"] = np.arange(len(tracks)) + 1 tracks["Labels"] = np.random.randint(2, size=len(tracks)) + tracks.rename(columns={"TrackID": track_id_col}, inplace=True) tracks.to_csv(path, index=False) reader = napari_get_reader(path) @@ -41,7 +45,7 @@ def test_get_reader(tmp_path: Path, tracks: pd.DataFrame) -> None: assert np.allclose(props["NodeID"], tracks["NodeID"]) assert np.allclose(props["Labels"], tracks["Labels"]) - assert np.allclose(data, tracks[["TrackID", "t", "z", "y", "x"]]) + assert np.allclose(data, tracks[[track_id_col, "t", "z", "y", "x"]]) def test_napari_read( diff --git a/src/in_silico_fate_mapping/_writer.py b/src/in_silico_fate_mapping/_writer.py index b17234e..bc97fa7 100644 --- a/src/in_silico_fate_mapping/_writer.py +++ b/src/in_silico_fate_mapping/_writer.py @@ -8,7 +8,8 @@ def napari_write_tracks(path: str, data: np.ndarray, meta: dict) -> List[str]: - header = list(TRACKS_HEADER) + # first position are the default + header = list(c[0] for c in TRACKS_HEADER) if data.shape[1] == 4: header.remove("z")