diff --git a/src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java b/src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java
index f07783ce..a6b5821e 100644
--- a/src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java
+++ b/src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java
@@ -99,6 +99,7 @@ public class BeagleTreeLikelihood extends TreeLikelihood {
@Override
public void initAndValidate() {
+ alignment = (Alignment) dataInput.get();
boolean forceJava = Boolean.valueOf(System.getProperty("java.only"));
if (forceJava) {
return;
@@ -114,7 +115,7 @@ private boolean initialize() {
throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base");
}
m_siteModel = (SiteModel.Base) siteModelInput.get();
- m_siteModel.setDataType(dataInput.get().getDataType());
+ m_siteModel.setDataType(alignment.getDataType());
substitutionModel = m_siteModel.substModelInput.get();
branchRateModel = branchRateModelInput.get();
if (branchRateModel == null) {
@@ -123,8 +124,8 @@ private boolean initialize() {
m_branchLengths = new double[m_nNodeCount];
storedBranchLengths = new double[m_nNodeCount];
- m_nStateCount = dataInput.get().getMaxStateCount();
- patternCount = dataInput.get().getPatternCount();
+ m_nStateCount = alignment.getMaxStateCount();
+ patternCount = alignment.getPatternCount();
//System.err.println("Attempt to load BEAGLE TreeLikelihood");
@@ -136,8 +137,8 @@ private boolean initialize() {
for (int i = 0; i < categoryRates.length; i++) {
if (categoryRates[i] == 0) {
proportionInvariant = m_siteModel.getRateForCategory(i, null);
- int stateCount = dataInput.get().getMaxStateCount();
- int patterns = dataInput.get().getPatternCount();
+ int stateCount = alignment.getMaxStateCount();
+ int patterns = alignment.getPatternCount();
calcConstantPatternIndices(patterns, stateCount);
invariantCategory = i;
@@ -152,7 +153,7 @@ private boolean initialize() {
break;
}
}
- if (constantPattern != null && constantPattern.size() > dataInput.get().getPatternCount()) {
+ if (constantPattern != null && constantPattern.size() > alignment.getPatternCount()) {
// if there are many more constant patterns than patterns (each pattern can
// have a number of constant patters, one for each state) it is less efficient
// to just calculate the TreeLikelihood for constant sites than optimising
@@ -385,7 +386,7 @@ private boolean initialize() {
Node [] nodes = treeInput.get().getNodesAsArray();
for (int i = 0; i < tipCount; i++) {
- int taxon = getTaxonIndex(nodes[i].getID(), dataInput.get());
+ int taxon = getTaxonIndex(nodes[i].getID(), alignment);
if (m_bUseAmbiguities || m_bUseTipLikelihoods) {
setPartials(beagle, i, taxon);
} else {
@@ -393,13 +394,13 @@ private boolean initialize() {
}
}
- if (dataInput.get().isAscertained) {
+ if (alignment.isAscertained) {
ascertainedSitePatterns = true;
}
double[] patternWeights = new double[patternCount];
for (int i = 0; i < patternCount; i++) {
- patternWeights[i] = dataInput.get().getPatternWeight(i);
+ patternWeights[i] = alignment.getPatternWeight(i);
}
beagle.setPatternWeights(patternWeights);
@@ -482,22 +483,21 @@ protected int getScaleBufferCount() {
*/
protected final void setPartials(Beagle beagle,
int nodeIndex, int taxon) {
- Alignment data = dataInput.get();
double[] partials = new double[patternCount * m_nStateCount * categoryCount];
int v = 0;
for (int i = 0; i < patternCount; i++) {
- double[] tipProbabilities = data.getTipLikelihoods(taxon,i);
+ double[] tipProbabilities = alignment.getTipLikelihoods(taxon,i);
if (tipProbabilities != null) {
for (int state = 0; state < m_nStateCount; state++) {
partials[v++] = tipProbabilities[state];
}
}
else {
- int stateCount = data.getPattern(taxon, i);
- boolean[] stateSet = data.getStateSet(stateCount);
+ int stateCount = alignment.getPattern(taxon, i);
+ boolean[] stateSet = alignment.getStateSet(stateCount);
for (int state = 0; state < m_nStateCount; state++) {
partials[v++] = (stateSet[state] ? 1.0 : 0.0);
}
@@ -547,14 +547,13 @@ void setUpSubstModel() {
*/
protected final void setStates(Beagle beagle,
int nodeIndex, int taxon) {
- Alignment data = dataInput.get();
int i;
int[] states = new int[patternCount];
for (i = 0; i < patternCount; i++) {
- int code = data.getPattern(taxon, i);
- int[] statesForCode = data.getDataType().getStatesForCode(code);
+ int code = alignment.getPattern(taxon, i);
+ int[] statesForCode = alignment.getDataType().getStatesForCode(code);
if (statesForCode.length==1)
states[i] = statesForCode[0];
else
@@ -866,11 +865,11 @@ public double calculateLogP() {
if (ascertainedSitePatterns) {
// Need to correct for ascertainedSitePatterns
beagle.getSiteLogLikelihoods(patternLogLikelihoods);
- logL = getAscertainmentCorrectedLogLikelihood(dataInput.get(),
- patternLogLikelihoods, dataInput.get().getWeights(), frequencies);
+ logL = getAscertainmentCorrectedLogLikelihood(alignment,
+ patternLogLikelihoods, alignment.getWeights(), frequencies);
} else if (invariantCategory >= 0) {
beagle.getSiteLogLikelihoods(patternLogLikelihoods);
- int [] patternWeights = dataInput.get().getWeights();
+ int [] patternWeights = alignment.getWeights();
proportionInvariant = m_siteModel.getProportionInvariant();
diff --git a/src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java b/src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java
index c2287d49..cd2125af 100644
--- a/src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java
+++ b/src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java
@@ -104,10 +104,13 @@ public List> listInputs() {
// specified a set ranges of patterns assigned to each thread
// first patternPoints contains 0, then one point for each thread
private int [] patternPoints;
+
+ private Alignment alignment;
@Override
public void initAndValidate() {
threadCount = ProgramStatus.m_nThreads;
+ alignment = (Alignment) dataInput.get();
if (maxNrOfThreadsInput.get() > 0) {
threadCount = Math.min(maxNrOfThreadsInput.get(), ProgramStatus.m_nThreads);
@@ -120,13 +123,13 @@ public void initAndValidate() {
logPByThread = new double[threadCount];
// sanity check: alignment should have same #taxa as tree
- if (dataInput.get().getTaxonCount() != treeInput.get().getLeafNodeCount()) {
+ if (alignment.getTaxonCount() != treeInput.get().getLeafNodeCount()) {
throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences");
}
treelikelihood = new TreeLikelihood[threadCount];
- if (dataInput.get().isAscertained) {
+ if (alignment.isAscertained) {
Log.warning.println("Note, can only use single thread per alignment because the alignment is ascertained");
threadCount = 1;
}
@@ -147,12 +150,11 @@ public void initAndValidate() {
} else {
pool = Executors.newFixedThreadPool(threadCount);
- calcPatternPoints(dataInput.get().getSiteCount());
+ calcPatternPoints(alignment.getSiteCount());
for (int i = 0; i < threadCount; i++) {
- Alignment data = dataInput.get();
String filterSpec = (patternPoints[i] +1) + "-" + (patternPoints[i + 1]);
- if (data.isAscertained) {
- filterSpec += data.excludefromInput.get() + "-" + data.excludetoInput.get() + "," + filterSpec;
+ if (alignment.isAscertained) {
+ filterSpec += alignment.excludefromInput.get() + "-" + alignment.excludetoInput.get() + "," + filterSpec;
}
treelikelihood[i] = new TreeLikelihood();
treelikelihood[i].setID(getID() + i);
@@ -176,7 +178,7 @@ public void initAndValidate() {
"branchRateModel", duplicate(branchRateModelInput.get(), i),
"rootFrequencies", rootFrequenciesInput.get(),
"useAmbiguities", useAmbiguitiesInput.get(),
- "scaling" , scalingInput.get() + ""
+ "scaling", scalingInput.get() + ""
);
likelihoodCallers.add(new TreeLikelihoodCaller(treelikelihood[i], i));
@@ -321,7 +323,7 @@ private double calculateLogPByBeagle() {
/* return copy of pattern log likelihoods for each of the patterns in the alignment */
public double [] getPatternLogLikelihoods() {
- double [] patternLogLikelihoods = new double[dataInput.get().getPatternCount()];
+ double [] patternLogLikelihoods = new double[alignment.getPatternCount()];
int i = 0;
for (TreeLikelihood b : treelikelihood) {
double [] d = b.getPatternLogLikelihoods();
diff --git a/src/beast/base/evolution/likelihood/TreeLikelihood.java b/src/beast/base/evolution/likelihood/TreeLikelihood.java
index c32b3271..853ca7e2 100644
--- a/src/beast/base/evolution/likelihood/TreeLikelihood.java
+++ b/src/beast/base/evolution/likelihood/TreeLikelihood.java
@@ -26,6 +26,7 @@
package beast.base.evolution.likelihood;
+import java.lang.reflect.InvocationTargetException;
import java.util.*;
import beast.base.core.Description;
@@ -41,6 +42,7 @@
import beast.base.evolution.tree.Tree;
import beast.base.evolution.tree.TreeInterface;
import beast.base.inference.State;
+import beast.pkgmgmt.BEASTClassLoader;
@Description("Calculates the probability of sequence data on a beast.tree given a site and substitution model using " +
"a variant of the 'peeling algorithm'. For details, see" +
@@ -51,7 +53,7 @@ public class TreeLikelihood extends GenericTreeLikelihood {
final public Input m_useTipLikelihoods = new Input<>("useTipLikelihoods", "flag to indicate that partial likelihoods are provided at the tips", false);
final public Input implementationInput = new Input<>("implementation", "name of class that implements this treelikelihood potentially more efficiently. "
+ "This class will be tried first, with the TreeLikelihood as fallback implementation. "
- + "When multi-threading, multiple objects can be created.", "beast.evolution.likelihood.BeagleTreeLikelihood");
+ + "When multi-threading, multiple objects can be created.", BeagleTreeLikelihood.class.getName());
public static enum Scaling {none, always, _default};
final public Input scaling = new Input<>("scaling", "type of scaling to use, one of " + Arrays.toString(Scaling.values()) + ". If not specified, the -beagle_scaling flag is used.", Scaling._default, Scaling.values());
@@ -115,6 +117,11 @@ public static enum Scaling {none, always, _default};
*/
protected boolean useAscertainedSitePatterns = false;
+ /**
+ * alias for the data
+ */
+ protected Alignment alignment;
+
/**
* dealing with proportion of site being invariant *
*/
@@ -128,9 +135,14 @@ public static enum Scaling {none, always, _default};
@Override
public void initAndValidate() {
+ // sanity check: make sure data is an Alignment
+ if (!(dataInput.get() instanceof Alignment)) {
+ throw new RuntimeException("Expected Alignment as data, not " + dataInput.get().getClass().getName());
+ }
+ alignment = (Alignment) dataInput.get();
+
// sanity check: alignment should have same #taxa as tree
-
- if (dataInput.get().getTaxonCount() != treeInput.get().getLeafNodeCount()) {
+ if (alignment.getTaxonCount() != treeInput.get().getLeafNodeCount()) {
String leaves = "?";
if (treeInput.get() instanceof Tree) {
leaves = String.join(", ", ((Tree) treeInput.get()).getTaxaNames());
@@ -138,8 +150,8 @@ public void initAndValidate() {
throw new IllegalArgumentException(String.format(
"The number of leaves in the tree (%d) does not match the number of sequences (%d). "
+ "The tree has leaves [%s], while the data refers to taxa [%s].",
- treeInput.get().getLeafNodeCount(), dataInput.get().getTaxonCount(),
- leaves, String.join(", ", dataInput.get().getTaxaNames())));
+ treeInput.get().getLeafNodeCount(), alignment.getTaxonCount(),
+ leaves, String.join(", ", alignment.getTaxaNames())));
}
beagle = null;
@@ -152,8 +164,11 @@ public void initAndValidate() {
// }
//
// if (!hasImaginaryEigenvectors) {
- beagle = new BeagleTreeLikelihood();
try {
+ Object o = newTreeLikelihood();
+ if (o instanceof BeagleTreeLikelihood) {
+ beagle = (BeagleTreeLikelihood) o;
+ }
beagle.initByName(
"data", dataInput.get(), "tree", treeInput.get(), "siteModel", siteModelInput.get(),
"branchRateModel", branchRateModelInput.get(), "useAmbiguities", m_useAmbiguities.get(),
@@ -177,7 +192,7 @@ public void initAndValidate() {
throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base");
}
m_siteModel = (SiteModel.Base) siteModelInput.get();
- m_siteModel.setDataType(dataInput.get().getDataType());
+ m_siteModel.setDataType(alignment.getDataType());
substitutionModel = m_siteModel.substModelInput.get();
if (branchRateModelInput.get() != null) {
@@ -188,13 +203,12 @@ public void initAndValidate() {
m_branchLengths = new double[nodeCount];
storedBranchLengths = new double[nodeCount];
- int stateCount = dataInput.get().getMaxStateCount();
- int patterns = dataInput.get().getPatternCount();
+ int stateCount = alignment.getMaxStateCount();
+ int patterns = alignment.getPatternCount();
likelihoodCore = createLikelihoodCore(stateCount);
String className = getClass().getSimpleName();
- Alignment alignment = dataInput.get();
Log.info.println(className + "(" + getID() + ") uses " + likelihoodCore.getClass().getSimpleName());
Log.info.println(" " + alignment.toString(true));
@@ -214,12 +228,18 @@ public void initAndValidate() {
probabilities = new double[(stateCount + 1) * (stateCount + 1)];
Arrays.fill(probabilities, 1.0);
- if (dataInput.get().isAscertained) {
+ if (alignment.isAscertained) {
useAscertainedSitePatterns = true;
}
}
- protected LikelihoodCore createLikelihoodCore(int stateCount) {
+ private TreeLikelihood newTreeLikelihood() throws InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException, ClassNotFoundException {
+ String class_ = implementationInput.get();
+ TreeLikelihood tl = (TreeLikelihood) BEASTClassLoader.forName(class_).getConstructor().newInstance();
+ return tl;
+ }
+
+ protected LikelihoodCore createLikelihoodCore(int stateCount) {
if (stateCount == 4) {
return new BeerLikelihoodCore4();
} else {
@@ -239,12 +259,12 @@ protected LikelihoodCore createLikelihoodCore(int stateCount) {
protected void calcConstantPatternIndices(final int patterns, final int stateCount) {
constantPattern = new ArrayList<>();
for (int i = 0; i < patterns; i++) {
- final int[] pattern = dataInput.get().getPattern(i);
+ final int[] pattern = alignment.getPattern(i);
final boolean[] isInvariant = new boolean[stateCount];
Arrays.fill(isInvariant, true);
for (final int state : pattern) {
- final boolean[] isStateSet = dataInput.get().getStateSet(state);
- if (m_useAmbiguities.get() || !dataInput.get().getDataType().isAmbiguousCode(state)) {
+ final boolean[] isStateSet = alignment.getStateSet(state);
+ if (m_useAmbiguities.get() || !alignment.getDataType().isAmbiguousCode(state)) {
for (int k = 0; k < stateCount; k++) {
isInvariant[k] &= isStateSet[k];
}
@@ -262,7 +282,7 @@ protected void initCore() {
final int nodeCount = treeInput.get().getNodeCount();
likelihoodCore.initialize(
nodeCount,
- dataInput.get().getPatternCount(),
+ alignment.getPatternCount(),
m_siteModel.getCategoryCount(),
true, m_useAmbiguities.get()
);
@@ -271,9 +291,9 @@ protected void initCore() {
final int intNodeCount = nodeCount / 2;
if (m_useAmbiguities.get() || m_useTipLikelihoods.get()) {
- setPartials(treeInput.get().getRoot(), dataInput.get().getPatternCount());
+ setPartials(treeInput.get().getRoot(), alignment.getPatternCount());
} else {
- setStates(treeInput.get().getRoot(), dataInput.get().getPatternCount());
+ setStates(treeInput.get().getRoot(), alignment.getPatternCount());
}
hasDirt = Tree.IS_FILTHY;
for (int i = 0; i < intNodeCount; i++) {
@@ -294,13 +314,12 @@ public void sample(State state, Random random) {
*/
protected void setStates(Node node, int patternCount) {
if (node.isLeaf()) {
- Alignment data = dataInput.get();
int i;
int[] states = new int[patternCount];
- int taxonIndex = getTaxonIndex(node.getID(), data);
+ int taxonIndex = getTaxonIndex(node.getID(), alignment);
for (i = 0; i < patternCount; i++) {
- int code = data.getPattern(taxonIndex, i);
- int[] statesForCode = data.getDataType().getStatesForCode(code);
+ int code = alignment.getPattern(taxonIndex, i);
+ int[] statesForCode = alignment.getDataType().getStatesForCode(code);
if (statesForCode.length==1)
states[i] = statesForCode[0];
else
@@ -339,21 +358,20 @@ private int getTaxonIndex(String taxon, Alignment data) {
*/
protected void setPartials(Node node, int patternCount) {
if (node.isLeaf()) {
- Alignment data = dataInput.get();
- int states = data.getDataType().getStateCount();
+ int states = alignment.getDataType().getStateCount();
double[] partials = new double[patternCount * states];
int k = 0;
- int taxonIndex = getTaxonIndex(node.getID(), data);
+ int taxonIndex = getTaxonIndex(node.getID(), alignment);
for (int patternIndex_ = 0; patternIndex_ < patternCount; patternIndex_++) {
- double[] tipLikelihoods = data.getTipLikelihoods(taxonIndex,patternIndex_);
+ double[] tipLikelihoods = alignment.getTipLikelihoods(taxonIndex,patternIndex_);
if (tipLikelihoods != null) {
for (int state = 0; state < states; state++) {
partials[k++] = tipLikelihoods[state];
}
}
else {
- int stateCount = data.getPattern(taxonIndex, patternIndex_);
- boolean[] stateSet = data.getStateSet(stateCount);
+ int stateCount = alignment.getPattern(taxonIndex, patternIndex_);
+ boolean[] stateSet = alignment.getStateSet(stateCount);
for (int state = 0; state < states; state++) {
partials[k++] = (stateSet[state] ? 1.0 : 0.0);
}
@@ -369,6 +387,9 @@ protected void setPartials(Node node, int patternCount) {
// for testing
public double[] getRootPartials() {
+ if (beagle != null) {
+ return beagle.getRootPartials();
+ }
return m_fRootPartials.clone();
}
@@ -424,13 +445,13 @@ public double calculateLogP() {
protected void calcLogP() {
logP = 0.0;
if (useAscertainedSitePatterns) {
- final double ascertainmentCorrection = dataInput.get().getAscertainmentCorrection(patternLogLikelihoods);
- for (int i = 0; i < dataInput.get().getPatternCount(); i++) {
- logP += (patternLogLikelihoods[i] - ascertainmentCorrection) * dataInput.get().getPatternWeight(i);
+ final double ascertainmentCorrection = alignment.getAscertainmentCorrection(patternLogLikelihoods);
+ for (int i = 0; i < alignment.getPatternCount(); i++) {
+ logP += (patternLogLikelihoods[i] - ascertainmentCorrection) * alignment.getPatternWeight(i);
}
} else {
- for (int i = 0; i < dataInput.get().getPatternCount(); i++) {
- logP += patternLogLikelihoods[i] * dataInput.get().getPatternWeight(i);
+ for (int i = 0; i < alignment.getPatternCount(); i++) {
+ logP += patternLogLikelihoods[i] * alignment.getPatternWeight(i);
}
}
}