Skip to content

Commit

Permalink
hmm: update pomegranate API usage, per @rollf in #789
Browse files Browse the repository at this point in the history
  • Loading branch information
etal committed Sep 4, 2024
1 parent 8ba7cf3 commit ee6beab
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions cnvlib/segmentation/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import numpy as np
import pandas as pd
import pomegranate as pom
import pomegranate.distributions
from pomegranate.distributions import Normal
from pomegranate.hmm import DenseHMM
import scipy.special

from ..cnary import CopyNumArray as CNA
Expand Down Expand Up @@ -93,25 +95,25 @@ def hmm_get_model(cnarr, method, diploid_parx_genome, processes):
if method == "hmm-germline":
state_names = ["loss", "neutral", "gain"]
distributions = [
pom.NormalDistribution(-1.0, stdev, frozen=True),
pom.NormalDistribution(0.0, stdev, frozen=True),
pom.NormalDistribution(0.585, stdev, frozen=True),
Normal(-1.0, stdev, frozen=True),
Normal(0.0, stdev, frozen=True),
Normal(0.585, stdev, frozen=True),
]
elif method == "hmm-tumor":
state_names = ["del", "loss", "neutral", "gain", "amp"]
distributions = [
pom.NormalDistribution(-2.0, stdev, frozen=False),
pom.NormalDistribution(-0.5, stdev, frozen=False),
pom.NormalDistribution(0.0, stdev, frozen=True),
pom.NormalDistribution(0.3, stdev, frozen=False),
pom.NormalDistribution(1.0, stdev, frozen=False),
Normal(-2.0, stdev, frozen=False),
Normal(-0.5, stdev, frozen=False),
Normal(0.0, stdev, frozen=True),
Normal(0.3, stdev, frozen=False),
Normal(1.0, stdev, frozen=False),
]
else:
state_names = ["loss", "neutral", "gain"]
distributions = [
pom.NormalDistribution(-1.0, stdev, frozen=False),
pom.NormalDistribution(0.0, stdev, frozen=False),
pom.NormalDistribution(0.585, stdev, frozen=False),
Normal(-1.0, stdev, frozen=False),
Normal(0.0, stdev, frozen=False),
Normal(0.585, stdev, frozen=False),
]

n_states = len(distributions)
Expand All @@ -125,7 +127,7 @@ def hmm_get_model(cnarr, method, diploid_parx_genome, processes):
np.identity(n_states) * 100 + np.ones((n_states, n_states)) / n_states
)

model = pom.HiddenMarkovModel.from_matrix(
model = DenseHMM.from_matrix(
transition_matrix,
distributions,
start_probabilities,
Expand Down Expand Up @@ -170,8 +172,8 @@ def variants_in_segment(varr, segment, min_variants=50):
observations = varr.mirrored_baf(above_half=True)
state_names = ["neutral", "alt"]
distributions = [
pom.NormalDistribution(0.5, 0.1, frozen=True),
pom.NormalDistribution(0.67, 0.1, frozen=True),
Normal(0.5, 0.1, frozen=True),
Normal(0.67, 0.1, frozen=True),
]
n_states = len(distributions)
# Starts -- prefer neutral
Expand All @@ -181,7 +183,7 @@ def variants_in_segment(varr, segment, min_variants=50):
transition_matrix = (
np.identity(n_states) * 100 + np.ones((n_states, n_states)) / n_states
)
model = pom.HiddenMarkovModel.from_matrix(
model = DenseHMM.from_matrix(
transition_matrix,
distributions,
start_probabilities,
Expand Down

0 comments on commit ee6beab

Please sign in to comment.