Skip to content

Commit

Permalink
robustify tree likelihoods for possibility data is not Alignment #1174
Browse files Browse the repository at this point in the history
  • Loading branch information
rbouckaert committed Jan 15, 2025
1 parent 3d6f460 commit 10aaac6
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 60 deletions.
37 changes: 18 additions & 19 deletions src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public class BeagleTreeLikelihood extends TreeLikelihood {

@Override
public void initAndValidate() {
alignment = (Alignment) dataInput.get();
boolean forceJava = Boolean.valueOf(System.getProperty("java.only"));
if (forceJava) {
return;
Expand All @@ -114,7 +115,7 @@ private boolean initialize() {
throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base");
}
m_siteModel = (SiteModel.Base) siteModelInput.get();
m_siteModel.setDataType(dataInput.get().getDataType());
m_siteModel.setDataType(alignment.getDataType());
substitutionModel = m_siteModel.substModelInput.get();
branchRateModel = branchRateModelInput.get();
if (branchRateModel == null) {
Expand All @@ -123,8 +124,8 @@ private boolean initialize() {
m_branchLengths = new double[m_nNodeCount];
storedBranchLengths = new double[m_nNodeCount];

m_nStateCount = dataInput.get().getMaxStateCount();
patternCount = dataInput.get().getPatternCount();
m_nStateCount = alignment.getMaxStateCount();
patternCount = alignment.getPatternCount();

//System.err.println("Attempt to load BEAGLE TreeLikelihood");

Expand All @@ -136,8 +137,8 @@ private boolean initialize() {
for (int i = 0; i < categoryRates.length; i++) {
if (categoryRates[i] == 0) {
proportionInvariant = m_siteModel.getRateForCategory(i, null);
int stateCount = dataInput.get().getMaxStateCount();
int patterns = dataInput.get().getPatternCount();
int stateCount = alignment.getMaxStateCount();
int patterns = alignment.getPatternCount();
calcConstantPatternIndices(patterns, stateCount);
invariantCategory = i;

Expand All @@ -152,7 +153,7 @@ private boolean initialize() {
break;
}
}
if (constantPattern != null && constantPattern.size() > dataInput.get().getPatternCount()) {
if (constantPattern != null && constantPattern.size() > alignment.getPatternCount()) {
// if there are many more constant patterns than patterns (each pattern can
// have a number of constant patters, one for each state) it is less efficient
// to just calculate the TreeLikelihood for constant sites than optimising
Expand Down Expand Up @@ -385,21 +386,21 @@ private boolean initialize() {

Node [] nodes = treeInput.get().getNodesAsArray();
for (int i = 0; i < tipCount; i++) {
int taxon = getTaxonIndex(nodes[i].getID(), dataInput.get());
int taxon = getTaxonIndex(nodes[i].getID(), alignment);
if (m_bUseAmbiguities || m_bUseTipLikelihoods) {
setPartials(beagle, i, taxon);
} else {
setStates(beagle, i, taxon);
}
}

if (dataInput.get().isAscertained) {
if (alignment.isAscertained) {
ascertainedSitePatterns = true;
}

double[] patternWeights = new double[patternCount];
for (int i = 0; i < patternCount; i++) {
patternWeights[i] = dataInput.get().getPatternWeight(i);
patternWeights[i] = alignment.getPatternWeight(i);
}
beagle.setPatternWeights(patternWeights);

Expand Down Expand Up @@ -482,22 +483,21 @@ protected int getScaleBufferCount() {
*/
protected final void setPartials(Beagle beagle,
int nodeIndex, int taxon) {
Alignment data = dataInput.get();

double[] partials = new double[patternCount * m_nStateCount * categoryCount];

int v = 0;
for (int i = 0; i < patternCount; i++) {

double[] tipProbabilities = data.getTipLikelihoods(taxon,i);
double[] tipProbabilities = alignment.getTipLikelihoods(taxon,i);
if (tipProbabilities != null) {
for (int state = 0; state < m_nStateCount; state++) {
partials[v++] = tipProbabilities[state];
}
}
else {
int stateCount = data.getPattern(taxon, i);
boolean[] stateSet = data.getStateSet(stateCount);
int stateCount = alignment.getPattern(taxon, i);
boolean[] stateSet = alignment.getStateSet(stateCount);
for (int state = 0; state < m_nStateCount; state++) {
partials[v++] = (stateSet[state] ? 1.0 : 0.0);
}
Expand Down Expand Up @@ -547,14 +547,13 @@ void setUpSubstModel() {
*/
protected final void setStates(Beagle beagle,
int nodeIndex, int taxon) {
Alignment data = dataInput.get();
int i;

int[] states = new int[patternCount];

for (i = 0; i < patternCount; i++) {
int code = data.getPattern(taxon, i);
int[] statesForCode = data.getDataType().getStatesForCode(code);
int code = alignment.getPattern(taxon, i);
int[] statesForCode = alignment.getDataType().getStatesForCode(code);
if (statesForCode.length==1)
states[i] = statesForCode[0];
else
Expand Down Expand Up @@ -866,11 +865,11 @@ public double calculateLogP() {
if (ascertainedSitePatterns) {
// Need to correct for ascertainedSitePatterns
beagle.getSiteLogLikelihoods(patternLogLikelihoods);
logL = getAscertainmentCorrectedLogLikelihood(dataInput.get(),
patternLogLikelihoods, dataInput.get().getWeights(), frequencies);
logL = getAscertainmentCorrectedLogLikelihood(alignment,
patternLogLikelihoods, alignment.getWeights(), frequencies);
} else if (invariantCategory >= 0) {
beagle.getSiteLogLikelihoods(patternLogLikelihoods);
int [] patternWeights = dataInput.get().getWeights();
int [] patternWeights = alignment.getWeights();
proportionInvariant = m_siteModel.getProportionInvariant();


Expand Down
18 changes: 10 additions & 8 deletions src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,13 @@ public List<Input<?>> listInputs() {
// specified a set ranges of patterns assigned to each thread
// first patternPoints contains 0, then one point for each thread
private int [] patternPoints;

private Alignment alignment;

@Override
public void initAndValidate() {
threadCount = ProgramStatus.m_nThreads;
alignment = (Alignment) dataInput.get();

if (maxNrOfThreadsInput.get() > 0) {
threadCount = Math.min(maxNrOfThreadsInput.get(), ProgramStatus.m_nThreads);
Expand All @@ -120,13 +123,13 @@ public void initAndValidate() {
logPByThread = new double[threadCount];

// sanity check: alignment should have same #taxa as tree
if (dataInput.get().getTaxonCount() != treeInput.get().getLeafNodeCount()) {
if (alignment.getTaxonCount() != treeInput.get().getLeafNodeCount()) {
throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences");
}

treelikelihood = new TreeLikelihood[threadCount];

if (dataInput.get().isAscertained) {
if (alignment.isAscertained) {
Log.warning.println("Note, can only use single thread per alignment because the alignment is ascertained");
threadCount = 1;
}
Expand All @@ -147,12 +150,11 @@ public void initAndValidate() {
} else {
pool = Executors.newFixedThreadPool(threadCount);

calcPatternPoints(dataInput.get().getSiteCount());
calcPatternPoints(alignment.getSiteCount());
for (int i = 0; i < threadCount; i++) {
Alignment data = dataInput.get();
String filterSpec = (patternPoints[i] +1) + "-" + (patternPoints[i + 1]);
if (data.isAscertained) {
filterSpec += data.excludefromInput.get() + "-" + data.excludetoInput.get() + "," + filterSpec;
if (alignment.isAscertained) {
filterSpec += alignment.excludefromInput.get() + "-" + alignment.excludetoInput.get() + "," + filterSpec;
}
treelikelihood[i] = new TreeLikelihood();
treelikelihood[i].setID(getID() + i);
Expand All @@ -176,7 +178,7 @@ public void initAndValidate() {
"branchRateModel", duplicate(branchRateModelInput.get(), i),
"rootFrequencies", rootFrequenciesInput.get(),
"useAmbiguities", useAmbiguitiesInput.get(),
"scaling" , scalingInput.get() + ""
"scaling", scalingInput.get() + ""
);

likelihoodCallers.add(new TreeLikelihoodCaller(treelikelihood[i], i));
Expand Down Expand Up @@ -321,7 +323,7 @@ private double calculateLogPByBeagle() {

/* return copy of pattern log likelihoods for each of the patterns in the alignment */
public double [] getPatternLogLikelihoods() {
double [] patternLogLikelihoods = new double[dataInput.get().getPatternCount()];
double [] patternLogLikelihoods = new double[alignment.getPatternCount()];
int i = 0;
for (TreeLikelihood b : treelikelihood) {
double [] d = b.getPatternLogLikelihoods();
Expand Down
Loading

0 comments on commit 10aaac6

Please sign in to comment.