From 85dc76a9b414cc17f7fabe64f34b400b18d1c504 Mon Sep 17 00:00:00 2001
From: szhan <shing.zhan@gmail.com>
Date: Mon, 1 Jul 2024 17:17:25 +0100
Subject: [PATCH] Add argument for pass function to define emission
 probabilities

---
 lshmm/api.py                        | 106 +++++++++++-------
 lshmm/fb_haploid.py                 |  32 ++++--
 lshmm/vit_haploid.py                | 137 ++++++++++++++++++-----
 tests/test_api_fb_haploid.py        |   5 +-
 tests/test_api_fb_haploid_multi.py  |   3 +
 tests/test_api_vit_haploid.py       |   2 +
 tests/test_api_vit_haploid_multi.py |  23 ++--
 tests/test_nontree_fb_haploid.py    |  46 +++++++-
 tests/test_nontree_vit_haploid.py   | 166 ++++++++++++++++++++++++----
 9 files changed, 406 insertions(+), 114 deletions(-)

diff --git a/lshmm/api.py b/lshmm/api.py
index d6d662d..890a985 100644
--- a/lshmm/api.py
+++ b/lshmm/api.py
@@ -223,23 +223,34 @@ def forwards(
     )
 
     if ploidy == 1:
-        forward_function = forwards_ls_hap
+        (
+            forward_array,
+            normalisation_factor_from_forward,
+            log_lik,
+        ) = forwards_ls_hap(
+            num_ref_haps,
+            num_sites,
+            ref_panel_checked,
+            query_checked,
+            emission_matrix,
+            prob_recombination,
+            norm=normalise,
+            emission_func=core.get_emission_probability_haploid,
+        )
     else:
-        forward_function = forward_ls_dip_loop
-
-    (
-        forward_array,
-        normalisation_factor_from_forward,
-        log_lik,
-    ) = forward_function(
-        num_ref_haps,
-        num_sites,
-        ref_panel_checked,
-        query_checked,
-        emission_matrix,
-        prob_recombination,
-        norm=normalise,
-    )
+        (
+            forward_array,
+            normalisation_factor_from_forward,
+            log_lik,
+        ) = forward_ls_dip_loop(
+            num_ref_haps,
+            num_sites,
+            ref_panel_checked,
+            query_checked,
+            emission_matrix,
+            prob_recombination,
+            norm=normalise,
+        )
 
     return forward_array, normalisation_factor_from_forward, log_lik
 
@@ -267,19 +278,26 @@ def backwards(
     )
 
     if ploidy == 1:
-        backward_function = backwards_ls_hap
+        backwards_array = backwards_ls_hap(
+            num_ref_haps,
+            num_sites,
+            ref_panel_checked,
+            query_checked,
+            emission_matrix,
+            normalisation_factor_from_forward,
+            prob_recombination,
+            emission_func=core.get_emission_probability_haploid,
+        )
     else:
-        backward_function = backward_ls_dip_loop
-
-    backwards_array = backward_function(
-        num_ref_haps,
-        num_sites,
-        ref_panel_checked,
-        query_checked,
-        emission_matrix,
-        normalisation_factor_from_forward,
-        prob_recombination,
-    )
+        backwards_array = backward_ls_dip_loop(
+            num_ref_haps,
+            num_sites,
+            ref_panel_checked,
+            query_checked,
+            emission_matrix,
+            normalisation_factor_from_forward,
+            prob_recombination,
+        )
 
     return backwards_array
 
@@ -313,6 +331,7 @@ def viterbi(
             query_checked,
             emission_matrix,
             prob_recombination,
+            emission_func=core.get_emission_probability_haploid,
         )
         best_path = backwards_viterbi_hap(num_sites, V, P)
     else:
@@ -353,18 +372,25 @@ def path_loglik(
     )
 
     if ploidy == 1:
-        path_ll_function = path_ll_hap
+        log_lik = path_ll_hap(
+            num_ref_haps,
+            num_sites,
+            ref_panel_checked,
+            path,
+            query_checked,
+            emission_matrix,
+            prob_recombination,
+            emission_func=core.get_emission_probability_haploid,
+        )
     else:
