Skip to content

Commit

Permalink
Refactor _make_iter method (#58)
Browse files Browse the repository at this point in the history
* Add type hints and refactor _make_iter method

* Refactor _make_iter method in test_experiment.py

* linting

* Update unittests.yml to trigger pull request workflows only on the main branch

* update lint
  • Loading branch information
rvhonorato authored Feb 26, 2024
1 parent af9b04d commit c25f260
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 170 deletions.
8 changes: 6 additions & 2 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: python lint
name: lint

on:
['push', 'pull_request']
'pull_request':
branches:
- main

jobs:
build:
runs-on: ${{ matrix.platform }}
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
name: unittests

on:
['push', 'pull_request']
'pull_request':
branches:
- main

jobs:
build:
Expand Down
41 changes: 22 additions & 19 deletions src/fandas/modules/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import logging
import os
import sys
from typing import List, Tuple

from fandas.modules.chemical import (
ATOM_LIST,
ATOM_REF,
EXPERIMENT_CATALOG,
SPECIAL_CASES,
)
from fandas.modules.chemical_shift import ChemShift

log = logging.getLogger("fandaslog")

Expand Down Expand Up @@ -94,6 +96,7 @@ def execute(self, atoms, direction, exp_id, skip_iteration=False, filtered=False
atom_list = self.translate_atoms(atoms)
atom_iterable = list(itertools.product(*atom_list))

assert shifts
shift_iterable = self._make_iter(shifts, atom_iterable, direction)

results = []
Expand Down Expand Up @@ -127,31 +130,31 @@ def _make_line(self, data_t):

return line

def _make_iter(self, shifts, atom_list, direction):
def _make_iter(
self,
shifts: ChemShift,
atom_list: List[Tuple[str, str]],
direction: List[List[int]],
) -> List[Tuple[Tuple[ChemShift, ChemShift], str]]:
"""Make an iterable to be used by the experiment executor."""
resnum_list = list(shifts.residues.items())
first_resum = resnum_list[0][0]
first_resnum = resnum_list[0][0]
last_resnum = resnum_list[-1][0]

dimension = len(direction)

combinations = []
for resnum in range(first_resum, last_resnum + 1):
for resnum in range(first_resnum, last_resnum + 1):
for e in itertools.product(*direction):
try:
if dimension == 3:
residue_1 = shifts.residues[resnum - e[0]]
residue_2 = shifts.residues[resnum - e[1]]
residue_3 = shifts.residues[resnum - e[2]]
combinations.append((residue_1, residue_2, residue_3))
elif dimension == 2:
residue_1 = shifts.residues[resnum - e[0]]
residue_2 = shifts.residues[resnum - e[1]]
combinations.append((residue_1, residue_2))

except KeyError:
# there was no shift found in a given index, skip it
pass
# Fetch residues based on adjusted indices
residues = [
shifts.residues[resnum + offset]
for offset in e
if (resnum + offset) in shifts.residues
]

# Ensure we have the correct number of residues before appending
if len(residues) == len(e):
combinations.append(tuple(residues))
# print(e, residues)

return list(itertools.product(combinations, atom_list))

Expand Down
192 changes: 44 additions & 148 deletions tests/modules/test_experiment.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""Test the experiment module."""
import pytest
import tempfile

import os
import tempfile
from pathlib import Path
from fandas.modules.experiment import Experiment

import pytest

from fandas.modules.chemical_shift import ChemShift
from fandas.modules.experiment import Experiment
from fandas.modules.input import InputFile

from .. import TEST_INPUT_FILE, TEST_DISTANCE_FILE
from .. import TEST_DISTANCE_FILE, TEST_INPUT_FILE


@pytest.fixture
Expand Down Expand Up @@ -90,6 +93,19 @@ def test__make_line(experiment_class, chemical_shifts):

def test__make_iter(experiment_class, chemical_shifts):
"""Test the _make_iter method."""

observed_iter = experiment_class._make_iter(
shifts=chemical_shifts, atom_list=[("H", "H")], direction=[[0], [-1]]
)
expected_iter = [
((chemical_shifts.residues[2], chemical_shifts.residues[1]), ("H", "H")),
((chemical_shifts.residues[3], chemical_shifts.residues[2]), ("H", "H")),
((chemical_shifts.residues[4], chemical_shifts.residues[3]), ("H", "H")),
((chemical_shifts.residues[5], chemical_shifts.residues[4]), ("H", "H")),
]

assert observed_iter == expected_iter

observed_iter = experiment_class._make_iter(
shifts=chemical_shifts, atom_list=[("H", "H")], direction=[[0], [0]]
)
Expand Down Expand Up @@ -119,160 +135,40 @@ def test__make_iter(experiment_class, chemical_shifts):
((chemical_shifts.residues[5], chemical_shifts.residues[5]), ("H", "N")),
]

assert observed_iter == expected_iter

observed_iter = experiment_class._make_iter(
shifts=chemical_shifts,
atom_list=[("H", "H", "H"), ("H", "N", "H")],
direction=[[0], [0], [+1, 0]],
atom_list=[("H", "H", "H")],
direction=[[0], [0], [-1, 0, 1]],
)

expected_iter = [
(
(
chemical_shifts.residues[1],
chemical_shifts.residues[1],
chemical_shifts.residues[1],
),
("H", "H", "H"),
),
(
(
chemical_shifts.residues[1],
chemical_shifts.residues[1],
chemical_shifts.residues[1],
),
("H", "N", "H"),
),
(
(
chemical_shifts.residues[2],
chemical_shifts.residues[2],
chemical_shifts.residues[1],
),
("H", "H", "H"),
),
(
(
chemical_shifts.residues[2],
chemical_shifts.residues[2],
chemical_shifts.residues[1],
),
("H", "N", "H"),
),
(
(
chemical_shifts.residues[2],
chemical_shifts.residues[2],
chemical_shifts.residues[2],
),
("H", "H", "H"),
),
(
(
chemical_shifts.residues[2],
chemical_shifts.residues[2],
chemical_shifts.residues[2],
),
("H", "N", "H"),
),
(
(
chemical_shifts.residues[3],
chemical_shifts.residues[3],
chemical_shifts.residues[2],
),
("H", "H", "H"),
),
(
(
chemical_shifts.residues[3],
chemical_shifts.residues[3],
chemical_shifts.residues[2],
),
("H", "N", "H"),
),
(
(
chemical_shifts.residues[3],
chemical_shifts.residues[3],
chemical_shifts.residues[3],
),
("H", "H", "H"),
),
(
(
chemical_shifts.residues[3],
chemical_shifts.residues[3],
chemical_shifts.residues[3],
),
("H", "N", "H"),
),
(
(
chemical_shifts.residues[4],
chemical_shifts.residues[4],
chemical_shifts.residues[3],
),
("H", "H", "H"),
),
(
(
chemical_shifts.residues[4],
chemical_shifts.residues[4],
chemical_shifts.residues[3],
),
("H", "N", "H"),
),
assert observed_iter[2] == (
(
(
chemical_shifts.residues[4],
chemical_shifts.residues[4],
chemical_shifts.residues[4],
),
("H", "H", "H"),
),
(
(
chemical_shifts.residues[4],
chemical_shifts.residues[4],
chemical_shifts.residues[4],
),
("H", "N", "H"),
),
(
(
chemical_shifts.residues[5],
chemical_shifts.residues[5],
chemical_shifts.residues[4],
),
("H", "H", "H"),
),
(
(
chemical_shifts.residues[5],
chemical_shifts.residues[5],
chemical_shifts.residues[4],
),
("H", "N", "H"),
chemical_shifts.residues[2],
chemical_shifts.residues[2],
chemical_shifts.residues[1],
),
("H", "H", "H"),
)

assert observed_iter[3] == (
(
(
chemical_shifts.residues[5],
chemical_shifts.residues[5],
chemical_shifts.residues[5],
),
("H", "H", "H"),
chemical_shifts.residues[2],
chemical_shifts.residues[2],
chemical_shifts.residues[2],
),
("H", "H", "H"),
)

assert observed_iter[4] == (
(
(
chemical_shifts.residues[5],
chemical_shifts.residues[5],
chemical_shifts.residues[5],
),
("H", "N", "H"),
chemical_shifts.residues[2],
chemical_shifts.residues[2],
chemical_shifts.residues[3],
),
]

assert observed_iter == expected_iter
("H", "H", "H"),
)


def test_translate_atoms(experiment_class):
Expand Down

0 comments on commit c25f260

Please sign in to comment.