Skip to content

Commit

Permalink
hmm: Resolve some, but not all, breaking API changes
Browse files Browse the repository at this point in the history
  • Loading branch information
etal committed Sep 4, 2024
1 parent ee6beab commit e759acf
Showing 1 changed file with 31 additions and 25 deletions.
56 changes: 31 additions & 25 deletions cnvlib/segmentation/hmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,25 +95,25 @@ def hmm_get_model(cnarr, method, diploid_parx_genome, processes):
if method == "hmm-germline":
state_names = ["loss", "neutral", "gain"]
distributions = [
Normal(-1.0, stdev, frozen=True),
Normal(0.0, stdev, frozen=True),
Normal(0.585, stdev, frozen=True),
Normal([-1.0], [stdev], covariance_type="diag", frozen=True),
Normal([0.0], [stdev], covariance_type="diag", frozen=True),
Normal([0.585], [stdev], covariance_type="diag", frozen=True),
]
elif method == "hmm-tumor":
state_names = ["del", "loss", "neutral", "gain", "amp"]
distributions = [
Normal(-2.0, stdev, frozen=False),
Normal(-0.5, stdev, frozen=False),
Normal(0.0, stdev, frozen=True),
Normal(0.3, stdev, frozen=False),
Normal(1.0, stdev, frozen=False),
Normal([-2.0], [stdev], covariance_type="diag", frozen=False),
Normal([-0.5], [stdev], covariance_type="diag", frozen=False),
Normal([0.0], [stdev], covariance_type="diag", frozen=True),
Normal([0.3], [stdev], covariance_type="diag", frozen=False),
Normal([1.0], [stdev], covariance_type="diag", frozen=False),
]
else:
state_names = ["loss", "neutral", "gain"]
distributions = [
Normal(-1.0, stdev, frozen=False),
Normal(0.0, stdev, frozen=False),
Normal(0.585, stdev, frozen=False),
Normal([-1.0], [stdev], covariance_type="diag", frozen=False),
Normal([0.0], [stdev], covariance_type="diag", frozen=False),
Normal([0.585], [stdev], covariance_type="diag", frozen=False),
]

n_states = len(distributions)
Expand All @@ -126,13 +126,16 @@ def hmm_get_model(cnarr, method, diploid_parx_genome, processes):
transition_matrix = (
np.identity(n_states) * 100 + np.ones((n_states, n_states)) / n_states
)

model = DenseHMM.from_matrix(
transition_matrix,
distributions,
start_probabilities,
state_names=state_names,
name=method,
# Rescale so max is 1.0
transition_matrix /= transition_matrix.max()

model = DenseHMM(
distributions=distributions,
edges=transition_matrix,
starts=start_probabilities,
ends=start_probabilities,
#state_names=state_names,
#name=method,
)

model.fit(
Expand Down Expand Up @@ -172,8 +175,8 @@ def variants_in_segment(varr, segment, min_variants=50):
observations = varr.mirrored_baf(above_half=True)
state_names = ["neutral", "alt"]
distributions = [
Normal(0.5, 0.1, frozen=True),
Normal(0.67, 0.1, frozen=True),
Normal([0.5], [0.1], covariance_type="diag", frozen=True),
Normal([0.67], [0.1], covariance_type="diag", frozen=True),
]
n_states = len(distributions)
# Starts -- prefer neutral
Expand All @@ -183,12 +186,15 @@ def variants_in_segment(varr, segment, min_variants=50):
transition_matrix = (
np.identity(n_states) * 100 + np.ones((n_states, n_states)) / n_states
)
# Rescale so max is 1.0
transition_matrix /= transition_matrix.max()
model = DenseHMM.from_matrix(
transition_matrix,
distributions,
start_probabilities,
state_names=state_names,
name="loh",
distributions=distributions,
edges=transition_matrix,
starts=start_probabilities,
ends=start_probabilities,
#state_names=state_names,
#name="loh",
)

model.fit(
Expand Down

0 comments on commit e759acf

Please sign in to comment.