-        path_ll_function = path_ll_dip
-
-    log_lik = path_ll_function(
-        num_ref_haps,
-        num_sites,
-        ref_panel_checked,
-        path,
-        query_checked,
-        emission_matrix,
-        prob_recombination,
-    )
+        log_lik = path_ll_dip(
+            num_ref_haps,
+            num_sites,
+            ref_panel_checked,
+            path,
+            query_checked,
+            emission_matrix,
+            prob_recombination,
+        )
 
     return log_lik
diff --git a/lshmm/fb_haploid.py b/lshmm/fb_haploid.py
index 7ae8460..fa70f3b 100644
--- a/lshmm/fb_haploid.py
+++ b/lshmm/fb_haploid.py
@@ -7,7 +7,16 @@
 
 
 @jit.numba_njit
-def forwards_ls_hap(n, m, H, s, e, r, norm=True):
+def forwards_ls_hap(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+    norm=True,
+):
     """
     A matrix-based implementation using Numpy.
 
@@ -20,7 +29,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
     if norm:
         c = np.zeros(m)
         for i in range(n):
-            emission_prob = core.get_emission_probability_haploid(
+            emission_prob = emission_func(
                 ref_allele=H[0, i],
                 query_allele=s[0, 0],
                 site=0,
@@ -36,7 +45,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
         for l in range(1, m):
             for i in range(n):
                 F[l, i] = F[l - 1, i] * (1 - r[l]) + r_n[l]
-                emission_prob = core.get_emission_probability_haploid(
+                emission_prob = emission_func(
                     ref_allele=H[l, i],
                     query_allele=s[0, l],
                     site=l,
@@ -53,7 +62,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
     else:
         c = np.ones(m)
         for i in range(n):
-            emission_prob = core.get_emission_probability_haploid(
+            emission_prob = emission_func(
                 ref_allele=H[0, i],
                 query_allele=s[0, 0],
                 site=0,
@@ -65,7 +74,7 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
         for l in range(1, m):
             for i in range(n):
                 F[l, i] = F[l - 1, i] * (1 - r[l]) + np.sum(F[l - 1, :]) * r_n[l]
-                emission_prob = core.get_emission_probability_haploid(
+                emission_prob = emission_func(
                     ref_allele=H[l, i],
                     query_allele=s[0, l],
                     site=l,
@@ -79,7 +88,16 @@ def forwards_ls_hap(n, m, H, s, e, r, norm=True):
 
 
 @jit.numba_njit
-def backwards_ls_hap(n, m, H, s, e, c, r):
+def backwards_ls_hap(
+    n,
+    m,
+    H,
+    s,
+    e,
+    c,
+    r,
+    emission_func,
+):
     """
     A matrix-based implementation using Numpy.
 
@@ -96,7 +114,7 @@ def backwards_ls_hap(n, m, H, s, e, c, r):
         tmp_B = np.zeros(n)
         tmp_B_sum = 0
         for i in range(n):
