diff --git a/src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java b/src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java index f07783ce..a6b5821e 100644 --- a/src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java +++ b/src/beast/base/evolution/likelihood/BeagleTreeLikelihood.java @@ -99,6 +99,7 @@ public class BeagleTreeLikelihood extends TreeLikelihood { @Override public void initAndValidate() { + alignment = (Alignment) dataInput.get(); boolean forceJava = Boolean.valueOf(System.getProperty("java.only")); if (forceJava) { return; @@ -114,7 +115,7 @@ private boolean initialize() { throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base"); } m_siteModel = (SiteModel.Base) siteModelInput.get(); - m_siteModel.setDataType(dataInput.get().getDataType()); + m_siteModel.setDataType(alignment.getDataType()); substitutionModel = m_siteModel.substModelInput.get(); branchRateModel = branchRateModelInput.get(); if (branchRateModel == null) { @@ -123,8 +124,8 @@ private boolean initialize() { m_branchLengths = new double[m_nNodeCount]; storedBranchLengths = new double[m_nNodeCount]; - m_nStateCount = dataInput.get().getMaxStateCount(); - patternCount = dataInput.get().getPatternCount(); + m_nStateCount = alignment.getMaxStateCount(); + patternCount = alignment.getPatternCount(); //System.err.println("Attempt to load BEAGLE TreeLikelihood"); @@ -136,8 +137,8 @@ private boolean initialize() { for (int i = 0; i < categoryRates.length; i++) { if (categoryRates[i] == 0) { proportionInvariant = m_siteModel.getRateForCategory(i, null); - int stateCount = dataInput.get().getMaxStateCount(); - int patterns = dataInput.get().getPatternCount(); + int stateCount = alignment.getMaxStateCount(); + int patterns = alignment.getPatternCount(); calcConstantPatternIndices(patterns, stateCount); invariantCategory = i; @@ -152,7 +153,7 @@ private boolean initialize() { break; } } - if (constantPattern != null && constantPattern.size() > dataInput.get().getPatternCount()) { + if (constantPattern != null && constantPattern.size() > alignment.getPatternCount()) { // if there are many more constant patterns than patterns (each pattern can // have a number of constant patters, one for each state) it is less efficient // to just calculate the TreeLikelihood for constant sites than optimising @@ -385,7 +386,7 @@ private boolean initialize() { Node [] nodes = treeInput.get().getNodesAsArray(); for (int i = 0; i < tipCount; i++) { - int taxon = getTaxonIndex(nodes[i].getID(), dataInput.get()); + int taxon = getTaxonIndex(nodes[i].getID(), alignment); if (m_bUseAmbiguities || m_bUseTipLikelihoods) { setPartials(beagle, i, taxon); } else { @@ -393,13 +394,13 @@ private boolean initialize() { } } - if (dataInput.get().isAscertained) { + if (alignment.isAscertained) { ascertainedSitePatterns = true; } double[] patternWeights = new double[patternCount]; for (int i = 0; i < patternCount; i++) { - patternWeights[i] = dataInput.get().getPatternWeight(i); + patternWeights[i] = alignment.getPatternWeight(i); } beagle.setPatternWeights(patternWeights); @@ -482,22 +483,21 @@ protected int getScaleBufferCount() { */ protected final void setPartials(Beagle beagle, int nodeIndex, int taxon) { - Alignment data = dataInput.get(); double[] partials = new double[patternCount * m_nStateCount * categoryCount]; int v = 0; for (int i = 0; i < patternCount; i++) { - double[] tipProbabilities = data.getTipLikelihoods(taxon,i); + double[] tipProbabilities = alignment.getTipLikelihoods(taxon,i); if (tipProbabilities != null) { for (int state = 0; state < m_nStateCount; state++) { partials[v++] = tipProbabilities[state]; } } else { - int stateCount = data.getPattern(taxon, i); - boolean[] stateSet = data.getStateSet(stateCount); + int stateCount = alignment.getPattern(taxon, i); + boolean[] stateSet = alignment.getStateSet(stateCount); for (int state = 0; state < m_nStateCount; state++) { partials[v++] = (stateSet[state] ? 1.0 : 0.0); } @@ -547,14 +547,13 @@ void setUpSubstModel() { */ protected final void setStates(Beagle beagle, int nodeIndex, int taxon) { - Alignment data = dataInput.get(); int i; int[] states = new int[patternCount]; for (i = 0; i < patternCount; i++) { - int code = data.getPattern(taxon, i); - int[] statesForCode = data.getDataType().getStatesForCode(code); + int code = alignment.getPattern(taxon, i); + int[] statesForCode = alignment.getDataType().getStatesForCode(code); if (statesForCode.length==1) states[i] = statesForCode[0]; else @@ -866,11 +865,11 @@ public double calculateLogP() { if (ascertainedSitePatterns) { // Need to correct for ascertainedSitePatterns beagle.getSiteLogLikelihoods(patternLogLikelihoods); - logL = getAscertainmentCorrectedLogLikelihood(dataInput.get(), - patternLogLikelihoods, dataInput.get().getWeights(), frequencies); + logL = getAscertainmentCorrectedLogLikelihood(alignment, + patternLogLikelihoods, alignment.getWeights(), frequencies); } else if (invariantCategory >= 0) { beagle.getSiteLogLikelihoods(patternLogLikelihoods); - int [] patternWeights = dataInput.get().getWeights(); + int [] patternWeights = alignment.getWeights(); proportionInvariant = m_siteModel.getProportionInvariant(); diff --git a/src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java b/src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java index c2287d49..cd2125af 100644 --- a/src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java +++ b/src/beast/base/evolution/likelihood/ThreadedTreeLikelihood.java @@ -104,10 +104,13 @@ public List> listInputs() { // specified a set ranges of patterns assigned to each thread // first patternPoints contains 0, then one point for each thread private int [] patternPoints; + + private Alignment alignment; @Override public void initAndValidate() { threadCount = ProgramStatus.m_nThreads; + alignment = (Alignment) dataInput.get(); if (maxNrOfThreadsInput.get() > 0) { threadCount = Math.min(maxNrOfThreadsInput.get(), ProgramStatus.m_nThreads); @@ -120,13 +123,13 @@ public void initAndValidate() { logPByThread = new double[threadCount]; // sanity check: alignment should have same #taxa as tree - if (dataInput.get().getTaxonCount() != treeInput.get().getLeafNodeCount()) { + if (alignment.getTaxonCount() != treeInput.get().getLeafNodeCount()) { throw new IllegalArgumentException("The number of nodes in the tree does not match the number of sequences"); } treelikelihood = new TreeLikelihood[threadCount]; - if (dataInput.get().isAscertained) { + if (alignment.isAscertained) { Log.warning.println("Note, can only use single thread per alignment because the alignment is ascertained"); threadCount = 1; } @@ -147,12 +150,11 @@ public void initAndValidate() { } else { pool = Executors.newFixedThreadPool(threadCount); - calcPatternPoints(dataInput.get().getSiteCount()); + calcPatternPoints(alignment.getSiteCount()); for (int i = 0; i < threadCount; i++) { - Alignment data = dataInput.get(); String filterSpec = (patternPoints[i] +1) + "-" + (patternPoints[i + 1]); - if (data.isAscertained) { - filterSpec += data.excludefromInput.get() + "-" + data.excludetoInput.get() + "," + filterSpec; + if (alignment.isAscertained) { + filterSpec += alignment.excludefromInput.get() + "-" + alignment.excludetoInput.get() + "," + filterSpec; } treelikelihood[i] = new TreeLikelihood(); treelikelihood[i].setID(getID() + i); @@ -176,7 +178,7 @@ public void initAndValidate() { "branchRateModel", duplicate(branchRateModelInput.get(), i), "rootFrequencies", rootFrequenciesInput.get(), "useAmbiguities", useAmbiguitiesInput.get(), - "scaling" , scalingInput.get() + "" + "scaling", scalingInput.get() + "" ); likelihoodCallers.add(new TreeLikelihoodCaller(treelikelihood[i], i)); @@ -321,7 +323,7 @@ private double calculateLogPByBeagle() { /* return copy of pattern log likelihoods for each of the patterns in the alignment */ public double [] getPatternLogLikelihoods() { - double [] patternLogLikelihoods = new double[dataInput.get().getPatternCount()]; + double [] patternLogLikelihoods = new double[alignment.getPatternCount()]; int i = 0; for (TreeLikelihood b : treelikelihood) { double [] d = b.getPatternLogLikelihoods(); diff --git a/src/beast/base/evolution/likelihood/TreeLikelihood.java b/src/beast/base/evolution/likelihood/TreeLikelihood.java index c32b3271..853ca7e2 100644 --- a/src/beast/base/evolution/likelihood/TreeLikelihood.java +++ b/src/beast/base/evolution/likelihood/TreeLikelihood.java @@ -26,6 +26,7 @@ package beast.base.evolution.likelihood; +import java.lang.reflect.InvocationTargetException; import java.util.*; import beast.base.core.Description; @@ -41,6 +42,7 @@ import beast.base.evolution.tree.Tree; import beast.base.evolution.tree.TreeInterface; import beast.base.inference.State; +import beast.pkgmgmt.BEASTClassLoader; @Description("Calculates the probability of sequence data on a beast.tree given a site and substitution model using " + "a variant of the 'peeling algorithm'. For details, see" + @@ -51,7 +53,7 @@ public class TreeLikelihood extends GenericTreeLikelihood { final public Input m_useTipLikelihoods = new Input<>("useTipLikelihoods", "flag to indicate that partial likelihoods are provided at the tips", false); final public Input implementationInput = new Input<>("implementation", "name of class that implements this treelikelihood potentially more efficiently. " + "This class will be tried first, with the TreeLikelihood as fallback implementation. " - + "When multi-threading, multiple objects can be created.", "beast.evolution.likelihood.BeagleTreeLikelihood"); + + "When multi-threading, multiple objects can be created.", BeagleTreeLikelihood.class.getName()); public static enum Scaling {none, always, _default}; final public Input scaling = new Input<>("scaling", "type of scaling to use, one of " + Arrays.toString(Scaling.values()) + ". If not specified, the -beagle_scaling flag is used.", Scaling._default, Scaling.values()); @@ -115,6 +117,11 @@ public static enum Scaling {none, always, _default}; */ protected boolean useAscertainedSitePatterns = false; + /** + * alias for the data + */ + protected Alignment alignment; + /** * dealing with proportion of site being invariant * */ @@ -128,9 +135,14 @@ public static enum Scaling {none, always, _default}; @Override public void initAndValidate() { + // sanity check: make sure data is an Alignment + if (!(dataInput.get() instanceof Alignment)) { + throw new RuntimeException("Expected Alignment as data, not " + dataInput.get().getClass().getName()); + } + alignment = (Alignment) dataInput.get(); + // sanity check: alignment should have same #taxa as tree - - if (dataInput.get().getTaxonCount() != treeInput.get().getLeafNodeCount()) { + if (alignment.getTaxonCount() != treeInput.get().getLeafNodeCount()) { String leaves = "?"; if (treeInput.get() instanceof Tree) { leaves = String.join(", ", ((Tree) treeInput.get()).getTaxaNames()); @@ -138,8 +150,8 @@ public void initAndValidate() { throw new IllegalArgumentException(String.format( "The number of leaves in the tree (%d) does not match the number of sequences (%d). " + "The tree has leaves [%s], while the data refers to taxa [%s].", - treeInput.get().getLeafNodeCount(), dataInput.get().getTaxonCount(), - leaves, String.join(", ", dataInput.get().getTaxaNames()))); + treeInput.get().getLeafNodeCount(), alignment.getTaxonCount(), + leaves, String.join(", ", alignment.getTaxaNames()))); } beagle = null; @@ -152,8 +164,11 @@ public void initAndValidate() { // } // // if (!hasImaginaryEigenvectors) { - beagle = new BeagleTreeLikelihood(); try { + Object o = newTreeLikelihood(); + if (o instanceof BeagleTreeLikelihood) { + beagle = (BeagleTreeLikelihood) o; + } beagle.initByName( "data", dataInput.get(), "tree", treeInput.get(), "siteModel", siteModelInput.get(), "branchRateModel", branchRateModelInput.get(), "useAmbiguities", m_useAmbiguities.get(), @@ -177,7 +192,7 @@ public void initAndValidate() { throw new IllegalArgumentException("siteModel input should be of type SiteModel.Base"); } m_siteModel = (SiteModel.Base) siteModelInput.get(); - m_siteModel.setDataType(dataInput.get().getDataType()); + m_siteModel.setDataType(alignment.getDataType()); substitutionModel = m_siteModel.substModelInput.get(); if (branchRateModelInput.get() != null) { @@ -188,13 +203,12 @@ public void initAndValidate() { m_branchLengths = new double[nodeCount]; storedBranchLengths = new double[nodeCount]; - int stateCount = dataInput.get().getMaxStateCount(); - int patterns = dataInput.get().getPatternCount(); + int stateCount = alignment.getMaxStateCount(); + int patterns = alignment.getPatternCount(); likelihoodCore = createLikelihoodCore(stateCount); String className = getClass().getSimpleName(); - Alignment alignment = dataInput.get(); Log.info.println(className + "(" + getID() + ") uses " + likelihoodCore.getClass().getSimpleName()); Log.info.println(" " + alignment.toString(true)); @@ -214,12 +228,18 @@ public void initAndValidate() { probabilities = new double[(stateCount + 1) * (stateCount + 1)]; Arrays.fill(probabilities, 1.0); - if (dataInput.get().isAscertained) { + if (alignment.isAscertained) { useAscertainedSitePatterns = true; } } - protected LikelihoodCore createLikelihoodCore(int stateCount) { + private TreeLikelihood newTreeLikelihood() throws InstantiationException, IllegalAccessException, IllegalArgumentException, InvocationTargetException, NoSuchMethodException, SecurityException, ClassNotFoundException { + String class_ = implementationInput.get(); + TreeLikelihood tl = (TreeLikelihood) BEASTClassLoader.forName(class_).getConstructor().newInstance(); + return tl; + } + + protected LikelihoodCore createLikelihoodCore(int stateCount) { if (stateCount == 4) { return new BeerLikelihoodCore4(); } else { @@ -239,12 +259,12 @@ protected LikelihoodCore createLikelihoodCore(int stateCount) { protected void calcConstantPatternIndices(final int patterns, final int stateCount) { constantPattern = new ArrayList<>(); for (int i = 0; i < patterns; i++) { - final int[] pattern = dataInput.get().getPattern(i); + final int[] pattern = alignment.getPattern(i); final boolean[] isInvariant = new boolean[stateCount]; Arrays.fill(isInvariant, true); for (final int state : pattern) { - final boolean[] isStateSet = dataInput.get().getStateSet(state); - if (m_useAmbiguities.get() || !dataInput.get().getDataType().isAmbiguousCode(state)) { + final boolean[] isStateSet = alignment.getStateSet(state); + if (m_useAmbiguities.get() || !alignment.getDataType().isAmbiguousCode(state)) { for (int k = 0; k < stateCount; k++) { isInvariant[k] &= isStateSet[k]; } @@ -262,7 +282,7 @@ protected void initCore() { final int nodeCount = treeInput.get().getNodeCount(); likelihoodCore.initialize( nodeCount, - dataInput.get().getPatternCount(), + alignment.getPatternCount(), m_siteModel.getCategoryCount(), true, m_useAmbiguities.get() ); @@ -271,9 +291,9 @@ protected void initCore() { final int intNodeCount = nodeCount / 2; if (m_useAmbiguities.get() || m_useTipLikelihoods.get()) { - setPartials(treeInput.get().getRoot(), dataInput.get().getPatternCount()); + setPartials(treeInput.get().getRoot(), alignment.getPatternCount()); } else { - setStates(treeInput.get().getRoot(), dataInput.get().getPatternCount()); + setStates(treeInput.get().getRoot(), alignment.getPatternCount()); } hasDirt = Tree.IS_FILTHY; for (int i = 0; i < intNodeCount; i++) { @@ -294,13 +314,12 @@ public void sample(State state, Random random) { */ protected void setStates(Node node, int patternCount) { if (node.isLeaf()) { - Alignment data = dataInput.get(); int i; int[] states = new int[patternCount]; - int taxonIndex = getTaxonIndex(node.getID(), data); + int taxonIndex = getTaxonIndex(node.getID(), alignment); for (i = 0; i < patternCount; i++) { - int code = data.getPattern(taxonIndex, i); - int[] statesForCode = data.getDataType().getStatesForCode(code); + int code = alignment.getPattern(taxonIndex, i); + int[] statesForCode = alignment.getDataType().getStatesForCode(code); if (statesForCode.length==1) states[i] = statesForCode[0]; else @@ -339,21 +358,20 @@ private int getTaxonIndex(String taxon, Alignment data) { */ protected void setPartials(Node node, int patternCount) { if (node.isLeaf()) { - Alignment data = dataInput.get(); - int states = data.getDataType().getStateCount(); + int states = alignment.getDataType().getStateCount(); double[] partials = new double[patternCount * states]; int k = 0; - int taxonIndex = getTaxonIndex(node.getID(), data); + int taxonIndex = getTaxonIndex(node.getID(), alignment); for (int patternIndex_ = 0; patternIndex_ < patternCount; patternIndex_++) { - double[] tipLikelihoods = data.getTipLikelihoods(taxonIndex,patternIndex_); + double[] tipLikelihoods = alignment.getTipLikelihoods(taxonIndex,patternIndex_); if (tipLikelihoods != null) { for (int state = 0; state < states; state++) { partials[k++] = tipLikelihoods[state]; } } else { - int stateCount = data.getPattern(taxonIndex, patternIndex_); - boolean[] stateSet = data.getStateSet(stateCount); + int stateCount = alignment.getPattern(taxonIndex, patternIndex_); + boolean[] stateSet = alignment.getStateSet(stateCount); for (int state = 0; state < states; state++) { partials[k++] = (stateSet[state] ? 1.0 : 0.0); } @@ -369,6 +387,9 @@ protected void setPartials(Node node, int patternCount) { // for testing public double[] getRootPartials() { + if (beagle != null) { + return beagle.getRootPartials(); + } return m_fRootPartials.clone(); } @@ -424,13 +445,13 @@ public double calculateLogP() { protected void calcLogP() { logP = 0.0; if (useAscertainedSitePatterns) { - final double ascertainmentCorrection = dataInput.get().getAscertainmentCorrection(patternLogLikelihoods); - for (int i = 0; i < dataInput.get().getPatternCount(); i++) { - logP += (patternLogLikelihoods[i] - ascertainmentCorrection) * dataInput.get().getPatternWeight(i); + final double ascertainmentCorrection = alignment.getAscertainmentCorrection(patternLogLikelihoods); + for (int i = 0; i < alignment.getPatternCount(); i++) { + logP += (patternLogLikelihoods[i] - ascertainmentCorrection) * alignment.getPatternWeight(i); } } else { - for (int i = 0; i < dataInput.get().getPatternCount(); i++) { - logP += patternLogLikelihoods[i] * dataInput.get().getPatternWeight(i); + for (int i = 0; i < alignment.getPatternCount(); i++) { + logP += patternLogLikelihoods[i] * alignment.getPatternWeight(i); } } }