-            emission_prob = core.get_emission_probability_haploid(
+            emission_prob = emission_func(
                 ref_allele=H[l + 1, i],
                 query_allele=s[0, l + 1],
                 site=l + 1,
diff --git a/lshmm/vit_haploid.py b/lshmm/vit_haploid.py
index 87dbb97..cdcf833 100644
--- a/lshmm/vit_haploid.py
+++ b/lshmm/vit_haploid.py
@@ -7,7 +7,15 @@
 
 
 @jit.numba_njit
-def viterbi_naive_init(n, m, H, s, e, r):
+def viterbi_naive_init(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """Initialise a naive implementation."""
     V = np.zeros((m, n))
     P = np.zeros((m, n), dtype=np.int64)
@@ -15,7 +23,7 @@ def viterbi_naive_init(n, m, H, s, e, r):
     r_n = r / num_copiable_entries
 
     for i in range(n):
-        emission_prob = core.get_emission_probability_haploid(
+        emission_prob = emission_func(
             ref_allele=H[0, i],
             query_allele=s[0, 0],
             site=0,
@@ -27,7 +35,15 @@ def viterbi_naive_init(n, m, H, s, e, r):
 
 
 @jit.numba_njit
-def viterbi_init(n, m, H, s, e, r):
+def viterbi_init(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """Initialise a naive, but more memory efficient, implementation."""
     V_prev = np.zeros(n)
     V = np.zeros(n)
@@ -36,7 +52,7 @@ def viterbi_init(n, m, H, s, e, r):
     r_n = r / num_copiable_entries
 
     for i in range(n):
-        emission_prob = core.get_emission_probability_haploid(
+        emission_prob = emission_func(
             ref_allele=H[0, i],
             query_allele=s[0, 0],
             site=0,
@@ -48,15 +64,23 @@ def viterbi_init(n, m, H, s, e, r):
 
 
 @jit.numba_njit
-def forwards_viterbi_hap_naive(n, m, H, s, e, r):
+def forwards_viterbi_hap_naive(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """A naive implementation of the forward pass."""
-    V, P, r_n = viterbi_naive_init(n, m, H, s, e, r)
+    V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func=emission_func)
 
     for j in range(1, m):
         for i in range(n):
             v = np.zeros(n)
             for k in range(n):
-                emission_prob = core.get_emission_probability_haploid(
+                emission_prob = emission_func(
                     ref_allele=H[j, i],
                     query_allele=s[0, j],
                     site=j,
@@ -76,16 +100,24 @@ def forwards_viterbi_hap_naive(n, m, H, s, e, r):
 
 
 @jit.numba_njit
-def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
+def forwards_viterbi_hap_naive_vec(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """A naive matrix-based implementation of the forward pass."""
-    V, P, r_n = viterbi_naive_init(n, m, H, s, e, r)
+    V, P, r_n = viterbi_naive_init(n, m, H, s, e, r, emission_func=emission_func)
 
     for j in range(1, m):
         v_tmp = V[j - 1, :] * r_n[j]
         for i in range(n):
             v = np.copy(v_tmp)
             v[i] += V[j - 1, i] * (1 - r[j])
-            emission_prob = core.get_emission_probability_haploid(
+            emission_prob = emission_func(
                 ref_allele=H[j, i],
                 query_allele=s[0, j],
                 site=j,
@@ -101,15 +133,23 @@ def forwards_viterbi_hap_naive_vec(n, m, H, s, e, r):
 
 
 @jit.numba_njit
-def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
+def forwards_viterbi_hap_naive_low_mem(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """A naive implementation of the forward pass with reduced memory."""
-    V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)
+    V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func)
 
     for j in range(1, m):
         for i in range(n):
             v = np.zeros(n)
             for k in range(n):
-                emission_prob = core.get_emission_probability_haploid(
+                emission_prob = emission_func(
                     ref_allele=H[j, i],
                     query_allele=s[0, j],
                     site=j,
@@ -130,9 +170,17 @@ def forwards_viterbi_hap_naive_low_mem(n, m, H, s, e, r):
 
 
 @jit.numba_njit
-def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
+def forwards_viterbi_hap_naive_low_mem_rescaling(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """A naive implementation of the forward pass with reduced memory and rescaling."""
-    V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)
+    V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func)
     c = np.ones(m)
 
     for j in range(1, m):
@@ -141,7 +189,7 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
         for i in range(n):
             v = np.zeros(n)
             for k in range(n):
-                emission_prob = core.get_emission_probability_haploid(
+                emission_prob = emission_func(
                     ref_allele=H[j, i],
                     query_allele=s[0, j],
                     site=j,
@@ -162,9 +210,17 @@ def forwards_viterbi_hap_naive_low_mem_rescaling(n, m, H, s, e, r):
 
 
 @jit.numba_njit
-def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
+def forwards_viterbi_hap_low_mem_rescaling(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """An implementation with reduced memory that exploits the Markov structure."""
-    V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r)
+    V, V_prev, P, r_n = viterbi_init(n, m, H, s, e, r, emission_func=emission_func)
     c = np.ones(m)
 
     for j in range(1, m):
@@ -178,7 +234,7 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
             if V[i] < r_n[j]:
                 V[i] = r_n[j]
                 P[j, i] = argmax
-            emission_prob = core.get_emission_probability_haploid(
+            emission_prob = emission_func(
                 ref_allele=H[j, i],
                 query_allele=s[0, j],
                 site=j,
@@ -193,7 +249,15 @@ def forwards_viterbi_hap_low_mem_rescaling(n, m, H, s, e, r):
 
 
 @jit.numba_njit
-def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
+def forwards_viterbi_hap_lower_mem_rescaling(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """
     An implementation with even smaller memory footprint
     that exploits the Markov structure.
@@ -202,7 +266,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
     """
     V = np.zeros(n)
     for i in range(n):
-        emission_prob = core.get_emission_probability_haploid(
+        emission_prob = emission_func(
             ref_allele=H[0, i],
             query_allele=s[0, 0],
             site=0,
@@ -224,7 +288,7 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
             if V[i] < r_n[j]:
                 V[i] = r_n[j]
                 P[j, i] = argmax
-            emission_prob = core.get_emission_probability_haploid(
+            emission_prob = emission_func(
                 ref_allele=H[j, i],
                 query_allele=s[0, j],
                 site=j,
@@ -238,14 +302,22 @@ def forwards_viterbi_hap_lower_mem_rescaling(n, m, H, s, e, r):
 
 
 @jit.numba_njit
-def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
+def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
+    n,
+    m,
+    H,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """
     An implementation with even smaller memory footprint and rescaling
     that exploits the Markov structure.
     """
     V = np.zeros(n)
     for i in range(n):
-        emission_prob = core.get_emission_probability_haploid(
+        emission_prob = emission_func(
             ref_allele=H[0, i],
             query_allele=s[0, 0],
             site=0,
@@ -273,7 +345,7 @@ def forwards_viterbi_hap_lower_mem_rescaling_no_pointer(n, m, H, s, e, r):
                 recombs[j] = np.append(
                     recombs[j], i
                 )  # We add template i as a potential template to recombine to at site j.
-            emission_prob = core.get_emission_probability_haploid(
+            emission_prob = emission_func(
                 ref_allele=H[j, i],
                 query_allele=s[0, j],
                 site=j,
@@ -320,13 +392,22 @@ def backwards_viterbi_hap_no_pointer(m, V_argmaxes, recombs):
 
 
 @jit.numba_njit
-def path_ll_hap(n, m, H, path, s, e, r):
+def path_ll_hap(
+    n,
+    m,
+    H,
+    path,
+    s,
+    e,
+    r,
+    emission_func,
+):
     """
     Evaluate the log-likelihood of a path through a reference panel resulting in a query.
 
     This is exposed via the API.
     """
-    emission_prob = core.get_emission_probability_haploid(
+    emission_prob = emission_func(
         ref_allele=H[0, path[0]],
         query_allele=s[0, 0],
         site=0,
@@ -338,7 +419,7 @@ def path_ll_hap(n, m, H, path, s, e, r):
     r_n = r / num_copiable_entries
 
     for l in range(1, m):
-        emission_prob = core.get_emission_probability_haploid(
+        emission_prob = emission_func(
             ref_allele=H[l, path[l]],
             query_allele=s[0, l],
             site=l,
diff --git a/tests/test_api_fb_haploid.py b/tests/test_api_fb_haploid.py
index 7238283..1cf9656 100644
--- a/tests/test_api_fb_haploid.py
+++ b/tests/test_api_fb_haploid.py
@@ -16,7 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
             include_ancestors=include_ancestors,
             include_extreme_rates=True,
         ):
-            num_alleles = core.get_num_alleles(H_vs, s)
+            emission_func = core.get_emission_probability_haploid
             F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(
                 n=n,
                 m=m,
@@ -24,6 +24,8 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
                 s=s,
                 e=e_vs,
                 r=r,
+                emission_func=emission_func,
+                norm=True,
             )
             B_vs = fbh.backwards_ls_hap(
                 n=n,
@@ -33,6 +35,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
                 e=e_vs,
                 c=c_vs,
                 r=r,
+                emission_func=emission_func,
             )
             F, c, ll = ls.forwards(
                 reference_panel=H_vs,
diff --git a/tests/test_api_fb_haploid_multi.py b/tests/test_api_fb_haploid_multi.py
index a90f57a..77b1b6d 100644
--- a/tests/test_api_fb_haploid_multi.py
+++ b/tests/test_api_fb_haploid_multi.py
@@ -16,6 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
             include_ancestors=include_ancestors,
             include_extreme_rates=True,
         ):
+            emission_func = core.get_emission_probability_haploid
             F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(
                 n=n,
                 m=m,
@@ -23,6 +24,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
                 s=s,
                 e=e_vs,
                 r=r,
+                emission_func=emission_func,
             )
             B_vs = fbh.backwards_ls_hap(
                 n=n,
@@ -32,6 +34,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
                 e=e_vs,
                 c=c_vs,
                 r=r,
+                emission_func=emission_func,
             )
             F, c, ll = ls.forwards(
                 reference_panel=H_vs,
diff --git a/tests/test_api_vit_haploid.py b/tests/test_api_vit_haploid.py
index c45e27d..0cf8590 100644
--- a/tests/test_api_vit_haploid.py
+++ b/tests/test_api_vit_haploid.py
@@ -16,6 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
             include_ancestors=include_ancestors,
             include_extreme_rates=True,
         ):
+            emission_func = core.get_emission_probability_haploid
             V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
                 n=n,
                 m=m,
@@ -23,6 +24,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
                 s=s,
                 e=e_vs,
                 r=r,
+                emission_func=emission_func,
             )
             path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs)
             path, ll = ls.viterbi(
diff --git a/tests/test_api_vit_haploid_multi.py b/tests/test_api_vit_haploid_multi.py
index 5020171..5f5c6dd 100644
--- a/tests/test_api_vit_haploid_multi.py
+++ b/tests/test_api_vit_haploid_multi.py
@@ -16,6 +16,7 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
             include_ancestors=include_ancestors,
             include_extreme_rates=True,
         ):
+            emission_func = core.get_emission_probability_haploid
             V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_lower_mem_rescaling(
                 n=n,
                 m=m,
@@ -23,9 +24,19 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
                 s=s,
                 e=e_vs,
                 r=r,
+                emission_func=emission_func,
             )
             path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs, P=P_vs)
-            path_ll_hap = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r)
+            path_ll_hap = vh.path_ll_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                path=path_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+            )
             path, ll = ls.viterbi(
                 reference_panel=H_vs,
                 query=s,
@@ -44,17 +55,11 @@ def test_ts_multiallelic_n10_no_recomb(
         self, scale_mutation_rate, include_ancestors
     ):
         ts = self.get_ts_multiallelic_n10_no_recomb()
-        self.verify(
-            ts,
-            scale_mutation_rate=scale_mutation_rate,
-            include_ancestors=include_ancestors,
-        )
+        self.verify(ts, scale_mutation_rate, include_ancestors)
 
     @pytest.mark.parametrize("num_samples", [6, 8, 16])
     @pytest.mark.parametrize("scale_mutation_rate", [True, False])
     @pytest.mark.parametrize("include_ancestors", [True, False])
-    def test_ts_multiallelic_n16(
-        self, num_samples, scale_mutation_rate, include_ancestors
-    ):
+    def test_ts_multiallelic(self, num_samples, scale_mutation_rate, include_ancestors):
         ts = self.get_ts_multiallelic(num_samples)
         self.verify(ts, scale_mutation_rate, include_ancestors)
diff --git a/tests/test_nontree_fb_haploid.py b/tests/test_nontree_fb_haploid.py
index 7f76a0f..331aabe 100644
--- a/tests/test_nontree_fb_haploid.py
+++ b/tests/test_nontree_fb_haploid.py
@@ -10,20 +10,56 @@
 
 class TestNonTreeForwardBackwardHaploid(lsbase.ForwardBackwardAlgorithmBase):
     def verify(self, ts, scale_mutation_rate, include_ancestors):
+        ploidy = 1
         for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars(
             ts,
-            ploidy=1,
+            ploidy=ploidy,
             scale_mutation_rate=scale_mutation_rate,
             include_ancestors=include_ancestors,
             include_extreme_rates=True,
         ):
-            F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(n, m, H_vs, s, e_vs, r, norm=False)
-            B_vs = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_vs, r)
+            emission_func = core.get_emission_probability_haploid
+            F_vs, c_vs, ll_vs = fbh.forwards_ls_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+                norm=False,
+            )
+            B_vs = fbh.backwards_ls_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                c=c_vs,
+                r=r,
+                emission_func=emission_func,
+            )
             self.assertAllClose(np.log10(np.sum(F_vs * B_vs, 1)), ll_vs * np.ones(m))
             F_tmp, c_tmp, ll_tmp = fbh.forwards_ls_hap(
-                n, m, H_vs, s, e_vs, r, norm=True
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+                norm=True,
+            )
+            B_tmp = fbh.backwards_ls_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                c=c_tmp,
+                r=r,
+                emission_func=emission_func,
             )
-            B_tmp = fbh.backwards_ls_hap(n, m, H_vs, s, e_vs, c_tmp, r)
             self.assertAllClose(np.sum(F_tmp * B_tmp, 1), np.ones(m))
             self.assertAllClose(ll_vs, ll_tmp)
 
diff --git a/tests/test_nontree_vit_haploid.py b/tests/test_nontree_vit_haploid.py
index 93bd7a5..3560c78 100644
--- a/tests/test_nontree_vit_haploid.py
+++ b/tests/test_nontree_vit_haploid.py
@@ -10,55 +10,158 @@
 
 class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase):
     def verify(self, ts, scale_mutation_rate, include_ancestors):
+        ploidy = 1
         for n, m, H_vs, s, e_vs, r, _ in self.get_examples_pars(
             ts,
-            ploidy=1,
+            ploidy=ploidy,
             scale_mutation_rate=scale_mutation_rate,
             include_ancestors=include_ancestors,
             include_extreme_rates=True,
         ):
-            V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(n, m, H_vs, s, e_vs, r)
-            path_vs = vh.backwards_viterbi_hap(m, V_vs[m - 1, :], P_vs)
-            ll_check = vh.path_ll_hap(n, m, H_vs, path_vs, s, e_vs, r)
+            emission_func = core.get_emission_probability_haploid
+
+            V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+            )
+            path_vs = vh.backwards_viterbi_hap(
+                m=m,
+                V_last=V_vs[m - 1, :],
+                P=P_vs,
+            )
+            ll_check = vh.path_ll_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                path=path_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+            )
             self.assertAllClose(ll_vs, ll_check)
 
             V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(
-                n, m, H_vs, s, e_vs, r
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+            )
+            path_tmp = vh.backwards_viterbi_hap(
+                m=m,
+                V_last=V_tmp[m - 1, :],
+                P=P_tmp,
+            )
+            ll_check = vh.path_ll_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                path=path_tmp,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
             )
-            path_tmp = vh.backwards_viterbi_hap(m, V_tmp[m - 1, :], P_tmp)
-            ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
             self.assertAllClose(ll_tmp, ll_check)
             self.assertAllClose(ll_vs, ll_tmp)
 
             V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem(
-                n, m, H_vs, s, e_vs, r
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+            )
+            path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
+            ll_check = vh.path_ll_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                path=path_tmp,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
             )
-            path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
-            ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
             self.assertAllClose(ll_tmp, ll_check)
             self.assertAllClose(ll_vs, ll_tmp)
 
             V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling(
-                n, m, H_vs, s, e_vs, r
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+            )
+            path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
+            ll_check = vh.path_ll_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                path=path_tmp,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
             )
-            path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
-            ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
             self.assertAllClose(ll_tmp, ll_check)
             self.assertAllClose(ll_vs, ll_tmp)
 
             V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling(
-                n, m, H_vs, s, e_vs, r
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+            )
+            path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
+            ll_check = vh.path_ll_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                path=path_tmp,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
             )
-            path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
-            ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
             self.assertAllClose(ll_tmp, ll_check)
             self.assertAllClose(ll_vs, ll_tmp)
 
             V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling(
-                n, m, H_vs, s, e_vs, r
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
+            )
+            path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
+            ll_check = vh.path_ll_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                path=path_tmp,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
             )
-            path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
-            ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
             self.assertAllClose(ll_tmp, ll_check)
             self.assertAllClose(ll_vs, ll_tmp)
 
@@ -68,14 +171,29 @@ def verify(self, ts, scale_mutation_rate, include_ancestors):
                 recombs,
                 ll_tmp,
             ) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
-                n, m, H_vs, s, e_vs, r
+                n=n,
+                m=m,
+                H=H_vs,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
             )
             path_tmp = vh.backwards_viterbi_hap_no_pointer(
-                m,
-                V_argmaxes_tmp,
-                nb.typed.List(recombs),
+                m=m,
+                V_argmaxes=V_argmaxes_tmp,
+                recombs=nb.typed.List(recombs),
+            )
+            ll_check = vh.path_ll_hap(
+                n=n,
+                m=m,
+                H=H_vs,
+                path=path_tmp,
+                s=s,
+                e=e_vs,
+                r=r,
+                emission_func=emission_func,
             )
-            ll_check = vh.path_ll_hap(n, m, H_vs, path_tmp, s, e_vs, r)
             self.assertAllClose(ll_tmp, ll_check)
             self.assertAllClose(ll_vs, ll_tmp)