From 6e734c2c38ea7e819c30d38d53a317962037755f Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 14 Apr 2024 21:43:19 -0400 Subject: [PATCH 01/15] Starting work on GMM. --- Clustering/GMM/pom.xml | 98 ++++ .../org/tribuo/clustering/gmm/GMMOptions.java | 77 +++ .../org/tribuo/clustering/gmm/GMMTrainer.java | 511 ++++++++++++++++++ .../clustering/gmm/GaussianMixtureModel.java | 327 +++++++++++ .../org/tribuo/clustering/gmm/TrainTest.java | 134 +++++ .../tribuo/clustering/gmm/package-info.java | 20 + .../protos/tribuo-clustering-gmm.proto | 41 ++ .../clustering/kmeans/KMeansOptions.java | 5 +- Clustering/pom.xml | 1 + Core/pom.xml | 4 - Core/src/main/java/org/tribuo/util/Util.java | 20 + Data/pom.xml | 7 +- distribution/pom.xml | 5 + 13 files changed, 1239 insertions(+), 11 deletions(-) create mode 100644 Clustering/GMM/pom.xml create mode 100644 Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMOptions.java create mode 100644 Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java create mode 100644 Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java create mode 100644 Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/TrainTest.java create mode 100644 Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/package-info.java create mode 100644 Clustering/GMM/src/main/resources/protos/tribuo-clustering-gmm.proto diff --git a/Clustering/GMM/pom.xml b/Clustering/GMM/pom.xml new file mode 100644 index 000000000..be68c43f9 --- /dev/null +++ b/Clustering/GMM/pom.xml @@ -0,0 +1,98 @@ + + + + + 4.0.0 + + org.tribuo + tribuo-clustering + 5.0.0-SNAPSHOT + ../pom.xml + + Clustering-GMM + tribuo-clustering-gmm + jar + + 17 + + + + + ${project.groupId} + tribuo-core + + + org.tribuo + tribuo-data + + + ${project.groupId} + tribuo-math + + + ${project.groupId} + tribuo-clustering-core + ${project.version} + + + com.oracle.labs.olcut + olcut-core + + + + ${project.groupId} + tribuo-core + ${project.version} + test-jar + test + + + org.junit.jupiter + junit-jupiter + ${junit.version} + test + + + + + + + org.apache.maven.plugins + maven-assembly-plugin + + + jar-with-dependencies + + + + + make-assembly + package + + single + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + + + diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMOptions.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMOptions.java new file mode 100644 index 000000000..5eb27cd89 --- /dev/null +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMOptions.java @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.clustering.gmm; + +import com.oracle.labs.mlrg.olcut.config.Option; +import com.oracle.labs.mlrg.olcut.config.Options; +import org.tribuo.Trainer; +import org.tribuo.clustering.gmm.GMMTrainer.Initialisation; +import org.tribuo.clustering.gmm.GMMTrainer.CovarianceType; + +import java.util.logging.Logger; + +/** + * OLCUT {@link Options} for the GMM implementation. + */ +public class GMMOptions implements Options { + private static final Logger logger = Logger.getLogger(GMMOptions.class.getName()); + + /** + * Iterations of the GMM algorithm. Defaults to 10. + */ + @Option(longName = "gmm-interations", usage = "Iterations of the GMM algorithm. Defaults to 10.") + public int iterations = 10; + /** + * Number of centroids/Gaussians in GMM. Defaults to 10. + */ + @Option(longName = "gmm-num-centroids", usage = "Number of centroids in GMM. Defaults to 10.") + public int centroids = 10; + /** + * The covariance type of the Gaussians. + */ + @Option(charName = 'v', longName = "covariance-type", usage = "Set the covariance type.") + public CovarianceType covarianceType = CovarianceType.DIAGONAL; + /** + * Initialisation function in GMM. Defaults to RANDOM. + */ + @Option(longName = "gmm-initialisation", usage = "Initialisation function in GMM. Defaults to RANDOM.") + public Initialisation initialisation = GMMTrainer.Initialisation.RANDOM; + /** + * Convergence tolerance to terminate EM early. + */ + @Option(longName = "gmm-tolerance", usage = "The convergence threshold.") + public double tolerance = 1e-3f; + /** + * Number of computation threads in GMM. Defaults to 4. + */ + @Option(longName = "gmm-num-threads", usage = "Number of computation threads in GMM. Defaults to 4.") + public int numThreads = 4; + /** + * The RNG seed. + */ + @Option(longName = "gmm-seed", usage = "Sets the random seed for GMM.") + public long seed = Trainer.DEFAULT_SEED; + + /** + * Gets the configured GMMTrainer using the options in this object. + * @return A GMMTrainer. + */ + public GMMTrainer getTrainer() { + logger.info("Configuring GMM Trainer"); + return new GMMTrainer(centroids, iterations, covarianceType, initialisation, tolerance, numThreads, seed); + } +} diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java new file mode 100644 index 000000000..c01dd01dd --- /dev/null +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -0,0 +1,511 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.clustering.gmm; + +import com.oracle.labs.mlrg.olcut.config.Config; +import com.oracle.labs.mlrg.olcut.config.PropertyException; +import com.oracle.labs.mlrg.olcut.provenance.Provenance; +import com.oracle.labs.mlrg.olcut.util.MutableLong; +import com.oracle.labs.mlrg.olcut.util.StreamUtil; +import org.tribuo.Dataset; +import org.tribuo.Example; +import org.tribuo.ImmutableFeatureMap; +import org.tribuo.ImmutableOutputInfo; +import org.tribuo.Trainer; +import org.tribuo.clustering.ClusterID; +import org.tribuo.clustering.ImmutableClusteringInfo; +import org.tribuo.math.la.DenseMatrix; +import org.tribuo.math.la.DenseVector; +import org.tribuo.math.la.SGDVector; +import org.tribuo.math.la.SparseVector; +import org.tribuo.provenance.ModelProvenance; +import org.tribuo.provenance.TrainerProvenance; +import org.tribuo.provenance.impl.TrainerProvenanceImpl; +import org.tribuo.util.Util; + +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.time.OffsetDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.SplittableRandom; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.ForkJoinWorkerThread; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +/** + * A Gaussian Mixture Model trainer, which generates a GMM clustering of the supplied + * data. The model finds the centres, and then predict needs to be + * called to infer the centre assignments for the input data. + *

+ * It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments + * can only be retrieved from the model after training, and require re-evaluating each example. + *

+ * The Trainer has a parameterised distance function, and a selectable number + * of threads used in the training step. The thread pool is local to an invocation of train, + * so there can be multiple concurrent trainings. + *

+ * The train method will instantiate dense examples as dense vectors, speeding up the computation. + *

+ * Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase + * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a + * {@link SecurityManager}. + *

+ * See: + *

+ * J. Friedman, T. Hastie, & R. Tibshirani.
+ * "The Elements of Statistical Learning"
+ * Springer 2001. PDF
+ * 
+ *

+ * For more on optional kmeans++ initialisation, see: + *

+ * D. Arthur, S. Vassilvitskii.
+ * "K-Means++: The Advantages of Careful Seeding"
+ * PDF
+ * 
+ */ +public class GMMTrainer implements Trainer { + private static final Logger logger = Logger.getLogger(GMMTrainer.class.getName()); + + // Thread factory for the FJP, to allow use with OpenSearch's SecureSM + private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory(); + + public enum CovarianceType { + /** + * Full covariance. + */ + FULL, + /** + * Diagonal covariance. + */ + DIAGONAL, + /** + * Spherical covariance. + */ + SPHERICAL + } + + /** + * Possible initialization functions. + */ + public enum Initialisation { + /** + * Initialize Gaussians by choosing uniformly at random from the data + * points. + */ + RANDOM, + /** + * KMeans++ initialisation. + */ + PLUSPLUS + } + + @Config(mandatory = true, description = "Number of centroids.") + private int centroids; + + @Config(mandatory = true, description = "The number of iterations to run.") + private int iterations; + + @Config(description = "The convergence threshold.") + private double convergenceTolerance = 1e-3f; + + @Config(description = "The type of covariance matrix to fit.") + private CovarianceType covarianceType = CovarianceType.DIAGONAL; + + @Config(description = "The centroid initialisation method to use.") + private Initialisation initialisationType = Initialisation.RANDOM; + + @Config(description = "The number of threads to use for training.") + private int numThreads = 1; + + @Config(mandatory = true, description = "The seed to use for the RNG.") + private long seed; + + private SplittableRandom rng; + + private int trainInvocationCounter; + + /** + * for olcut. + */ + private GMMTrainer() { } + + /** + * Constructs a K-Means trainer using the supplied parameters and the default random initialisation. + * + * @param centroids The number of centroids to use. + * @param iterations The maximum number of iterations. + * @param numThreads The number of threads. + * @param seed The random seed. + */ + public GMMTrainer(int centroids, int iterations, int numThreads, long seed) { + this(centroids,iterations,CovarianceType.DIAGONAL,Initialisation.RANDOM,1e-3,numThreads,seed); + } + + /** + * Constructs a K-Means trainer using the supplied parameters. + * + * @param centroids The number of centroids to use. + * @param iterations The maximum number of iterations. + * @param initialisationType The centroid initialization method. + * @param numThreads The number of threads. + * @param seed The random seed. + */ + public GMMTrainer(int centroids, int iterations, CovarianceType covarianceType, Initialisation initialisationType, double tolerance, int numThreads, long seed) { + this.centroids = centroids; + this.iterations = iterations; + this.covarianceType = covarianceType; + this.initialisationType = initialisationType; + this.convergenceTolerance = tolerance; + this.numThreads = numThreads; + this.seed = seed; + postConfig(); + } + + /** + * Used by the OLCUT configuration system, and should not be called by external code. + */ + @Override + public synchronized void postConfig() { + this.rng = new SplittableRandom(seed); + + if (centroids < 1) { + throw new PropertyException("centroids", "Centroids must be positive, found " + centroids); + } + } + + @Override + public GaussianMixtureModel train(Dataset examples, Map runProvenance) { + return train(examples, runProvenance, INCREMENT_INVOCATION_COUNT); + } + + @Override + public GaussianMixtureModel train(Dataset examples, Map runProvenance, int invocationCount) { + // Creates a new local RNG and adds one to the invocation count. + TrainerProvenance trainerProvenance; + SplittableRandom localRNG; + synchronized (this) { + if(invocationCount != INCREMENT_INVOCATION_COUNT) { + setInvocationCount(invocationCount); + } + localRNG = rng.split(); + trainerProvenance = getProvenance(); + trainInvocationCounter++; + } + ImmutableFeatureMap featureMap = examples.getFeatureIDMap(); + + SGDVector[] data = new SGDVector[examples.size()]; + double[] weights = new double[examples.size()]; + int n = 0; + for (Example example : examples) { + weights[n] = example.getWeight(); + if (example.size() == featureMap.size()) { + data[n] = DenseVector.createDenseVector(example, featureMap, false); + } else { + data[n] = SparseVector.createSparseVector(example, featureMap, false); + } + n++; + } + + DenseMatrix responsibilities = new DenseMatrix(examples.size(), centroids); + DenseVector[] centroidVectors = switch (initialisationType) { + case RANDOM -> initialiseRandomCentroids(centroids, featureMap, localRNG); + case PLUSPLUS -> initialisePlusPlusCentroids(centroids, data, localRNG, dist); + }; + + Map> clusterAssignments = new HashMap<>(); + boolean parallel = numThreads > 1; + for (int i = 0; i < centroids; i++) { + clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>()); + } + + AtomicInteger changeCounter = new AtomicInteger(0); + Consumer eStepFunc = (IntAndVector e) -> { + double minDist = Double.POSITIVE_INFINITY; + int clusterID = -1; + int id = e.idx; + SGDVector vector = e.vector; + for (int j = 0; j < centroids; j++) { + DenseVector cluster = centroidVectors[j]; + double distance = dist.computeDistance(cluster, vector); + if (distance < minDist) { + minDist = distance; + clusterID = j; + } + } + + clusterAssignments.get(clusterID).add(id); + if (oldCentre[id] != clusterID) { + // Changed the centroid of this vector. + oldCentre[id] = clusterID; + changeCounter.incrementAndGet(); + } + }; + + boolean converged = false; + ForkJoinPool fjp = null; + try { + if (parallel) { + if (System.getSecurityManager() == null) { + fjp = new ForkJoinPool(numThreads); + } else { + fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false); + } + } + for (int i = 0; (i < iterations) && !converged; i++) { + logger.log(Level.FINE,"Beginning iteration " + i); + changeCounter.set(0); + + for (Entry> e : clusterAssignments.entrySet()) { + e.getValue().clear(); + } + + // E step + Stream vecStream = Arrays.stream(data); + Stream intStream = IntStream.range(0, data.length).boxed(); + Stream zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); + if (parallel) { + Stream parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel()); + try { + fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Parallel execution failed", e); + } + } else { + zipStream.forEach(eStepFunc); + } + logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated."); + + mStep(fjp, centroidVectors, clusterAssignments, data, weights); + + logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated."); + + if (changeCounter.get() == 0) { + converged = true; + logger.log(Level.INFO, "K-Means converged at iteration " + i); + } + } + } finally { + if (fjp != null) { + fjp.shutdown(); + } + } + + Map counts = new HashMap<>(); + for (Entry> e : clusterAssignments.entrySet()) { + counts.put(e.getKey(), new MutableLong(e.getValue().size())); + } + + ImmutableOutputInfo outputMap = new ImmutableClusteringInfo(counts); + + ModelProvenance provenance = new ModelProvenance(GaussianMixtureModel.class.getName(), OffsetDateTime.now(), + examples.getProvenance(), trainerProvenance, runProvenance); + + return new GaussianMixtureModel("gaussian-mixture-model", provenance, featureMap, outputMap, centroidVectors, dist); + } + + @Override + public GaussianMixtureModel train(Dataset dataset) { + return train(dataset, Collections.emptyMap()); + } + + @Override + public int getInvocationCount() { + return trainInvocationCounter; + } + + @Override + public synchronized void setInvocationCount(int invocationCount){ + if(invocationCount < 0){ + throw new IllegalArgumentException("The supplied invocationCount is less than zero."); + } + + rng = new SplittableRandom(seed); + + for (trainInvocationCounter = 0; trainInvocationCounter < invocationCount; trainInvocationCounter++){ + SplittableRandom localRNG = rng.split(); + } + + } + + /** + * Initialisation method called at the start of each train call when using the default centroid initialisation. + * Centroids are initialised using a uniform random sample from the feature domain. + * + * @param centroids The number of centroids to create. + * @param featureMap The feature map to use for centroid sampling. + * @param rng The RNG to use. + * @return A {@link DenseVector} array of centroids. + */ + private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableFeatureMap featureMap, + SplittableRandom rng) { + DenseVector[] centroidVectors = new DenseVector[centroids]; + int numFeatures = featureMap.size(); + for (int i = 0; i < centroids; i++) { + double[] newCentroid = new double[numFeatures]; + + for (int j = 0; j < numFeatures; j++) { + newCentroid[j] = featureMap.get(j).uniformSample(rng); + } + + centroidVectors[i] = DenseVector.createDenseVector(newCentroid); + } + return centroidVectors; + } + + /** + * Initialisation method called at the start of each train call when using kmeans++ centroid initialisation. + * + * @param centroids The number of centroids to create. + * @param data The dataset of {@link SGDVector} to use. + * @param rng The RNG to use. + * @param dist The distance function. + * @return A {@link DenseVector} array of centroids. + */ + private static DenseVector[] initialisePlusPlusCentroids(int centroids, SGDVector[] data, SplittableRandom rng, + org.tribuo.math.distance.Distance dist) { + if (centroids > data.length) { + throw new IllegalArgumentException("The number of centroids may not exceed the number of samples."); + } + + double[] minDistancePerVector = new double[data.length]; + Arrays.fill(minDistancePerVector, Double.POSITIVE_INFINITY); + + double[] squaredMinDistance = new double[data.length]; + double[] probabilities = new double[data.length]; + DenseVector[] centroidVectors = new DenseVector[centroids]; + + // set first centroid randomly from the data + centroidVectors[0] = getRandomCentroidFromData(data, rng); + + // Set each uninitialised centroid remaining + for (int i = 1; i < centroids; i++) { + DenseVector prevCentroid = centroidVectors[i - 1]; + + // go through every vector and see if the min distance to the + // newest centroid is smaller than previous min distance for vec + for (int j = 0; j < data.length; j++) { + double tempDistance = dist.computeDistance(prevCentroid, data[j]); + minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance); + } + + // square the distances and get total for normalization + double total = 0.0; + for (int j = 0; j < data.length; j++) { + squaredMinDistance[j] = minDistancePerVector[j] * minDistancePerVector[j]; + total += squaredMinDistance[j]; + } + + // compute probabilities as p[i] = D(xi)^2 / sum(D(x)^2) + for (int j = 0; j < probabilities.length; j++) { + probabilities[j] = squaredMinDistance[j] / total; + } + + // sample from probabilities to get the new centroid from data + double[] cdf = Util.generateCDF(probabilities); + int idx = Util.sampleFromCDF(cdf, rng); + centroidVectors[i] = DenseVector.createDenseVector(data[idx].toArray()); + } + return centroidVectors; + } + + /** + * Randomly select a piece of data as the starting centroid. + * + * @param data The dataset of {@link SparseVector} to use. + * @param rng The RNG to use. + * @return A {@link DenseVector} representing a centroid. + */ + private static DenseVector getRandomCentroidFromData(SGDVector[] data, SplittableRandom rng) { + int randIdx = rng.nextInt(data.length); + return DenseVector.createDenseVector(data[randIdx].toArray()); + } + + /** + * Runs the mStep, writing to the {@code centroidVectors} array. + * @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel. + * If the fjp is null then the computation is executed sequentially on the main thread. + * @param centroidVectors The centroid vectors to write out. + * @param clusterAssignments The current cluster assignments. + * @param data The data points. + * @param weights The example weights. + */ + protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map> clusterAssignments, SGDVector[] data, double[] weights) { + // M step + Consumer>> mStepFunc = (e) -> { + DenseVector newCentroid = centroidVectors[e.getKey()]; + newCentroid.fill(0.0); + + double weightSum = 0.0; + for (Integer idx : e.getValue()) { + newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]); + weightSum += weights[idx]; + } + if (weightSum != 0.0) { + newCentroid.scaleInPlace(1.0 / weightSum); + } + }; + + Stream>> mStream = clusterAssignments.entrySet().stream(); + if (fjp != null) { + Stream>> parallelMStream = StreamUtil.boundParallelism(mStream.parallel()); + try { + fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Parallel execution failed", e); + } + } else { + mStream.forEach(mStepFunc); + } + } + + @Override + public String toString() { + return "GMMTrainer(centroids=" + centroids + ",seed=" + seed + ",numThreads=" + numThreads + ", initialisationType=" + initialisationType + ")"; + } + + @Override + public TrainerProvenance getProvenance() { + return new TrainerProvenanceImpl(this); + } + + /** + * Tuple of index and position. + */ + record IntAndVector(int idx, SGDVector vector) { } + + /** + * Used to allow FJPs to work with OpenSearch's SecureSM. + */ + private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory { + public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { + return AccessController.doPrivileged((PrivilegedAction) () -> new ForkJoinWorkerThread(pool) {}); + } + } +} \ No newline at end of file diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java new file mode 100644 index 000000000..1996675f9 --- /dev/null +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java @@ -0,0 +1,327 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.clustering.gmm; + +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.oracle.labs.mlrg.olcut.util.Pair; +import org.tribuo.Example; +import org.tribuo.Excuse; +import org.tribuo.Feature; +import org.tribuo.ImmutableFeatureMap; +import org.tribuo.ImmutableOutputInfo; +import org.tribuo.Model; +import org.tribuo.Prediction; +import org.tribuo.clustering.ClusterID; +import org.tribuo.impl.ArrayExample; +import org.tribuo.impl.ModelDataCarrier; +import org.tribuo.math.la.DenseMatrix; +import org.tribuo.math.la.DenseVector; +import org.tribuo.math.la.SGDVector; +import org.tribuo.math.la.SparseVector; +import org.tribuo.math.la.Tensor; +import org.tribuo.math.la.VectorTuple; +import org.tribuo.math.protos.TensorProto; +import org.tribuo.protos.core.ModelProto; +import org.tribuo.provenance.ModelProvenance; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.random.RandomGenerator; + +/** + * A Gaussian Mixture Model. + *

+ * The predict method of this model assigns the provided input to a cluster, + * but it does not update the model's centroids. + *

+ * The predict method is single threaded. + *

+ * See: + *

+ * J. Friedman, T. Hastie, & R. Tibshirani.
+ * "The Elements of Statistical Learning"
+ * Springer 2001. PDF
+ * 
+ */ +public class GaussianMixtureModel extends Model { + private static final long serialVersionUID = 1L; + + /** + * Protobuf serialization version. + */ + public static final int CURRENT_VERSION = 0; + + private final DenseVector[] meanVectors; + + private final DenseMatrix[] covarianceMatrices; + + private final DenseVector mixingDistribution; + + GaussianMixtureModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, + ImmutableOutputInfo outputIDInfo, DenseVector[] meanVectors, + DenseMatrix[] covarianceMatrices, DenseVector mixingDistribution) { + super(name,description,featureIDMap,outputIDInfo,false); + this.meanVectors = meanVectors; + this.covarianceMatrices = covarianceMatrices; + this.mixingDistribution = mixingDistribution; + } + + /** + * Deserialization factory. + * @param version The serialized object version. + * @param className The class name. + * @param message The serialized data. + * @throws InvalidProtocolBufferException If the protobuf could not be parsed from the {@code message}. + * @return The deserialized object. + */ + public static GaussianMixtureModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException { + if (version < 0 || version > CURRENT_VERSION) { + throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION); + } + GaussianMixtureModelProto proto = message.unpack(GaussianMixtureModelProto.class); + + ModelDataCarrier carrier = ModelDataCarrier.deserialize(proto.getMetadata()); + if (!carrier.outputDomain().getOutput(0).getClass().equals(ClusterID.class)) { + throw new IllegalStateException("Invalid protobuf, output domain is not a clustering domain, found " + carrier.outputDomain().getClass()); + } + @SuppressWarnings("unchecked") // guarded by getClass + ImmutableOutputInfo outputDomain = (ImmutableOutputInfo) carrier.outputDomain(); + + ImmutableFeatureMap featureDomain = carrier.featureDomain(); + + final int means = proto.getMeanVectorsCount(); + + if (means == 0) { + throw new IllegalStateException("Invalid protobuf, no centroids were found"); + } else if (proto.getCovarianceMatricesCount() != means) { + throw new IllegalStateException("Invalid protobuf, found " + means + " means, but " + proto.getCovarianceMatricesCount() + " covariances."); + } + DenseVector[] centroids = new DenseVector[means]; + List centroidProtos = proto.getMeanVectorsList(); + for (int i = 0; i < centroids.length; i++) { + Tensor centroidTensor = Tensor.deserialize(centroidProtos.get(i)); + if (centroidTensor instanceof DenseVector centroid) { + if (centroid.size() != featureDomain.size()) { + throw new IllegalStateException("Invalid protobuf, centroid did not contain all the features, found " + centroid.size() + " expected " + featureDomain.size()); + } + centroids[i] = centroid; + } else { + throw new IllegalStateException("Invalid protobuf, expected centroid to be a dense vector, found " + centroidTensor.getClass()); + } + } + DenseMatrix[] covariances = new DenseMatrix[means]; + List covarianceProtos = proto.getCovarianceMatricesList(); + for (int i = 0; i < covariances.length; i++) { + Tensor covarianceTensor = Tensor.deserialize(covarianceProtos.get(i)); + if (covarianceTensor instanceof DenseMatrix covariance) { + if (covariance.getDimension1Size() != featureDomain.size() || covariance.getDimension2Size() != featureDomain.size()) { + throw new IllegalStateException("Invalid protobuf, covariance was not square or did not contain all " + + "the features, found " + Arrays.toString(covariance.getShape()) + " expected [" + featureDomain.size() + ", " + featureDomain.size() +"]"); + } + covariances[i] = covariance; + } else { + throw new IllegalStateException("Invalid protobuf, expected covariance to be a dense matrix, found " + covarianceTensor.getClass()); + } + } + Tensor mixing = Tensor.deserialize(proto.getMixingDistribution()); + DenseVector mixingVec; + if (mixing instanceof DenseVector mixingDist) { + if (mixingDist.size() != means) { + throw new IllegalStateException("Invalid protobuf, found " + means + " but a " + mixingDist.size() + " element mixing distribution"); + } else { + mixingVec = mixingDist; + } + } else { + throw new IllegalStateException("Invalid protobuf, expected mixing distribution to be a dense vector, found " + mixing.getClass()); + } + + return new GaussianMixtureModel(carrier.name(), carrier.provenance(), featureDomain, outputDomain, centroids, covariances, mixingVec); + } + + /** + * Returns a copy of the centroids. + *

+ * In most cases you should prefer {@link #getMeans} as + * it performs the mapping from Tribuo's internal feature ids + * to the externally visible feature names for you. + * This method provides direct access to the centroid vectors + * for use in downstream processing if the ids are not relevant + * (or are known to match). + * @return The centroids. + */ + public DenseVector[] getMeanVectors() { + DenseVector[] copies = new DenseVector[meanVectors.length]; + + for (int i = 0; i < copies.length; i++) { + copies[i] = meanVectors[i].copy(); + } + + return copies; + } + + /** + * Returns a copy of the covariances. + *

+ * This method provides direct access to the covariance matrices + * for use in downstream processing, users need to map the indices using + * Tribuo's internal ids themselves. + * @return The covariances. + */ + public DenseMatrix[] getCovariances() { + DenseMatrix[] copies = new DenseMatrix[covarianceMatrices.length]; + + for (int i = 0; i < copies.length; i++) { + copies[i] = covarianceMatrices[i].copy(); + } + + return copies; + } + + /** + * Returns a list of features, one per centroid. + *

+ * This should be used in preference to {@link #getMeanVectors()} + * as it performs the mapping from Tribuo's internal feature ids to + * the externally visible feature names. + *

+ * @return A list containing all the centroids. + */ + public List> getMeans() { + List> output = new ArrayList<>(meanVectors.length); + + for (int i = 0; i < meanVectors.length; i++) { + List features = new ArrayList<>(featureIDMap.size()); + + for (VectorTuple v : meanVectors[i]) { + Feature f = new Feature(featureIDMap.get(v.index).getName(),v.value); + features.add(f); + } + + output.add(features); + } + + return output; + } + + @Override + public Prediction predict(Example example) { + SGDVector vector; + if (example.size() == featureIDMap.size()) { + vector = DenseVector.createDenseVector(example, featureIDMap, false); + } else { + vector = SparseVector.createSparseVector(example, featureIDMap, false); + } + if (vector.numActiveElements() == 0) { + throw new IllegalArgumentException("No features found in Example " + example.toString()); + } + + // generate cluster responsibilities and normalize into a distribution + + return new Prediction<>(new ClusterID(id),vector.size(),example); + } + + @Override + public Map>> getTopFeatures(int n) { + return Collections.emptyMap(); + } + + @Override + public Optional> getExcuse(Example example) { + return Optional.empty(); + } + + /** + * Sample from this Gaussian Mixture Model. + * @param numSamples The number of samples to draw. + * @param rng The RNG to use. + * @return A list of samples from this GMM. + */ + public List> sample(int numSamples, RandomGenerator rng) { + // Convert mixing distribution into CDF + + // Sample from mixing distribution + + // Sample from appropriate MultivariateNormalDistribution + + } + + /** + * Sample from this Gaussian Mixture Model. + * @param numSamples The number of samples to draw. + * @param rng The RNG to use. + * @return A list of examples sampled from this GMM. + */ + public List> sampleExamples(int numSamples, RandomGenerator rng) { + var samples = sample(numSamples, rng); + + List> output = new ArrayList<>(); + + for (Pair e : samples) { + ClusterID id = outputIDInfo.getOutput(e.getA()); + String[] names = new String[e.getB().size()]; + double[] values = new double[e.getB().size()]; + for (VectorTuple v : e.getB()) { + names[v.index] = featureIDMap.get(v.index).getName(); + values[v.index] = v.value; + } + Example ex = new ArrayExample<>(id, names, values); + output.add(ex); + } + + return output; + } + + @Override + public ModelProto serialize() { + ModelDataCarrier carrier = createDataCarrier(); + + GaussianMixtureModelProto.Builder modelBuilder = GaussianMixtureModelProto.newBuilder(); + modelBuilder.setMetadata(carrier.serialize()); + for (DenseVector e : meanVectors) { + modelBuilder.addMeanVectors(e.serialize()); + } + for (DenseMatrix e : covarianceMatrices) { + modelBuilder.addCovarianceMatrices(e.serialize()); + } + modelBuilder.setMixingDistribution(mixingDistribution.serialize()); + + ModelProto.Builder builder = ModelProto.newBuilder(); + builder.setSerializedData(Any.pack(modelBuilder.build())); + builder.setClassName(GaussianMixtureModel.class.getName()); + builder.setVersion(CURRENT_VERSION); + + return builder.build(); + } + + @Override + protected GaussianMixtureModel copy(String newName, ModelProvenance newProvenance) { + DenseVector[] newCentroids = new DenseVector[meanVectors.length]; + DenseMatrix[] newCovariance = new DenseMatrix[meanVectors.length]; + for (int i = 0; i < meanVectors.length; i++) { + newCentroids[i] = meanVectors[i].copy(); + newCovariance[i] = covarianceMatrices[i].copy(); + } + + return new GaussianMixtureModel(newName,newProvenance,featureIDMap,outputIDInfo, + newCentroids,newCovariance,mixingDistribution.copy()); + } +} diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/TrainTest.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/TrainTest.java new file mode 100644 index 000000000..3b4d81fab --- /dev/null +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/TrainTest.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.clustering.gmm; + +import com.oracle.labs.mlrg.olcut.config.ConfigurationManager; +import com.oracle.labs.mlrg.olcut.config.Option; +import com.oracle.labs.mlrg.olcut.config.Options; +import com.oracle.labs.mlrg.olcut.config.UsageException; +import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter; +import com.oracle.labs.mlrg.olcut.util.Pair; +import org.tribuo.Dataset; +import org.tribuo.Model; +import org.tribuo.clustering.ClusterID; +import org.tribuo.clustering.ClusteringFactory; +import org.tribuo.clustering.evaluation.ClusteringEvaluation; +import org.tribuo.clustering.gmm.GMMTrainer.CovarianceType; +import org.tribuo.clustering.gmm.GMMTrainer.Initialisation; +import org.tribuo.data.DataOptions; + +import java.io.IOException; +import java.util.logging.Logger; + + +/** + * Build and run a Gaussian Mixture model for a standard dataset. + */ +public class TrainTest { + + private static final Logger logger = Logger.getLogger(TrainTest.class.getName()); + + /** + * Options for the GMM CLI. + */ + public static class GMMOptions implements Options { + @Override + public String getOptionsDescription() { + return "Trains and evaluates a Gaussian Mixture model on the specified dataset."; + } + + /** + * The data loading options. + */ + public DataOptions general; + + /** + * Number of clusters to infer. + */ + @Option(charName = 'n', longName = "num-clusters", usage = "Number of clusters to infer.") + public int centroids = 5; + /** + * Maximum number of iterations. + */ + @Option(charName = 'i', longName = "iterations", usage = "Maximum number of iterations.") + public int iterations = 10; + /** + * The covariance type of the gaussians. + */ + @Option(charName = 'v', longName = "covariance-type", usage = "Set the covariance type.") + public CovarianceType covarianceType = CovarianceType.DIAGONAL; + /** + * Type of initialisation to use for centroids. + */ + @Option(charName = 's', longName = "initialisation", usage = "Type of initialisation to use for centroids.") + public Initialisation initialisation = Initialisation.RANDOM; + /** + * Convergence tolerance to terminate EM early. + */ + @Option(longName = "tolerance", usage = "The convergence threshold.") + public double tolerance = 1e-3; + /** + * Number of threads to use (range (1, num hw threads)). + */ + @Option(charName = 't', longName = "num-threads", usage = "Number of threads to use (range (1, num hw threads)).") + public int numThreads = 4; + } + + /** + * Runs a TrainTest CLI. + * @param args the command line arguments + * @throws IOException if there is any error reading the examples. + */ + public static void main(String[] args) throws IOException { + // + // Use the labs format logging. + LabsLogFormatter.setAllLogFormatters(); + + GMMOptions o = new GMMOptions(); + ConfigurationManager cm; + try { + cm = new ConfigurationManager(args,o); + } catch (UsageException e) { + logger.info(e.getMessage()); + return; + } + + if (o.general.trainingPath == null) { + logger.info(cm.usage()); + return; + } + + ClusteringFactory factory = new ClusteringFactory(); + + Pair,Dataset> data = o.general.load(factory); + Dataset train = data.getA(); + + //public GMMTrainer(int centroids, int iterations, DistanceType distType, int numThreads, int seed) + GMMTrainer trainer = new GMMTrainer(o.centroids,o.iterations, + o.covarianceType,o.initialisation,o.tolerance,o.numThreads,o.general.seed); + Model model = trainer.train(train); + logger.info("Finished training model"); + ClusteringEvaluation evaluation = factory.getEvaluator().evaluate(model,train); + logger.info("Finished evaluating model"); + System.out.println("Normalized MI = " + evaluation.normalizedMI()); + System.out.println("Adjusted MI = " + evaluation.adjustedMI()); + + if (o.general.outputPath != null) { + o.general.saveModel(model); + } + } +} diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/package-info.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/package-info.java new file mode 100644 index 000000000..9a150c083 --- /dev/null +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Gaussian Mixture Model training and inference. + */ +package org.tribuo.clustering.gmm; \ No newline at end of file diff --git a/Clustering/GMM/src/main/resources/protos/tribuo-clustering-gmm.proto b/Clustering/GMM/src/main/resources/protos/tribuo-clustering-gmm.proto new file mode 100644 index 000000000..a41e874a7 --- /dev/null +++ b/Clustering/GMM/src/main/resources/protos/tribuo-clustering-gmm.proto @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +/* + * Protos for serializing Tribuo's GMM. + */ +package tribuo.clustering.gmm; + +option java_multiple_files = true; +option java_package = "org.tribuo.clustering.gmm.protos"; + +// Import Tribuo's core protos +import "tribuo-core.proto"; + +// Import Tribuo's math protos +import "tribuo-math.proto"; + +/* +GaussianMixtureModel proto + */ +message GaussianMixtureModelProto { + tribuo.core.ModelDataProto metadata = 1; + tribuo.math.TensorProto mixing_distribution = 2; + repeated tribuo.math.TensorProto mean_vectors = 3; + repeated tribuo.math.TensorProto covariance_matrices = 4; +} diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java index 792e4a7df..e2b239de2 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansOptions.java @@ -55,8 +55,11 @@ public class KMeansOptions implements Options { */ @Option(longName = "kmeans-num-threads", usage = "Number of computation threads in K-Means. Defaults to 4.") public int numThreads = 4; + /** + * The RNG seed. + */ @Option(longName = "kmeans-seed", usage = "Sets the random seed for K-Means.") - private long seed = Trainer.DEFAULT_SEED; + public long seed = Trainer.DEFAULT_SEED; /** * Gets the configured KMeansTrainer using the options in this object. diff --git a/Clustering/pom.xml b/Clustering/pom.xml index 125318075..0d406b658 100644 --- a/Clustering/pom.xml +++ b/Clustering/pom.xml @@ -28,6 +28,7 @@ pom Core + GMM KMeans Hdbscan diff --git a/Core/pom.xml b/Core/pom.xml index 29aa2dd31..57b1b8b69 100644 --- a/Core/pom.xml +++ b/Core/pom.xml @@ -27,10 +27,6 @@ Core tribuo-core jar - - 1.8 - 1.8 - diff --git a/Core/src/main/java/org/tribuo/util/Util.java b/Core/src/main/java/org/tribuo/util/Util.java index 060ef4722..1a3a3ed7a 100644 --- a/Core/src/main/java/org/tribuo/util/Util.java +++ b/Core/src/main/java/org/tribuo/util/Util.java @@ -29,6 +29,7 @@ import java.util.function.ToIntFunction; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.random.RandomGenerator; /** * Ye olde util class. @@ -517,6 +518,25 @@ public static int sampleFromCDF(double[] cdf, Random rng) { } } + /** + * Samples an index from the supplied cdf. + * @param cdf The cdf to sample from. + * @param rng The rng to use. + * @return A sample. + */ + public static int sampleFromCDF(double[] cdf, RandomGenerator rng) { + if (Math.abs(cdf[cdf.length-1] - 1.0) > 1e-6) { + throw new IllegalStateException("Weights do not sum to 1, cdf[cdf.length-1] = " + cdf[cdf.length-1]); + } + double uniform = rng.nextDouble(); + int searchVal = Arrays.binarySearch(cdf, uniform); + if (searchVal < 0) { + return - 1 - searchVal; + } else { + return searchVal; + } + } + /** * Samples an index from the supplied cdf. * @param cdf The cdf to sample from. diff --git a/Data/pom.xml b/Data/pom.xml index 04e7b2b3c..2fe4d645d 100644 --- a/Data/pom.xml +++ b/Data/pom.xml @@ -64,7 +64,7 @@ org.apache.commons commons-lang3 - 3.12.0 + 3.14.0 org.junit.jupiter @@ -72,16 +72,11 @@ test - - 1.8 - 1.8 - org.apache.maven.plugins maven-jar-plugin - 3.2.0 diff --git a/distribution/pom.xml b/distribution/pom.xml index b3243323e..b498a7b0c 100644 --- a/distribution/pom.xml +++ b/distribution/pom.xml @@ -217,6 +217,11 @@ tribuo-clustering-core ${project.version} + + org.tribuo + tribuo-clustering-gmm + ${project.version} + org.tribuo tribuo-clustering-kmeans From 556be1623b4084d52a7a24b0b138508653f52516 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 15 Apr 2024 21:37:56 -0400 Subject: [PATCH 02/15] More work on GMM training method. --- .../org/tribuo/clustering/gmm/GMMTrainer.java | 122 ++++++------------ .../java/org/tribuo/math/la/DenseVector.java | 16 +++ 2 files changed, 58 insertions(+), 80 deletions(-) diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index c01dd01dd..8df1bfb92 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -26,8 +26,10 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Trainer; +import org.tribuo.WeightedExamples; import org.tribuo.clustering.ClusterID; import org.tribuo.clustering.ImmutableClusteringInfo; +import org.tribuo.math.distance.L2Distance; import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SGDVector; @@ -37,10 +39,7 @@ import org.tribuo.provenance.impl.TrainerProvenanceImpl; import org.tribuo.util.Util; -import java.security.AccessController; -import java.security.PrivilegedAction; import java.time.OffsetDateTime; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -50,8 +49,6 @@ import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; -import java.util.concurrent.ForkJoinWorkerThread; -import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.logging.Level; import java.util.logging.Logger; @@ -72,10 +69,6 @@ *

* The train method will instantiate dense examples as dense vectors, speeding up the computation. *

- * Note parallel training uses a {@link ForkJoinPool} which requires that the Tribuo codebase - * is given the "modifyThread" and "modifyThreadGroup" privileges when running under a - * {@link SecurityManager}. - *

* See: *

  * J. Friedman, T. Hastie, & R. Tibshirani.
@@ -90,12 +83,9 @@
  * PDF
  * 
*/ -public class GMMTrainer implements Trainer { +public class GMMTrainer implements Trainer, WeightedExamples { private static final Logger logger = Logger.getLogger(GMMTrainer.class.getName()); - // Thread factory for the FJP, to allow use with OpenSearch's SecureSM - private static final CustomForkJoinWorkerThreadFactory THREAD_FACTORY = new CustomForkJoinWorkerThreadFactory(); - public enum CovarianceType { /** * Full covariance. @@ -151,6 +141,8 @@ public enum Initialisation { private int trainInvocationCounter; + private static final L2Distance plusPlusDistance = new L2Distance(); + /** * for olcut. */ @@ -234,82 +226,66 @@ public GaussianMixtureModel train(Dataset examples, Map initialiseRandomCentroids(centroids, featureMap, localRNG); - case PLUSPLUS -> initialisePlusPlusCentroids(centroids, data, localRNG, dist); + case PLUSPLUS -> initialisePlusPlusCentroids(centroids, data, localRNG); }; + DenseMatrix[] covarianceMatrices = new DenseMatrix[centroids]; + DenseMatrix.CholeskyFactorization[] precisionFactorizations = new DenseMatrix.CholeskyFactorization[centroids]; + DenseVector mixingDistribution = new DenseVector(centroids); - Map> clusterAssignments = new HashMap<>(); boolean parallel = numThreads > 1; - for (int i = 0; i < centroids; i++) { - clusterAssignments.put(i, parallel ? Collections.synchronizedList(new ArrayList<>()) : new ArrayList<>()); - } - AtomicInteger changeCounter = new AtomicInteger(0); - Consumer eStepFunc = (IntAndVector e) -> { + Consumer eStepFunc = (SGDVector e) -> { double minDist = Double.POSITIVE_INFINITY; - int clusterID = -1; - int id = e.idx; - SGDVector vector = e.vector; for (int j = 0; j < centroids; j++) { - DenseVector cluster = centroidVectors[j]; - double distance = dist.computeDistance(cluster, vector); + DenseVector cluster = meanVectors[j]; + double distance = dist.computeDistance(cluster, e); if (distance < minDist) { minDist = distance; - clusterID = j; } } - - clusterAssignments.get(clusterID).add(id); - if (oldCentre[id] != clusterID) { - // Changed the centroid of this vector. - oldCentre[id] = clusterID; - changeCounter.incrementAndGet(); - } }; + double oldLowerBound = Double.NEGATIVE_INFINITY; + double newLowerBound; boolean converged = false; ForkJoinPool fjp = null; try { if (parallel) { - if (System.getSecurityManager() == null) { - fjp = new ForkJoinPool(numThreads); - } else { - fjp = new ForkJoinPool(numThreads, THREAD_FACTORY, null, false); - } + fjp = new ForkJoinPool(numThreads); } for (int i = 0; (i < iterations) && !converged; i++) { logger.log(Level.FINE,"Beginning iteration " + i); - changeCounter.set(0); - - for (Entry> e : clusterAssignments.entrySet()) { - e.getValue().clear(); - } // E step Stream vecStream = Arrays.stream(data); - Stream intStream = IntStream.range(0, data.length).boxed(); - Stream zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); if (parallel) { - Stream parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel()); try { - fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get(); + fjp.submit(() -> vecStream.parallel().forEach(eStepFunc)).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } } else { - zipStream.forEach(eStepFunc); + vecStream.forEach(eStepFunc); } - logger.log(Level.FINE, "E step completed. " + changeCounter.get() + " words updated."); + logger.log(Level.FINE, i + "th e step completed."); - mStep(fjp, centroidVectors, clusterAssignments, data, weights); + // M step + mStep(fjp, responsibilities, meanVectors, covarianceMatrices, mixingDistribution, precisionFactorizations, data, weights); + logger.log(Level.FINE, i + "th m step completed."); - logger.log(Level.INFO, "Iteration " + i + " completed. " + changeCounter.get() + " examples updated."); + // Compute log likelihood bound + newLowerBound = computeLowerBound(); - if (changeCounter.get() == 0) { + logger.log(Level.INFO, "Iteration " + i + " completed."); + + if (newLowerBound - oldLowerBound < convergenceTolerance) { converged = true; - logger.log(Level.INFO, "K-Means converged at iteration " + i); + logger.log(Level.INFO, "GMM converged at iteration " + i); } + + oldLowerBound = newLowerBound; } } finally { if (fjp != null) { @@ -318,8 +294,10 @@ public GaussianMixtureModel train(Dataset examples, Map counts = new HashMap<>(); - for (Entry> e : clusterAssignments.entrySet()) { - counts.put(e.getKey(), new MutableLong(e.getValue().size())); + for (int i = 0; i < examples.size(); i++) { + int idx = responsibilities.getRow(i).argmax(); + var count = counts.computeIfAbsent(idx, k -> new MutableLong()); + count.increment(); } ImmutableOutputInfo outputMap = new ImmutableClusteringInfo(counts); @@ -327,7 +305,8 @@ public GaussianMixtureModel train(Dataset examples, Map data.length) { throw new IllegalArgumentException("The number of centroids may not exceed the number of samples."); } @@ -411,7 +387,7 @@ private static DenseVector[] initialisePlusPlusCentroids(int centroids, SGDVecto // go through every vector and see if the min distance to the // newest centroid is smaller than previous min distance for vec for (int j = 0; j < data.length; j++) { - double tempDistance = dist.computeDistance(prevCentroid, data[j]); + double tempDistance = plusPlusDistance.computeDistance(prevCentroid, data[j]); minDistancePerVector[j] = Math.min(minDistancePerVector[j], tempDistance); } @@ -450,13 +426,13 @@ private static DenseVector getRandomCentroidFromData(SGDVector[] data, Splittabl /** * Runs the mStep, writing to the {@code centroidVectors} array. * @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel. - * If the fjp is null then the computation is executed sequentially on the main thread. + * If the fjp is null then the computation is executed sequentially. * @param centroidVectors The centroid vectors to write out. - * @param clusterAssignments The current cluster assignments. * @param data The data points. * @param weights The example weights. */ - protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map> clusterAssignments, SGDVector[] data, double[] weights) { + protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, DenseMatrix[] covarianceMatrices, + DenseMatrix.CholeskyFactorization[] precisionFactorizations, DenseVector mixingDistribution, SGDVector[] data, double[] weights) { // M step Consumer>> mStepFunc = (e) -> { DenseVector newCentroid = centroidVectors[e.getKey()]; @@ -474,9 +450,8 @@ protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map>> mStream = clusterAssignments.entrySet().stream(); if (fjp != null) { - Stream>> parallelMStream = StreamUtil.boundParallelism(mStream.parallel()); try { - fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get(); + fjp.submit(() -> mStream.parallel().forEach(mStepFunc)).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } @@ -495,17 +470,4 @@ public TrainerProvenance getProvenance() { return new TrainerProvenanceImpl(this); } - /** - * Tuple of index and position. - */ - record IntAndVector(int idx, SGDVector vector) { } - - /** - * Used to allow FJPs to work with OpenSearch's SecureSM. - */ - private static final class CustomForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory { - public final ForkJoinWorkerThread newThread(ForkJoinPool pool) { - return AccessController.doPrivileged((PrivilegedAction) () -> new ForkJoinWorkerThread(pool) {}); - } - } } \ No newline at end of file diff --git a/Math/src/main/java/org/tribuo/math/la/DenseVector.java b/Math/src/main/java/org/tribuo/math/la/DenseVector.java index 34ea6108c..4f31388a7 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseVector.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseVector.java @@ -259,6 +259,22 @@ public int numActiveElements() { return elements.length; } + /** + * Gets the index of the maximum element. + * @return The index of the maximum element. + */ + public int argmax() { + int idx = -1; + double value = Double.NEGATIVE_INFINITY; + for (int i = 0; i < elements.length; i++) { + if (value < get(i)) { + idx = i; + value = get(i); + } + } + return idx; + } + /** * Performs a reduction from left to right of this vector. *

From 4c68a5235893697b62f230ef10dce92eac3ec030 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 15 Apr 2024 21:38:19 -0400 Subject: [PATCH 03/15] Modernizing KMeansTrainer. --- .../clustering/kmeans/KMeansTrainer.java | 41 +++++-------------- 1 file changed, 10 insertions(+), 31 deletions(-) diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index 1ea08c7a0..6051121bb 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ import org.tribuo.ImmutableFeatureMap; import org.tribuo.ImmutableOutputInfo; import org.tribuo.Trainer; +import org.tribuo.WeightedExamples; import org.tribuo.clustering.ClusterID; import org.tribuo.clustering.ImmutableClusteringInfo; import org.tribuo.math.distance.DistanceType; @@ -90,7 +91,7 @@ * PDF * */ -public class KMeansTrainer implements Trainer { +public class KMeansTrainer implements Trainer, WeightedExamples { private static final Logger logger = Logger.getLogger(KMeansTrainer.class.getName()); // Thread factory for the FJP, to allow use with OpenSearch's SecureSM @@ -300,17 +301,10 @@ public KMeansModel train(Dataset examples, Map ru n++; } - DenseVector[] centroidVectors; - switch (initialisationType) { - case RANDOM: - centroidVectors = initialiseRandomCentroids(centroids, featureMap, localRNG); - break; - case PLUSPLUS: - centroidVectors = initialisePlusPlusCentroids(centroids, data, localRNG, dist); - break; - default: - throw new IllegalStateException("Unknown initialisation" + initialisationType); - } + DenseVector[] centroidVectors = switch (initialisationType) { + case RANDOM -> initialiseRandomCentroids(centroids, featureMap, localRNG); + case PLUSPLUS -> initialisePlusPlusCentroids(centroids, data, localRNG, dist); + }; Map> clusterAssignments = new HashMap<>(); boolean parallel = numThreads > 1; @@ -364,9 +358,8 @@ public KMeansModel train(Dataset examples, Map ru Stream intStream = IntStream.range(0, data.length).boxed(); Stream zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); if (parallel) { - Stream parallelZipStream = StreamUtil.boundParallelism(zipStream.parallel()); try { - fjp.submit(() -> parallelZipStream.forEach(eStepFunc)).get(); + fjp.submit(() -> zipStream.parallel().forEach(eStepFunc)).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } @@ -550,9 +543,8 @@ protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, Map>> mStream = clusterAssignments.entrySet().stream(); if (fjp != null) { - Stream>> parallelMStream = StreamUtil.boundParallelism(mStream.parallel()); try { - fjp.submit(() -> parallelMStream.forEach(mStepFunc)).get(); + fjp.submit(() -> mStream.parallel().forEach(mStepFunc)).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } @@ -574,20 +566,7 @@ public TrainerProvenance getProvenance() { /** * Tuple of index and position. One day it'll be a record, but not today. */ - static class IntAndVector { - final int idx; - final SGDVector vector; - - /** - * Constructs an index and vector tuple. - * @param idx The index. - * @param vector The vector. - */ - public IntAndVector(int idx, SGDVector vector) { - this.idx = idx; - this.vector = vector; - } - } + record IntAndVector(int idx, SGDVector vector) { } /** * Used to allow FJPs to work with OpenSearch's SecureSM. From f07b5245b6835acbb7f3cf9264125b151da6dc5b Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 16 Apr 2024 21:10:26 -0400 Subject: [PATCH 04/15] Moving logsumexp from ChainHelper to a method on DenseVector. --- .../classification/sgd/crf/ChainHelper.java | 33 +--------------- .../java/org/tribuo/math/la/DenseVector.java | 39 ++++++++++++++++++- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/crf/ChainHelper.java b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/crf/ChainHelper.java index 33ff7fa74..935b29df5 100644 --- a/Classification/SGD/src/main/java/org/tribuo/classification/sgd/crf/ChainHelper.java +++ b/Classification/SGD/src/main/java/org/tribuo/classification/sgd/crf/ChainHelper.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -178,36 +178,7 @@ public static ChainViterbiResults viterbi(ChainCliqueValues scores) { * @return log sum exp input[i]. */ public static double sumLogProbs(DenseVector input) { - double LOG_TOLERANCE = 30.0; - - double maxValue = input.get(0); - int maxIdx = 0; - for (int i = 1; i < input.size(); i++) { - double value = input.get(i); - if (value > maxValue) { - maxValue = value; - maxIdx = i; - } - } - if (maxValue == Double.NEGATIVE_INFINITY) { - return maxValue; - } - - boolean anyAdded = false; - double intermediate = 0.0; - double cutoff = maxValue - LOG_TOLERANCE; - for (int i = 0; i < input.size(); i++) { - double value = input.get(i); - if (value >= cutoff && i != maxIdx && !Double.isInfinite(value)) { - anyAdded = true; - intermediate += Math.exp(value - maxValue); - } - } - if (anyAdded) { - return maxValue + Math.log1p(intermediate); - } else { - return maxValue; - } + return input.logSumExp(); } /** diff --git a/Math/src/main/java/org/tribuo/math/la/DenseVector.java b/Math/src/main/java/org/tribuo/math/la/DenseVector.java index 4f31388a7..992bcfae0 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseVector.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseVector.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -626,6 +626,43 @@ public void expNormalize(double total) { } } + /** + * Sums log probabilities. + * @return log sum exp input[i]. + */ + public double logSumExp() { + final double LOG_TOLERANCE = 30.0; + + double maxValue = get(0); + int maxIdx = 0; + for (int i = 1; i < elements.length; i++) { + double value = get(i); + if (value > maxValue) { + maxValue = value; + maxIdx = i; + } + } + if (maxValue == Double.NEGATIVE_INFINITY) { + return maxValue; + } + + boolean anyAdded = false; + double intermediate = 0.0; + double cutoff = maxValue - LOG_TOLERANCE; + for (int i = 0; i < elements.length; i++) { + double value = get(i); + if (value >= cutoff && i != maxIdx && !Double.isInfinite(value)) { + anyAdded = true; + intermediate += Math.exp(value - maxValue); + } + } + if (anyAdded) { + return maxValue + Math.log1p(intermediate); + } else { + return maxValue; + } + } + @Override public String toString() { StringBuilder buffer = new StringBuilder(); From a3d21406b0d47c4056fcc8660b15595de52bb4c8 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 16 Apr 2024 21:11:32 -0400 Subject: [PATCH 05/15] Filling out GaussianMixtureModel class. --- .../org/tribuo/clustering/gmm/GMMOptions.java | 5 +- .../org/tribuo/clustering/gmm/GMMTrainer.java | 26 +-- .../clustering/gmm/GaussianMixtureModel.java | 96 +++++++-- .../org/tribuo/clustering/gmm/TrainTest.java | 4 +- .../protos/tribuo-clustering-gmm.proto | 6 + Core/src/main/java/org/tribuo/util/Util.java | 2 +- .../MultivariateNormalDistribution.java | 189 +++++++++++++++--- 7 files changed, 256 insertions(+), 72 deletions(-) diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMOptions.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMOptions.java index 5eb27cd89..bf1f3129f 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMOptions.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMOptions.java @@ -20,7 +20,8 @@ import com.oracle.labs.mlrg.olcut.config.Options; import org.tribuo.Trainer; import org.tribuo.clustering.gmm.GMMTrainer.Initialisation; -import org.tribuo.clustering.gmm.GMMTrainer.CovarianceType; +import org.tribuo.math.distributions.MultivariateNormalDistribution.CovarianceType; +import org.tribuo.math.distributions.MultivariateNormalDistribution; import java.util.logging.Logger; @@ -44,7 +45,7 @@ public class GMMOptions implements Options { * The covariance type of the Gaussians. */ @Option(charName = 'v', longName = "covariance-type", usage = "Set the covariance type.") - public CovarianceType covarianceType = CovarianceType.DIAGONAL; + public CovarianceType covarianceType = MultivariateNormalDistribution.CovarianceType.DIAGONAL; /** * Initialisation function in GMM. Defaults to RANDOM. */ diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index 8df1bfb92..caf7c4e54 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -20,7 +20,6 @@ import com.oracle.labs.mlrg.olcut.config.PropertyException; import com.oracle.labs.mlrg.olcut.provenance.Provenance; import com.oracle.labs.mlrg.olcut.util.MutableLong; -import com.oracle.labs.mlrg.olcut.util.StreamUtil; import org.tribuo.Dataset; import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; @@ -30,6 +29,7 @@ import org.tribuo.clustering.ClusterID; import org.tribuo.clustering.ImmutableClusteringInfo; import org.tribuo.math.distance.L2Distance; +import org.tribuo.math.distributions.MultivariateNormalDistribution; import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SGDVector; @@ -52,7 +52,6 @@ import java.util.function.Consumer; import java.util.logging.Level; import java.util.logging.Logger; -import java.util.stream.IntStream; import java.util.stream.Stream; /** @@ -86,21 +85,6 @@ public class GMMTrainer implements Trainer, WeightedExamples { private static final Logger logger = Logger.getLogger(GMMTrainer.class.getName()); - public enum CovarianceType { - /** - * Full covariance. - */ - FULL, - /** - * Diagonal covariance. - */ - DIAGONAL, - /** - * Spherical covariance. - */ - SPHERICAL - } - /** * Possible initialization functions. */ @@ -126,7 +110,7 @@ public enum Initialisation { private double convergenceTolerance = 1e-3f; @Config(description = "The type of covariance matrix to fit.") - private CovarianceType covarianceType = CovarianceType.DIAGONAL; + private MultivariateNormalDistribution.CovarianceType covarianceType = MultivariateNormalDistribution.CovarianceType.DIAGONAL; @Config(description = "The centroid initialisation method to use.") private Initialisation initialisationType = Initialisation.RANDOM; @@ -157,7 +141,7 @@ private GMMTrainer() { } * @param seed The random seed. */ public GMMTrainer(int centroids, int iterations, int numThreads, long seed) { - this(centroids,iterations,CovarianceType.DIAGONAL,Initialisation.RANDOM,1e-3,numThreads,seed); + this(centroids,iterations, MultivariateNormalDistribution.CovarianceType.DIAGONAL,Initialisation.RANDOM,1e-3,numThreads,seed); } /** @@ -169,7 +153,7 @@ public GMMTrainer(int centroids, int iterations, int numThreads, long seed) { * @param numThreads The number of threads. * @param seed The random seed. */ - public GMMTrainer(int centroids, int iterations, CovarianceType covarianceType, Initialisation initialisationType, double tolerance, int numThreads, long seed) { + public GMMTrainer(int centroids, int iterations, MultivariateNormalDistribution.CovarianceType covarianceType, Initialisation initialisationType, double tolerance, int numThreads, long seed) { this.centroids = centroids; this.iterations = iterations; this.covarianceType = covarianceType; @@ -306,7 +290,7 @@ public GaussianMixtureModel train(Dataset examples, Map { private final DenseVector[] meanVectors; - private final DenseMatrix[] covarianceMatrices; + private final Tensor[] covarianceMatrices; private final DenseVector mixingDistribution; + private final MultivariateNormalDistribution.CovarianceType covarianceType; + + private final MultivariateNormalDistribution[] distributions; + GaussianMixtureModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo outputIDInfo, DenseVector[] meanVectors, - DenseMatrix[] covarianceMatrices, DenseVector mixingDistribution) { - super(name,description,featureIDMap,outputIDInfo,false); + Tensor[] covarianceMatrices, DenseVector mixingDistribution, + MultivariateNormalDistribution.CovarianceType covarianceType) { + super(name,description,featureIDMap,outputIDInfo,true); this.meanVectors = meanVectors; this.covarianceMatrices = covarianceMatrices; this.mixingDistribution = mixingDistribution; + this.covarianceType = covarianceType; + this.distributions = new MultivariateNormalDistribution[meanVectors.length]; + for (int i = 0; i < meanVectors.length; i++) { + // seed is 1 as we call the sample method which uses the supplied RNG, and no eigen decomposition + // because we use the cholesky to fit the GMM. + distributions[i] = new MultivariateNormalDistribution(meanVectors[i], covarianceMatrices[i], + covarianceType, 1L, false); + } } /** @@ -108,6 +124,7 @@ public static GaussianMixtureModel deserializeFromProto(int version, String clas ImmutableFeatureMap featureDomain = carrier.featureDomain(); + MultivariateNormalDistribution.CovarianceType covarianceType = MultivariateNormalDistribution.CovarianceType.fromValue(proto.getCovarianceTypeValue()); final int means = proto.getMeanVectorsCount(); if (means == 0) { @@ -128,18 +145,31 @@ public static GaussianMixtureModel deserializeFromProto(int version, String clas throw new IllegalStateException("Invalid protobuf, expected centroid to be a dense vector, found " + centroidTensor.getClass()); } } - DenseMatrix[] covariances = new DenseMatrix[means]; + Tensor[] covariances = new Tensor[means]; List covarianceProtos = proto.getCovarianceMatricesList(); for (int i = 0; i < covariances.length; i++) { Tensor covarianceTensor = Tensor.deserialize(covarianceProtos.get(i)); - if (covarianceTensor instanceof DenseMatrix covariance) { + if (covarianceType == MultivariateNormalDistribution.CovarianceType.FULL + && covarianceTensor instanceof DenseMatrix covariance) { if (covariance.getDimension1Size() != featureDomain.size() || covariance.getDimension2Size() != featureDomain.size()) { throw new IllegalStateException("Invalid protobuf, covariance was not square or did not contain all " + - "the features, found " + Arrays.toString(covariance.getShape()) + " expected [" + featureDomain.size() + ", " + featureDomain.size() +"]"); + "the features, found " + Arrays.toString(covariance.getShape()) + " expected [" + featureDomain.size() + ", " + featureDomain.size() + "]"); + } + covariances[i] = covariance; + } else if (covarianceType == MultivariateNormalDistribution.CovarianceType.DIAGONAL + && covarianceTensor instanceof DenseVector covariance) { + if (covariance.size() != featureDomain.size()) { + throw new IllegalStateException("Invalid protobuf, covariance was not diagonal, found " + covariance.size() + " elements not " + featureDomain.size() + "."); + } + covariances[i] = covariance; + } else if (covarianceType == MultivariateNormalDistribution.CovarianceType.SPHERICAL + && covarianceTensor instanceof DenseVector covariance) { + if (covariance.size() != 1) { + throw new IllegalStateException("Invalid protobuf, covariance was not spherical, found " + covariance.size() + " elements not 1."); } covariances[i] = covariance; } else { - throw new IllegalStateException("Invalid protobuf, expected covariance to be a dense matrix, found " + covarianceTensor.getClass()); + throw new IllegalStateException("Invalid protobuf, expected covariance to match covarianceType, found " + covarianceTensor.getClass() + " for type " + covarianceType); } } Tensor mixing = Tensor.deserialize(proto.getMixingDistribution()); @@ -154,7 +184,7 @@ public static GaussianMixtureModel deserializeFromProto(int version, String clas throw new IllegalStateException("Invalid protobuf, expected mixing distribution to be a dense vector, found " + mixing.getClass()); } - return new GaussianMixtureModel(carrier.name(), carrier.provenance(), featureDomain, outputDomain, centroids, covariances, mixingVec); + return new GaussianMixtureModel(carrier.name(), carrier.provenance(), featureDomain, outputDomain, centroids, covariances, mixingVec, covarianceType); } /** @@ -181,13 +211,13 @@ public DenseVector[] getMeanVectors() { /** * Returns a copy of the covariances. *

- * This method provides direct access to the covariance matrices + * This method provides direct access to the covariances * for use in downstream processing, users need to map the indices using * Tribuo's internal ids themselves. * @return The covariances. */ - public DenseMatrix[] getCovariances() { - DenseMatrix[] copies = new DenseMatrix[covarianceMatrices.length]; + public Tensor[] getCovariances() { + Tensor[] copies = new Tensor[covarianceMatrices.length]; for (int i = 0; i < copies.length; i++) { copies[i] = covarianceMatrices[i].copy(); @@ -235,8 +265,32 @@ public Prediction predict(Example example) { } // generate cluster responsibilities and normalize into a distribution + DenseVector responsibilities = new DenseVector(meanVectors[0].size()); + + // compute log probs + for (int i = 0; i < distributions.length; i++) { + responsibilities.set(i, distributions[i].logProbability(vector)); + } + + // add mixing distribution + responsibilities.intersectAndAddInPlace(mixingDistribution, Math::log); + + // convert from log space into probabilities + double sum = responsibilities.logSumExp(); + responsibilities.scalarAddInPlace(-sum); + responsibilities.foreachInPlace(Math::exp); + + // compute output prediction + ClusterID max = null; + Map scores = new HashMap<>(); + for (int i = 0; i < distributions.length; i++) { + ClusterID tmp = new ClusterID(i, responsibilities.get(i)); + if (max == null || tmp.getScore() > max.getScore()) { + max = tmp; + } + } - return new Prediction<>(new ClusterID(id),vector.size(),example); + return new Prediction<>(max,scores,vector.size(),example,generatesProbabilities); } @Override @@ -257,11 +311,18 @@ public Optional> getExcuse(Example example) { */ public List> sample(int numSamples, RandomGenerator rng) { // Convert mixing distribution into CDF + double[] cdf = Util.generateCDF(mixingDistribution.toArray()); - // Sample from mixing distribution + List> output = new ArrayList<>(); + for (int i = 0; i < numSamples; i++) { + // Sample from mixing distribution + int dist = Util.sampleFromCDF(cdf, rng); - // Sample from appropriate MultivariateNormalDistribution + // Sample from appropriate MultivariateNormalDistribution + } + + return output; } /** @@ -299,10 +360,11 @@ public ModelProto serialize() { for (DenseVector e : meanVectors) { modelBuilder.addMeanVectors(e.serialize()); } - for (DenseMatrix e : covarianceMatrices) { + for (Tensor e : covarianceMatrices) { modelBuilder.addCovarianceMatrices(e.serialize()); } modelBuilder.setMixingDistribution(mixingDistribution.serialize()); + modelBuilder.setCovarianceTypeValue(covarianceType.value()); ModelProto.Builder builder = ModelProto.newBuilder(); builder.setSerializedData(Any.pack(modelBuilder.build())); @@ -315,13 +377,13 @@ public ModelProto serialize() { @Override protected GaussianMixtureModel copy(String newName, ModelProvenance newProvenance) { DenseVector[] newCentroids = new DenseVector[meanVectors.length]; - DenseMatrix[] newCovariance = new DenseMatrix[meanVectors.length]; + Tensor[] newCovariance = new Tensor[meanVectors.length]; for (int i = 0; i < meanVectors.length; i++) { newCentroids[i] = meanVectors[i].copy(); newCovariance[i] = covarianceMatrices[i].copy(); } return new GaussianMixtureModel(newName,newProvenance,featureIDMap,outputIDInfo, - newCentroids,newCovariance,mixingDistribution.copy()); + newCentroids,newCovariance,mixingDistribution.copy(),covarianceType); } } diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/TrainTest.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/TrainTest.java index 3b4d81fab..7fabd162e 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/TrainTest.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/TrainTest.java @@ -27,9 +27,9 @@ import org.tribuo.clustering.ClusterID; import org.tribuo.clustering.ClusteringFactory; import org.tribuo.clustering.evaluation.ClusteringEvaluation; -import org.tribuo.clustering.gmm.GMMTrainer.CovarianceType; import org.tribuo.clustering.gmm.GMMTrainer.Initialisation; import org.tribuo.data.DataOptions; +import org.tribuo.math.distributions.MultivariateNormalDistribution; import java.io.IOException; import java.util.logging.Logger; @@ -70,7 +70,7 @@ public String getOptionsDescription() { * The covariance type of the gaussians. */ @Option(charName = 'v', longName = "covariance-type", usage = "Set the covariance type.") - public CovarianceType covarianceType = CovarianceType.DIAGONAL; + public MultivariateNormalDistribution.CovarianceType covarianceType = MultivariateNormalDistribution.CovarianceType.DIAGONAL; /** * Type of initialisation to use for centroids. */ diff --git a/Clustering/GMM/src/main/resources/protos/tribuo-clustering-gmm.proto b/Clustering/GMM/src/main/resources/protos/tribuo-clustering-gmm.proto index a41e874a7..342ff16c0 100644 --- a/Clustering/GMM/src/main/resources/protos/tribuo-clustering-gmm.proto +++ b/Clustering/GMM/src/main/resources/protos/tribuo-clustering-gmm.proto @@ -34,8 +34,14 @@ import "tribuo-math.proto"; GaussianMixtureModel proto */ message GaussianMixtureModelProto { + enum CovarianceTypeProto { + FULL = 0; + DIAGONAL = 1; + SPHERICAL = 2; + } tribuo.core.ModelDataProto metadata = 1; tribuo.math.TensorProto mixing_distribution = 2; repeated tribuo.math.TensorProto mean_vectors = 3; repeated tribuo.math.TensorProto covariance_matrices = 4; + CovarianceTypeProto covariance_type = 5; } diff --git a/Core/src/main/java/org/tribuo/util/Util.java b/Core/src/main/java/org/tribuo/util/Util.java index 1a3a3ed7a..2db850193 100644 --- a/Core/src/main/java/org/tribuo/util/Util.java +++ b/Core/src/main/java/org/tribuo/util/Util.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java index 91fb0fe97..0e56ceb8c 100644 --- a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java +++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,10 +19,13 @@ import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseSparseMatrix; import org.tribuo.math.la.DenseVector; +import org.tribuo.math.la.SGDVector; +import org.tribuo.math.la.Tensor; import java.util.Arrays; import java.util.Optional; import java.util.Random; +import java.util.random.RandomGenerator; /** * A class for sampling from multivariate normal distributions. @@ -32,16 +35,19 @@ public final class MultivariateNormalDistribution { private final long seed; private final Random rng; private final DenseVector means; - private final DenseMatrix covariance; + private final double variance; + private final DenseVector covarianceVector; + private final DenseMatrix covarianceMatrix; private final DenseMatrix samplingCovariance; private final boolean eigenDecomposition; + private final CovarianceType type; /** * Constructs a multivariate normal distribution that can be sampled from. *

* Throws {@link IllegalArgumentException} if the covariance matrix is not positive definite. *

- * Uses a {@link org.tribuo.math.la.DenseMatrix.CholeskyFactorization} to compute the sampling + * Uses a {@link DenseMatrix.CholeskyFactorization} to compute the sampling * covariance matrix. * @param means The mean vector. * @param covariance The covariance matrix. @@ -70,7 +76,7 @@ public MultivariateNormalDistribution(double[] means, double[][] covariance, lon *

* Throws {@link IllegalArgumentException} if the covariance matrix is not positive definite. *

- * Uses a {@link org.tribuo.math.la.DenseMatrix.CholeskyFactorization} to compute the sampling + * Uses a {@link DenseMatrix.CholeskyFactorization} to compute the sampling * covariance matrix. * @param means The mean vector. * @param covariance The covariance matrix. @@ -91,34 +97,92 @@ public MultivariateNormalDistribution(DenseVector means, DenseMatrix covariance, * rather than a cholesky factorization. */ public MultivariateNormalDistribution(DenseVector means, DenseMatrix covariance, long seed, boolean eigenDecomposition) { + this(means, covariance, CovarianceType.FULL, seed, eigenDecomposition); + } + + /** + * Constructs a multivariate normal distribution that can be sampled from. + *

+ * Throws {@link IllegalArgumentException} if the covariance matrix is not positive definite. + * @param means The mean vector. + * @param covariance The covariance matrix. If type is {@link CovarianceType#FULL} must be a {@link DenseMatrix}, + * if {@link CovarianceType#DIAGONAL} or {@link CovarianceType#SPHERICAL} must be a + * {@link DenseVector}. Spherical covariances should have a single element dense vector. + * @param type The covariance type. + * @param seed The RNG seed. + * @param eigenDecomposition If true use an eigen decomposition to compute the sampling covariance matrix + * rather than a cholesky factorization, if it's a full covariance. + */ + public MultivariateNormalDistribution(DenseVector means, Tensor covariance, CovarianceType type, long seed, boolean eigenDecomposition) { this.seed = seed; this.rng = new Random(seed); this.means = means.copy(); - this.covariance = covariance.copy(); - if (this.covariance.getDimension1Size() != this.means.size() || this.covariance.getDimension2Size() != this.means.size()) { - throw new IllegalArgumentException("Covariance matrix must be square and the same dimension as the mean vector. Mean vector size = " + means.size() + ", covariance size = " + Arrays.toString(this.covariance.getShape())); - } this.eigenDecomposition = eigenDecomposition; - if (eigenDecomposition) { - Optional factorization = this.covariance.eigenDecomposition(); - if (factorization.isPresent() && factorization.get().positiveEigenvalues()) { - DenseVector eigenvalues = factorization.get().eigenvalues(); - // rows are eigenvectors - DenseMatrix eigenvectors = new DenseMatrix(factorization.get().eigenvectors()); - // scale eigenvectors by sqrt of eigenvalues - eigenvalues.foreachInPlace(Math::sqrt); - DenseSparseMatrix diagonal = DenseSparseMatrix.createDiagonal(eigenvalues);; - this.samplingCovariance = eigenvectors.matrixMultiply(diagonal).matrixMultiply(eigenvectors,false,true); - } else { - throw new IllegalArgumentException("Covariance matrix is not positive definite."); + this.type = type; + switch (type) { + case FULL -> { + if (!(covariance instanceof DenseMatrix)) { + throw new IllegalArgumentException("Covariance matrix must be a square matrix for full covariance, found " + covariance.getClass()); + } + this.covarianceMatrix = (DenseMatrix) covariance.copy(); + if (this.covarianceMatrix.getDimension1Size() != this.means.size() || this.covarianceMatrix.getDimension2Size() != this.means.size()) { + throw new IllegalArgumentException("Covariance matrix must be square and the same dimension as the mean vector. Mean vector size = " + means.size() + ", covariance size = " + Arrays.toString(this.covarianceMatrix.getShape())); + } + if (eigenDecomposition) { + Optional factorization = this.covarianceMatrix.eigenDecomposition(); + if (factorization.isPresent() && factorization.get().positiveEigenvalues()) { + DenseVector eigenvalues = factorization.get().eigenvalues(); + // rows are eigenvectors + DenseMatrix eigenvectors = new DenseMatrix(factorization.get().eigenvectors()); + // scale eigenvectors by sqrt of eigenvalues + eigenvalues.foreachInPlace(Math::sqrt); + DenseSparseMatrix diagonal = DenseSparseMatrix.createDiagonal(eigenvalues); + this.samplingCovariance = eigenvectors.matrixMultiply(diagonal).matrixMultiply(eigenvectors,false,true); + } else { + throw new IllegalArgumentException("Covariance matrix is not positive definite."); + } + } else { + Optional factorization = this.covarianceMatrix.choleskyFactorization(); + if (factorization.isPresent()) { + this.samplingCovariance = factorization.get().lMatrix(); + } else { + throw new IllegalArgumentException("Covariance matrix is not positive definite."); + } + } + // set unused variables. + this.covarianceVector = null; + this.variance = Double.NaN; } - } else { - Optional factorization = this.covariance.choleskyFactorization(); - if (factorization.isPresent()) { - this.samplingCovariance = factorization.get().lMatrix(); - } else { - throw new IllegalArgumentException("Covariance matrix is not positive definite."); + case DIAGONAL -> { + if (!(covariance instanceof DenseVector)) { + throw new IllegalArgumentException("Covariance must be a vector for diagonal covariance, found " + covariance.getClass()); + } + this.covarianceVector = (DenseVector) covariance.copy(); + if (this.covarianceVector.size() != this.means.size()) { + throw new IllegalArgumentException("Covariance must be a vector and the same dimension as the mean vector. Mean vector size = " + means.size() + ", covariance size = " + this.covarianceVector.size()); + } + + // set unused variables. + this.covarianceMatrix = null; + this.samplingCovariance = null; + this.variance = Double.NaN; } + case SPHERICAL -> { + if (covariance instanceof DenseVector vec) { + if (vec.size() != 1) { + throw new IllegalArgumentException("Covariance must be a single element vector for spherical covariance. Found " + vec.size()); + } + } else { + throw new IllegalArgumentException("Covariance must be a single element vector for spherical covariance, found " + covariance.getClass()); + } + this.variance = Double.NaN; + + // set unused variables. + this.covarianceVector = null; + this.covarianceMatrix = null; + this.samplingCovariance = null; + } + default -> throw new IllegalArgumentException("Unknown covariance type " + type); } } @@ -127,12 +191,24 @@ public MultivariateNormalDistribution(DenseVector means, DenseMatrix covariance, * @return A sample from this distribution. */ public DenseVector sampleVector() { + return sampleVector(rng); + } + + /** + * Sample a vector from this multivariate normal distribution. + * @return A sample from this distribution. + */ + public DenseVector sampleVector(RandomGenerator otherRNG) { DenseVector sampled = new DenseVector(means.size()); for (int i = 0; i < means.size(); i++) { - sampled.set(i,rng.nextGaussian()); + sampled.set(i, otherRNG.nextGaussian()); } - sampled = samplingCovariance.leftMultiply(sampled); + switch (type) { + case FULL -> sampled = samplingCovariance.leftMultiply(sampled); + case DIAGONAL -> sampled.hadamardProductInPlace(covarianceVector); + case SPHERICAL -> sampled.scaleInPlace(variance); + } return means.add(sampled); } @@ -145,8 +221,63 @@ public double[] sampleArray() { return sampleVector().toArray(); } + /** + * Compute the log probability of the input under this multivariate normal distribution. + * @param input The input to compute. + * @return The log probability. + */ + public double logProbability(SGDVector input) { + + } + @Override public String toString() { - return "MultivariateNormal(mean="+means+",covariance="+covariance+",seed="+seed+",useEigenDecomposition="+eigenDecomposition+")"; + return switch (type) { + case FULL -> "MultivariateNormal(mean="+means+",covariance="+covarianceMatrix+",seed="+seed+",useEigenDecomposition="+eigenDecomposition+",type="+type+")"; + case DIAGONAL -> "MultivariateNormal(mean="+means+",covariance="+covarianceVector+",seed="+seed+",useEigenDecomposition="+eigenDecomposition+",type="+type+")"; + case SPHERICAL -> "MultivariateNormal(mean="+means+",covariance="+variance+",seed="+seed+",useEigenDecomposition="+eigenDecomposition+",type="+type+")"; + }; + } + + /** + * Type of the covariance in a multivariate normal distribution. + */ + public enum CovarianceType { + /** + * Full covariance. + */ + FULL(0), + /** + * Diagonal covariance. + */ + DIAGONAL(1), + /** + * Spherical covariance. + */ + SPHERICAL(2); + + private final int value; + private CovarianceType(int value) { + this.value = value; + } + + /** + * The enum value used for serialization. + * @return The enum value. + */ + public int value() { + return value; + } + + public static CovarianceType fromValue(int value) { + CovarianceType[] values = CovarianceType.values(); + for (CovarianceType t : values) { + if (t.value == value) { + return t; + } + } + // Failed to find the enum. + throw new IllegalStateException("Invalid CovarianceType enum value, found " + value); + } } } From 3f0ec5453f76de8592176bd13052f8a90cffa7df Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 19 Apr 2024 22:18:17 -0400 Subject: [PATCH 06/15] Filling out GMMTrainer.train. --- .../org/tribuo/clustering/gmm/GMMTrainer.java | 191 ++++++++------- .../org/tribuo/clustering/gmm/TestGMM.java | 219 ++++++++++++++++++ .../clustering/kmeans/KMeansTrainer.java | 2 +- .../MultivariateNormalDistribution.java | 25 ++ .../java/org/tribuo/math/la/DenseVector.java | 19 ++ 5 files changed, 372 insertions(+), 84 deletions(-) create mode 100644 Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index caf7c4e54..e6e778599 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -20,6 +20,7 @@ import com.oracle.labs.mlrg.olcut.config.PropertyException; import com.oracle.labs.mlrg.olcut.provenance.Provenance; import com.oracle.labs.mlrg.olcut.util.MutableLong; +import com.oracle.labs.mlrg.olcut.util.StreamUtil; import org.tribuo.Dataset; import org.tribuo.Example; import org.tribuo.ImmutableFeatureMap; @@ -34,6 +35,8 @@ import org.tribuo.math.la.DenseVector; import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.SparseVector; +import org.tribuo.math.la.Tensor; +import org.tribuo.math.la.VectorTuple; import org.tribuo.provenance.ModelProvenance; import org.tribuo.provenance.TrainerProvenance; import org.tribuo.provenance.impl.TrainerProvenanceImpl; @@ -43,15 +46,14 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; -import java.util.function.Consumer; +import java.util.function.ToDoubleFunction; import java.util.logging.Level; import java.util.logging.Logger; +import java.util.stream.IntStream; import java.util.stream.Stream; /** @@ -62,9 +64,8 @@ * It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments * can only be retrieved from the model after training, and require re-evaluating each example. *

- * The Trainer has a parameterised distance function, and a selectable number - * of threads used in the training step. The thread pool is local to an invocation of train, - * so there can be multiple concurrent trainings. + * The Trainer has a selectable number of threads used in the training step. + * The thread pool is local to an invocation of train, so there can be multiple concurrent trainings. *

* The train method will instantiate dense examples as dense vectors, speeding up the computation. *

@@ -100,8 +101,8 @@ public enum Initialisation { PLUSPLUS } - @Config(mandatory = true, description = "Number of centroids.") - private int centroids; + @Config(mandatory = true, description = "Number of Gaussians to fit.") + private int numGaussians; @Config(mandatory = true, description = "The number of iterations to run.") private int iterations; @@ -112,7 +113,7 @@ public enum Initialisation { @Config(description = "The type of covariance matrix to fit.") private MultivariateNormalDistribution.CovarianceType covarianceType = MultivariateNormalDistribution.CovarianceType.DIAGONAL; - @Config(description = "The centroid initialisation method to use.") + @Config(description = "The cluster initialisation method to use.") private Initialisation initialisationType = Initialisation.RANDOM; @Config(description = "The number of threads to use for training.") @@ -133,28 +134,28 @@ public enum Initialisation { private GMMTrainer() { } /** - * Constructs a K-Means trainer using the supplied parameters and the default random initialisation. + * Constructs a Gaussian Mixture Model trainer using the supplied parameters and the default random initialisation. * - * @param centroids The number of centroids to use. + * @param numGaussians The number of centroids to use. * @param iterations The maximum number of iterations. * @param numThreads The number of threads. * @param seed The random seed. */ - public GMMTrainer(int centroids, int iterations, int numThreads, long seed) { - this(centroids,iterations, MultivariateNormalDistribution.CovarianceType.DIAGONAL,Initialisation.RANDOM,1e-3,numThreads,seed); + public GMMTrainer(int numGaussians, int iterations, int numThreads, long seed) { + this(numGaussians,iterations, MultivariateNormalDistribution.CovarianceType.DIAGONAL,Initialisation.RANDOM,1e-3,numThreads,seed); } /** - * Constructs a K-Means trainer using the supplied parameters. + * Constructs a Gaussian Mixture Model trainer using the supplied parameters. * - * @param centroids The number of centroids to use. + * @param numGaussians The number of centroids to use. * @param iterations The maximum number of iterations. * @param initialisationType The centroid initialization method. * @param numThreads The number of threads. * @param seed The random seed. */ - public GMMTrainer(int centroids, int iterations, MultivariateNormalDistribution.CovarianceType covarianceType, Initialisation initialisationType, double tolerance, int numThreads, long seed) { - this.centroids = centroids; + public GMMTrainer(int numGaussians, int iterations, MultivariateNormalDistribution.CovarianceType covarianceType, Initialisation initialisationType, double tolerance, int numThreads, long seed) { + this.numGaussians = numGaussians; this.iterations = iterations; this.covarianceType = covarianceType; this.initialisationType = initialisationType; @@ -171,8 +172,8 @@ public GMMTrainer(int centroids, int iterations, MultivariateNormalDistribution. public synchronized void postConfig() { this.rng = new SplittableRandom(seed); - if (centroids < 1) { - throw new PropertyException("centroids", "Centroids must be positive, found " + centroids); + if (numGaussians < 1) { + throw new PropertyException("centroids", "Centroids must be positive, found " + numGaussians); } } @@ -196,6 +197,7 @@ public GaussianMixtureModel train(Dataset examples, Map examples, Map initialiseRandomCentroids(centroids, featureMap, localRNG); - case PLUSPLUS -> initialisePlusPlusCentroids(centroids, data, localRNG); + final DenseVector[] meanVectors = switch (initialisationType) { + case RANDOM -> initialiseRandomCentroids(numGaussians, featureMap, localRNG); + case PLUSPLUS -> initialisePlusPlusCentroids(numGaussians, data, localRNG); }; - DenseMatrix[] covarianceMatrices = new DenseMatrix[centroids]; - DenseMatrix.CholeskyFactorization[] precisionFactorizations = new DenseMatrix.CholeskyFactorization[centroids]; - DenseVector mixingDistribution = new DenseVector(centroids); + final Tensor[] covarianceMatrices = new Tensor[numGaussians]; + DenseMatrix.CholeskyFactorization[] precisionFactorizations = new DenseMatrix.CholeskyFactorization[numGaussians]; + final DenseVector mixingDistribution = new DenseVector(numGaussians); boolean parallel = numThreads > 1; - Consumer eStepFunc = (SGDVector e) -> { - double minDist = Double.POSITIVE_INFINITY; - for (int j = 0; j < centroids; j++) { - DenseVector cluster = meanVectors[j]; - double distance = dist.computeDistance(cluster, e); - if (distance < minDist) { - minDist = distance; - } + ToDoubleFunction eStepFunc = (IntAndVector e) -> { + DenseVector curResponsibilities = responsibilities[e.idx]; + // compute log probs + for (int i = 0; i < meanVectors.length; i++) { + curResponsibilities.set(i, MultivariateNormalDistribution.logProbability(e.vector, meanVectors[i], covarianceMatrices[i], precisionFactorizations[i], covarianceType)); } + + // add mixing distribution + curResponsibilities.intersectAndAddInPlace(mixingDistribution, Math::log); + + // normalize log probabilities + double sum = curResponsibilities.logSumExp(); + curResponsibilities.scalarAddInPlace(-sum); + + // exponentiate them + curResponsibilities.foreachInPlace(Math::exp); + + return sum; }; double oldLowerBound = Double.NEGATIVE_INFINITY; @@ -243,28 +254,75 @@ public GaussianMixtureModel train(Dataset examples, Map vecStream = Arrays.stream(data); + Stream intStream = IntStream.range(0, data.length).boxed(); + Stream zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); if (parallel) { try { - fjp.submit(() -> vecStream.parallel().forEach(eStepFunc)).get(); + normSum = fjp.submit(() -> zipStream.parallel().mapToDouble(eStepFunc).sum()).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } } else { - vecStream.forEach(eStepFunc); + normSum = zipStream.mapToDouble(eStepFunc).sum(); } logger.log(Level.FINE, i + "th e step completed."); + // compute lower bound + newLowerBound = normSum / examples.size(); + // M step - mStep(fjp, responsibilities, meanVectors, covarianceMatrices, mixingDistribution, precisionFactorizations, data, weights); - logger.log(Level.FINE, i + "th m step completed."); + // compute new mixing distribution + DenseVector zeroVector = new DenseVector(numGaussians); + Stream resStream = Arrays.stream(responsibilities); + DenseVector newMixingDistribution; + if (parallel) { + try { + newMixingDistribution = fjp.submit(() -> resStream.parallel().reduce(zeroVector, DenseVector::add)).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Parallel execution failed", e); + } + } else { + newMixingDistribution = resStream.parallel().reduce(zeroVector, DenseVector::add); + } + // add minimum value to ensure all values are positive + newMixingDistribution.scalarAddInPlace(2e-15); - // Compute log likelihood bound - newLowerBound = computeLowerBound(); + // compute new means based on mixing distribution & positions + for (int j = 0; j < numGaussians; j++) { + meanVectors[j].set(0); + } + // Manual matrix multiply here as things are stored as arrays of vectors + // responsibilities[examples, gaussians], data[examples, features], means[gaussians, features] + for (int j = 0; j < examples.size(); j++) { + DenseVector curResp = responsibilities[j]; + SGDVector curExample = data[j]; + for (VectorTuple v : curExample) { + for (int k = 0; k < numGaussians; k++) { + DenseVector curMean = meanVectors[k]; + curMean.set(v.index, curMean.get(v.index) + v.value * curResp.get(k)); + } + } + } + for (int j = 0; j < numGaussians; j++) { + meanVectors[j].scaleInPlace(newMixingDistribution.get(j)); + } + + // compute new covariances + + // renormalize mixing distribution + double mixingSum = newMixingDistribution.sum(); + newMixingDistribution.scaleInPlace(1/mixingSum); + mixingDistribution.setElements(newMixingDistribution); + + // factorize covariances + + logger.log(Level.FINE, i + "th m step completed."); logger.log(Level.INFO, "Iteration " + i + " completed."); - if (newLowerBound - oldLowerBound < convergenceTolerance) { + if ((newLowerBound - oldLowerBound) < convergenceTolerance) { converged = true; logger.log(Level.INFO, "GMM converged at iteration " + i); } @@ -279,7 +337,7 @@ public GaussianMixtureModel train(Dataset examples, Map counts = new HashMap<>(); for (int i = 0; i < examples.size(); i++) { - int idx = responsibilities.getRow(i).argmax(); + int idx = responsibilities[i].argmax(); var count = counts.computeIfAbsent(idx, k -> new MutableLong()); count.increment(); } @@ -330,13 +388,13 @@ private static DenseVector[] initialiseRandomCentroids(int centroids, ImmutableF DenseVector[] centroidVectors = new DenseVector[centroids]; int numFeatures = featureMap.size(); for (int i = 0; i < centroids; i++) { - double[] newCentroid = new double[numFeatures]; + DenseVector newCentroid = new DenseVector(numFeatures); for (int j = 0; j < numFeatures; j++) { - newCentroid[j] = featureMap.get(j).uniformSample(rng); + newCentroid.set(j, featureMap.get(j).uniformSample(rng)); } - centroidVectors[i] = DenseVector.createDenseVector(newCentroid); + centroidVectors[i] = newCentroid; } return centroidVectors; } @@ -407,46 +465,9 @@ private static DenseVector getRandomCentroidFromData(SGDVector[] data, Splittabl return DenseVector.createDenseVector(data[randIdx].toArray()); } - /** - * Runs the mStep, writing to the {@code centroidVectors} array. - * @param fjp The ForkJoinPool to run the computation in if it should be executed in parallel. - * If the fjp is null then the computation is executed sequentially. - * @param centroidVectors The centroid vectors to write out. - * @param data The data points. - * @param weights The example weights. - */ - protected void mStep(ForkJoinPool fjp, DenseVector[] centroidVectors, DenseMatrix[] covarianceMatrices, - DenseMatrix.CholeskyFactorization[] precisionFactorizations, DenseVector mixingDistribution, SGDVector[] data, double[] weights) { - // M step - Consumer>> mStepFunc = (e) -> { - DenseVector newCentroid = centroidVectors[e.getKey()]; - newCentroid.fill(0.0); - - double weightSum = 0.0; - for (Integer idx : e.getValue()) { - newCentroid.intersectAndAddInPlace(data[idx], (double f) -> f * weights[idx]); - weightSum += weights[idx]; - } - if (weightSum != 0.0) { - newCentroid.scaleInPlace(1.0 / weightSum); - } - }; - - Stream>> mStream = clusterAssignments.entrySet().stream(); - if (fjp != null) { - try { - fjp.submit(() -> mStream.parallel().forEach(mStepFunc)).get(); - } catch (InterruptedException | ExecutionException e) { - throw new RuntimeException("Parallel execution failed", e); - } - } else { - mStream.forEach(mStepFunc); - } - } - @Override public String toString() { - return "GMMTrainer(centroids=" + centroids + ",seed=" + seed + ",numThreads=" + numThreads + ", initialisationType=" + initialisationType + ")"; + return "GMMTrainer(numGaussians=" + numGaussians + ",seed=" + seed + ",numThreads=" + numThreads + ", initialisationType=" + initialisationType + ")"; } @Override @@ -454,4 +475,8 @@ public TrainerProvenance getProvenance() { return new TrainerProvenanceImpl(this); } + /** + * Tuple of index and position. + */ + record IntAndVector(int idx, SGDVector vector) { } } \ No newline at end of file diff --git a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java new file mode 100644 index 000000000..5527525b4 --- /dev/null +++ b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java @@ -0,0 +1,219 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.clustering.gmm; + +import com.oracle.labs.mlrg.olcut.util.Pair; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.tribuo.Dataset; +import org.tribuo.Model; +import org.tribuo.MutableDataset; +import org.tribuo.clustering.ClusterID; +import org.tribuo.clustering.evaluation.ClusteringEvaluation; +import org.tribuo.clustering.evaluation.ClusteringEvaluator; +import org.tribuo.clustering.example.ClusteringDataGenerator; +import org.tribuo.clustering.example.GaussianClusterDataSource; +import org.tribuo.math.distributions.MultivariateNormalDistribution; +import org.tribuo.math.la.DenseVector; +import org.tribuo.test.Helpers; + +import java.util.logging.Level; +import java.util.logging.Logger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Smoke tests for k-means. + */ +public class TestGMM { + + private static final GMMTrainer t = new GMMTrainer(4, 10, MultivariateNormalDistribution.CovarianceType.DIAGONAL, + GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1); + + private static final GMMTrainer plusPlus = new GMMTrainer(4, 10, MultivariateNormalDistribution.CovarianceType.FULL, + GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 1, 1); + + @BeforeAll + public static void setup() { + Logger logger = Logger.getLogger(GMMTrainer.class.getName()); + logger.setLevel(Level.WARNING); + logger = Logger.getLogger(org.tribuo.util.infotheory.InformationTheory.class.getName()); + logger.setLevel(Level.WARNING); + } + + @Test + public void testEvaluation() { + runEvaluation(t); + } + + @Test + public void testPlusPlusEvaluation() { + runEvaluation(plusPlus); + } + + public static void runEvaluation(GMMTrainer trainer) { + Dataset data = new MutableDataset<>(new GaussianClusterDataSource(500, 1L)); + Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); + ClusteringEvaluator eval = new ClusteringEvaluator(); + + GaussianMixtureModel model = trainer.train(data); + + Helpers.testModelSerialization(model, ClusterID.class); + Helpers.testModelProtoSerialization(model, ClusterID.class, test); + + ClusteringEvaluation trainEvaluation = eval.evaluate(model,data); + assertFalse(Double.isNaN(trainEvaluation.adjustedMI())); + assertFalse(Double.isNaN(trainEvaluation.normalizedMI())); + + ClusteringEvaluation testEvaluation = eval.evaluate(model,test); + assertFalse(Double.isNaN(testEvaluation.adjustedMI())); + assertFalse(Double.isNaN(testEvaluation.normalizedMI())); + } + + public static Model testTrainer(Pair, Dataset> p, GMMTrainer trainer) { + Model m = trainer.train(p.getA()); + ClusteringEvaluator e = new ClusteringEvaluator(); + e.evaluate(m,p.getB()); + return m; + } + + public static Model runDenseData(GMMTrainer trainer) { + Pair,Dataset> p = ClusteringDataGenerator.denseTrainTest(); + return testTrainer(p, trainer); + } + + @Test + public void testDenseData() { + Model model = runDenseData(t); + Helpers.testModelSerialization(model,ClusterID.class); + } + + @Test + public void testPlusPlusDenseData() { + runDenseData(plusPlus); + } + + public void runSparseData(GMMTrainer trainer) { + Pair,Dataset> p = ClusteringDataGenerator.sparseTrainTest(); + testTrainer(p, trainer); + } + + @Test + public void testSparseData() { + runSparseData(t); + } + + @Test + public void testPlusPlusSparseData() { + runSparseData(plusPlus); + } + + public void runInvalidExample(GMMTrainer trainer) { + assertThrows(IllegalArgumentException.class, () -> { + Pair, Dataset> p = ClusteringDataGenerator.denseTrainTest(); + Model m = trainer.train(p.getA()); + m.predict(ClusteringDataGenerator.invalidSparseExample()); + }); + } + + @Test + public void testInvalidExample() { + runInvalidExample(t); + } + + @Test + public void testPlusPlusInvalidExample() { + runInvalidExample(plusPlus); + } + + + public void runEmptyExample(GMMTrainer trainer) { + assertThrows(IllegalArgumentException.class, () -> { + Pair, Dataset> p = ClusteringDataGenerator.denseTrainTest(); + Model m = trainer.train(p.getA()); + m.predict(ClusteringDataGenerator.emptyExample()); + }); + } + + @Test + public void testEmptyExample() { + runEmptyExample(t); + } + + @Test + public void testPlusPlusEmptyExample() { + runEmptyExample(plusPlus); + } + + @Test + public void testPlusPlusTooManyCentroids() { + assertThrows(IllegalArgumentException.class, () -> { + Dataset data = ClusteringDataGenerator.gaussianClusters(3, 1L); + plusPlus.train(data); + }); + } + + @Test + public void testSetInvocationCount() { + // Create new trainer and dataset so as not to mess with the other tests + GMMTrainer originalTrainer = new GMMTrainer(4, 10, MultivariateNormalDistribution.CovarianceType.DIAGONAL, + GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1); + Pair,Dataset> p = ClusteringDataGenerator.denseTrainTest(); + + // The number of times to call train before final training. + // Original trainer will be trained numOfInvocations + 1 times + // New trainer will have it's invocation count set to numOfInvocations then trained once + int numOfInvocations = 2; + + // Create the first model and train it numOfInvocations + 1 times + GaussianMixtureModel originalModel = null; + for(int invocationCounter = 0; invocationCounter < numOfInvocations + 1; invocationCounter++){ + originalModel = originalTrainer.train(p.getA()); + } + + // Create a new model with same configuration, but set the invocation count to numOfInvocations + // Assert that this succeeded, this means RNG will be at state where originalTrainer was before + // it performed its last train. + GMMTrainer newTrainer = new GMMTrainer(4, 10, MultivariateNormalDistribution.CovarianceType.DIAGONAL, + GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1); + newTrainer.setInvocationCount(numOfInvocations); + assertEquals(numOfInvocations,newTrainer.getInvocationCount()); + + // Training newTrainer should now have the same result as if it + // had trained numOfInvocations times previously even though it hasn't + GaussianMixtureModel newModel = newTrainer.train(p.getA()); + assertEquals(originalTrainer.getInvocationCount(),newTrainer.getInvocationCount()); + + DenseVector[] newWeights = newModel.getMeanVectors(); + DenseVector[] oldWeights = originalModel.getMeanVectors(); + + for (int centroidIndex = 0; centroidIndex < newWeights.length; centroidIndex++){ + assertEquals(oldWeights[centroidIndex],newWeights[centroidIndex]); + } + } + + @Test + public void testNegativeInvocationCount(){ + assertThrows(IllegalArgumentException.class, () -> { + GMMTrainer t = new GMMTrainer(4, 10, MultivariateNormalDistribution.CovarianceType.DIAGONAL, + GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1); + t.setInvocationCount(-1); + }); + } +} diff --git a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java index 6051121bb..82e24615f 100644 --- a/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java +++ b/Clustering/KMeans/src/main/java/org/tribuo/clustering/kmeans/KMeansTrainer.java @@ -564,7 +564,7 @@ public TrainerProvenance getProvenance() { } /** - * Tuple of index and position. One day it'll be a record, but not today. + * Tuple of index and position. */ record IntAndVector(int idx, SGDVector vector) { } diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java index 0e56ceb8c..652c32de8 100644 --- a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java +++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java @@ -19,6 +19,7 @@ import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseSparseMatrix; import org.tribuo.math.la.DenseVector; +import org.tribuo.math.la.Matrix; import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.Tensor; @@ -100,6 +101,26 @@ public MultivariateNormalDistribution(DenseVector means, DenseMatrix covariance, this(means, covariance, CovarianceType.FULL, seed, eigenDecomposition); } + /** + * Constructs a multivariate normal distribution that can be sampled from using a spherical covariance. + * @param means The mean vector. + * @param sphericalCovariance The spherical covariance matrix, stored as a single double. + * @param seed The RNG seed. + */ + public MultivariateNormalDistribution(DenseVector means, double sphericalCovariance, long seed) { + this(means,new DenseVector(1,sphericalCovariance),CovarianceType.SPHERICAL,seed,false); + } + + /** + * Constructs a multivariate normal distribution that can be sampled from using a diagonal covariance matrix. + * @param means The mean vector. + * @param diagonalCovariance The diagonal covariance matrix, stored as a vector. + * @param seed The RNG seed. + */ + public MultivariateNormalDistribution(DenseVector means, DenseVector diagonalCovariance, long seed) { + this(means,diagonalCovariance,CovarianceType.DIAGONAL,seed,false); + } + /** * Constructs a multivariate normal distribution that can be sampled from. *

@@ -227,6 +248,10 @@ public double[] sampleArray() { * @return The log probability. */ public double logProbability(SGDVector input) { + return logProbability(input, means, covariance(), samplingCovariance, type); + } + + public static double logProbability(SGDVector input, DenseVector mean, Tensor covariance, Matrix.Factorization factorization, CovarianceType type) { } diff --git a/Math/src/main/java/org/tribuo/math/la/DenseVector.java b/Math/src/main/java/org/tribuo/math/la/DenseVector.java index 992bcfae0..6902dfcde 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseVector.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseVector.java @@ -432,6 +432,17 @@ public void hadamardProductInPlace(Tensor other, DoubleUnaryOperator f) { } } + /** + * Applies the function {@code f} to each element of this vector returning a new vector. + * @param f The function to apply. + * @return A copy of this vector with {@code f} applied to each element. + */ + public DenseVector foreach(DoubleUnaryOperator f) { + DenseVector output = new DenseVector(this); + output.foreachInPlace(f); + return output; + } + @Override public void foreachInPlace(DoubleUnaryOperator f) { for (int i = 0; i < elements.length; i++) { @@ -552,6 +563,14 @@ public void set(int index, double value) { elements[index] = value; } + /** + * Sets all elements of this vector to {@code value}. + * @param value The value to set things to. + */ + public void set(double value) { + Arrays.fill(elements, 0); + } + /** * Sets all the elements of this vector to be the same as {@code other}. * @param other The {@link DenseVector} to copy. From 3f8fc38782fc23397e9d9ea1191ede7c82140444 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 21 Apr 2024 21:45:32 -0400 Subject: [PATCH 07/15] Implementing MultivariateNormalDistribution.logProbability. --- .../org/tribuo/clustering/gmm/GMMTrainer.java | 36 ++++++++- .../clustering/gmm/GaussianMixtureModel.java | 3 +- .../MultivariateNormalDistribution.java | 81 +++++++++++++++++-- 3 files changed, 109 insertions(+), 11 deletions(-) diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index e6e778599..5dbf3107f 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -47,6 +47,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; @@ -217,7 +218,8 @@ public GaussianMixtureModel train(Dataset examples, Map initialisePlusPlusCentroids(numGaussians, data, localRNG); }; final Tensor[] covarianceMatrices = new Tensor[numGaussians]; - DenseMatrix.CholeskyFactorization[] precisionFactorizations = new DenseMatrix.CholeskyFactorization[numGaussians]; + final Tensor[] precision = new Tensor[numGaussians]; + final double[] determinant = new double[numGaussians]; final DenseVector mixingDistribution = new DenseVector(numGaussians); boolean parallel = numThreads > 1; @@ -226,7 +228,7 @@ public GaussianMixtureModel train(Dataset examples, Map examples, Map { + for (int j = 0; j < covarianceMatrices.length; j++) { + DenseMatrix covMax = (DenseMatrix) covarianceMatrices[j]; + Optional optFact = covMax.choleskyFactorization(); + if (optFact.isPresent()) { + DenseMatrix.CholeskyFactorization fact = optFact.get(); + precision[j] = fact.inverse(); + determinant[j] = fact.determinant(); + } else { + throw new IllegalStateException("Failed to invert covariance matrix, cholesky didn't complete."); + } + } + } + case DIAGONAL, SPHERICAL -> { + for (int j = 0; j < covarianceMatrices.length; j++) { + DenseVector covVec = (DenseVector) covarianceMatrices[j]; + DenseVector preVec = (DenseVector) precision[j]; + double tmp = 1; + for (int k = 0; k < preVec.size(); k++) { + double curVal = 1/Math.sqrt(covVec.get(k)); + preVec.set(k, curVal); + tmp *= curVal; + } + determinant[j] = tmp; + } + } + } logger.log(Level.FINE, i + "th m step completed."); diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java index 7e53e6c44..52936cabd 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java @@ -319,7 +319,8 @@ public List> sample(int numSamples, RandomGenerator r int dist = Util.sampleFromCDF(cdf, rng); // Sample from appropriate MultivariateNormalDistribution - + DenseVector sample = distributions[dist].sampleVector(rng); + output.add(new Pair<>(dist, sample)); } return output; diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java index 652c32de8..c36c9e16d 100644 --- a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java +++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java @@ -19,7 +19,6 @@ import org.tribuo.math.la.DenseMatrix; import org.tribuo.math.la.DenseSparseMatrix; import org.tribuo.math.la.DenseVector; -import org.tribuo.math.la.Matrix; import org.tribuo.math.la.SGDVector; import org.tribuo.math.la.Tensor; @@ -40,6 +39,8 @@ public final class MultivariateNormalDistribution { private final DenseVector covarianceVector; private final DenseMatrix covarianceMatrix; private final DenseMatrix samplingCovariance; + private final Tensor precision; + private final double determinant; private final boolean eigenDecomposition; private final CovarianceType type; @@ -159,6 +160,8 @@ public MultivariateNormalDistribution(DenseVector means, Tensor covariance, Cova eigenvalues.foreachInPlace(Math::sqrt); DenseSparseMatrix diagonal = DenseSparseMatrix.createDiagonal(eigenvalues); this.samplingCovariance = eigenvectors.matrixMultiply(diagonal).matrixMultiply(eigenvectors,false,true); + this.determinant = factorization.get().determinant(); + this.precision = factorization.get().inverse(); } else { throw new IllegalArgumentException("Covariance matrix is not positive definite."); } @@ -166,6 +169,8 @@ public MultivariateNormalDistribution(DenseVector means, Tensor covariance, Cova Optional factorization = this.covarianceMatrix.choleskyFactorization(); if (factorization.isPresent()) { this.samplingCovariance = factorization.get().lMatrix(); + this.determinant = factorization.get().determinant(); + this.precision = factorization.get().inverse(); } else { throw new IllegalArgumentException("Covariance matrix is not positive definite."); } @@ -183,6 +188,14 @@ public MultivariateNormalDistribution(DenseVector means, Tensor covariance, Cova throw new IllegalArgumentException("Covariance must be a vector and the same dimension as the mean vector. Mean vector size = " + means.size() + ", covariance size = " + this.covarianceVector.size()); } + double tmp = 1; + for (int i = 0; i < this.covarianceVector.size(); i++) { + tmp *= this.covarianceVector.get(i); + } + this.determinant = tmp; + this.precision = this.covarianceVector.copy(); + this.precision.foreachInPlace(a -> 1.0/a); + // set unused variables. this.covarianceMatrix = null; this.samplingCovariance = null; @@ -190,13 +203,20 @@ public MultivariateNormalDistribution(DenseVector means, Tensor covariance, Cova } case SPHERICAL -> { if (covariance instanceof DenseVector vec) { - if (vec.size() != 1) { + if ((vec.size() != 1) && (vec.size() != means.size())) { throw new IllegalArgumentException("Covariance must be a single element vector for spherical covariance. Found " + vec.size()); } } else { throw new IllegalArgumentException("Covariance must be a single element vector for spherical covariance, found " + covariance.getClass()); } - this.variance = Double.NaN; + this.variance = vec.get(0); + + double tmp = 1; + for (int i = 0; i < means.size(); i++) { + tmp *= this.variance; + } + this.determinant = tmp; + this.precision = new DenseVector(1, 1.0 / variance); // set unused variables. this.covarianceVector = null; @@ -231,7 +251,9 @@ public DenseVector sampleVector(RandomGenerator otherRNG) { case SPHERICAL -> sampled.scaleInPlace(variance); } - return means.add(sampled); + sampled.intersectAndAddInPlace(means); + + return sampled; } /** @@ -242,17 +264,62 @@ public double[] sampleArray() { return sampleVector().toArray(); } + /** + * Gets a copy of the mean vector. + * @return A copy of the mean vector. + */ + public DenseVector means() { + return means.copy(); + } + + /** + * Gets a copy of the covariance, either a {@link DenseMatrix} if it's full rank, + * or a {@link DenseVector} if it's diagonal or spherical. + * @return The covariance. + */ + public Tensor covariance() { + return switch (type) { + case FULL -> covarianceMatrix.copy(); + case DIAGONAL -> covarianceVector.copy(); + case SPHERICAL -> new DenseVector(means.size(), variance); + }; + } + /** * Compute the log probability of the input under this multivariate normal distribution. * @param input The input to compute. * @return The log probability. */ public double logProbability(SGDVector input) { - return logProbability(input, means, covariance(), samplingCovariance, type); + return logProbability(input, means, precision, determinant, type); } - public static double logProbability(SGDVector input, DenseVector mean, Tensor covariance, Matrix.Factorization factorization, CovarianceType type) { - + public static double logProbability(SGDVector input, DenseVector mean, Tensor precision, double determinant, CovarianceType type) { + // p(input|mean, variance) = \frac{1}{(2\pi)^{d/2} determinant^{1/2}} e^{-1/2 * (input - mean)^T * precision * (input - mean)} + // log p(i|mu,sigma) = - log ({2\pi}^{d/2}) - log (determinant^{1/2}) + (-1/2 * (i - mu)^T * precision * (i - mu)) + double scalar = (- (mean.size() / 2.0) * Math.log(2 * Math.PI)) - (Math.log(determinant) / 2.0); + DenseVector diff = (DenseVector) input.subtract(mean); + double distance = switch (type) { + case FULL -> { + DenseMatrix precMat = (DenseMatrix) precision; + yield precMat.leftMultiply(diff).dot(diff); + } + case DIAGONAL -> { + // diff^T * diagonal precision * diff + // = diff.hadamard(precision).dot(diff) + // = diff.square().dot(precision) + DenseVector precVec = (DenseVector) precision; + diff.foreachInPlace(a -> a * a); + yield diff.dot(precVec); + } + case SPHERICAL -> { + double precVal = ((DenseVector) precision).get(0); + diff.foreachInPlace(a -> a * a); + diff.scaleInPlace(precVal); + yield diff.sum(); + } + }; + return scalar + (0.5 * distance); } @Override From 80dba6422e40de67022fcfdd8bc67729fdba69b5 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 22 Apr 2024 21:17:49 -0400 Subject: [PATCH 08/15] Working on covariance calculation. --- .../org/tribuo/clustering/gmm/GMMTrainer.java | 78 +++++++++++++++---- .../java/org/tribuo/math/la/DenseMatrix.java | 18 +++++ 2 files changed, 80 insertions(+), 16 deletions(-) diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index 5dbf3107f..51091848d 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -51,6 +51,8 @@ import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; +import java.util.function.BinaryOperator; +import java.util.function.Function; import java.util.function.ToDoubleFunction; import java.util.logging.Level; import java.util.logging.Logger; @@ -217,18 +219,18 @@ public GaussianMixtureModel train(Dataset examples, Map initialiseRandomCentroids(numGaussians, featureMap, localRNG); case PLUSPLUS -> initialisePlusPlusCentroids(numGaussians, data, localRNG); }; - final Tensor[] covarianceMatrices = new Tensor[numGaussians]; + Tensor[] covariances = new Tensor[numGaussians]; final Tensor[] precision = new Tensor[numGaussians]; final double[] determinant = new double[numGaussians]; final DenseVector mixingDistribution = new DenseVector(numGaussians); boolean parallel = numThreads > 1; - ToDoubleFunction eStepFunc = (IntAndVector e) -> { - DenseVector curResponsibilities = responsibilities[e.idx]; + ToDoubleFunction eStepFunc = (Vectors e) -> { + DenseVector curResponsibilities = e.responsibility; // compute log probs for (int i = 0; i < meanVectors.length; i++) { - curResponsibilities.set(i, MultivariateNormalDistribution.logProbability(e.vector, meanVectors[i], precision[i], determinant[i], covarianceType)); + curResponsibilities.set(i, MultivariateNormalDistribution.logProbability(e.data, meanVectors[i], precision[i], determinant[i], covarianceType)); } // add mixing distribution @@ -257,17 +259,17 @@ public GaussianMixtureModel train(Dataset examples, Map vecStream = Arrays.stream(data); - Stream intStream = IntStream.range(0, data.length).boxed(); - Stream zipStream = StreamUtil.zip(intStream, vecStream, IntAndVector::new); + Stream dataEStream = Arrays.stream(data); + Stream resEStream = Arrays.stream(responsibilities); + Stream zipEStream = StreamUtil.zip(dataEStream, resEStream, Vectors::new); if (parallel) { try { - normSum = fjp.submit(() -> zipStream.parallel().mapToDouble(eStepFunc).sum()).get(); + normSum = fjp.submit(() -> zipEStream.parallel().mapToDouble(eStepFunc).sum()).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } } else { - normSum = zipStream.mapToDouble(eStepFunc).sum(); + normSum = zipEStream.mapToDouble(eStepFunc).sum(); } logger.log(Level.FINE, i + "th e step completed."); @@ -312,6 +314,50 @@ public GaussianMixtureModel train(Dataset examples, Map dataMStream = Arrays.stream(data); + Stream resMStream = Arrays.stream(responsibilities); + Stream zipMStream = StreamUtil.zip(dataMStream, resMStream, Vectors::new); + Tensor[] zeroTensorArr = switch (covarianceType) { + case FULL -> { + Tensor[] output = new Tensor[numGaussians]; + for (int j = 0; j < numGaussians; j++) { + output[i] = new DenseMatrix(featureMap.size(), featureMap.size()); + } + yield output; + } + case DIAGONAL, SPHERICAL -> { + Tensor[] output = new Tensor[numGaussians]; + for (int j = 0; j < numGaussians; j++) { + output[i] = new DenseVector(featureMap.size()); + } + yield output; + } + }; + Function mStep = (Vectors v) -> { + + }; + BinaryOperator combineTensor = (Tensor[] a, Tensor[] b) -> { + Tensor[] output = new Tensor[a.length]; + for (int j = 0; j < a.length; j++) { + if (a[j] instanceof DenseMatrix aMat && b[j] instanceof DenseMatrix bMat) { + output[j] = aMat.add(bMat); + } else if (a[j] instanceof DenseVector aVec && b[j] instanceof DenseVector bVec) { + output[j] = aVec.add(bVec); + } else { + throw new IllegalStateException("Invalid types in reduce, expected both DenseMatrix or DenseVector, found " + a[j].getClass() + " and " + b[j].getClass()); + } + } + return output; + }; + if (parallel) { + try { + covariances = fjp.submit(() -> zipMStream.parallel().map(mStep).reduce(zeroTensorArr, combineTensor)).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException("Parallel execution failed", e); + } + } else { + covariances = zipMStream.parallel().map(mStep).reduce(zeroTensorArr, combineTensor); + } // renormalize mixing distribution double mixingSum = newMixingDistribution.sum(); @@ -321,8 +367,8 @@ public GaussianMixtureModel train(Dataset examples, Map { - for (int j = 0; j < covarianceMatrices.length; j++) { - DenseMatrix covMax = (DenseMatrix) covarianceMatrices[j]; + for (int j = 0; j < covariances.length; j++) { + DenseMatrix covMax = (DenseMatrix) covariances[j]; Optional optFact = covMax.choleskyFactorization(); if (optFact.isPresent()) { DenseMatrix.CholeskyFactorization fact = optFact.get(); @@ -334,8 +380,8 @@ public GaussianMixtureModel train(Dataset examples, Map { - for (int j = 0; j < covarianceMatrices.length; j++) { - DenseVector covVec = (DenseVector) covarianceMatrices[j]; + for (int j = 0; j < covariances.length; j++) { + DenseVector covVec = (DenseVector) covariances[j]; DenseVector preVec = (DenseVector) precision[j]; double tmp = 1; for (int k = 0; k < preVec.size(); k++) { @@ -378,7 +424,7 @@ public GaussianMixtureModel train(Dataset examples, Map Date: Sun, 28 Apr 2024 15:06:07 -0400 Subject: [PATCH 09/15] Code compiles for GMM. Inference still isn't quite right though. --- .../org/tribuo/clustering/gmm/GMMTrainer.java | 108 +- .../clustering/gmm/GaussianMixtureModel.java | 5 +- .../gmm/protos/GaussianMixtureModelProto.java | 1711 +++++++++++++++++ .../GaussianMixtureModelProtoOrBuilder.java | 99 + .../gmm/protos/TribuoClusteringGmm.java | 64 + .../org/tribuo/clustering/gmm/TestGMM.java | 6 +- .../java/org/tribuo/math/la/DenseMatrix.java | 13 + pom.xml | 52 +- 8 files changed, 2004 insertions(+), 54 deletions(-) create mode 100644 Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/GaussianMixtureModelProto.java create mode 100644 Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/GaussianMixtureModelProtoOrBuilder.java create mode 100644 Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/TribuoClusteringGmm.java diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index 51091848d..f4b940345 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -51,6 +51,7 @@ import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; +import java.util.function.BiFunction; import java.util.function.BinaryOperator; import java.util.function.Function; import java.util.function.ToDoubleFunction; @@ -61,8 +62,8 @@ /** * A Gaussian Mixture Model trainer, which generates a GMM clustering of the supplied - * data. The model finds the centres, and then predict needs to be - * called to infer the centre assignments for the input data. + * data. The model finds the Gaussians, and then predict needs to be + * called to infer the cluster assignments for the input data. *

* It's slightly contorted to fit the Tribuo Trainer and Model API, as the cluster assignments * can only be retrieved from the model after training, and require re-evaluating each example. @@ -86,7 +87,7 @@ * PDF * */ -public class GMMTrainer implements Trainer, WeightedExamples { +public class GMMTrainer implements Trainer { private static final Logger logger = Logger.getLogger(GMMTrainer.class.getName()); /** @@ -125,6 +126,9 @@ public enum Initialisation { @Config(mandatory = true, description = "The seed to use for the RNG.") private long seed; + @Config(description = "Jitter to add to the covariance diagonal.") + private double covJitter = 1e-6; + private SplittableRandom rng; private int trainInvocationCounter; @@ -199,14 +203,13 @@ public GaussianMixtureModel train(Dataset examples, Map example : examples) { - weights[n] = example.getWeight(); - if (example.size() == featureMap.size()) { + if (example.size() == numFeatures) { data[n] = DenseVector.createDenseVector(example, featureMap, false); } else { data[n] = SparseVector.createSparseVector(example, featureMap, false); @@ -222,7 +225,29 @@ public GaussianMixtureModel train(Dataset examples, Map { + covariances[i] = DenseMatrix.createIdentity(numFeatures); + precision[i] = DenseMatrix.createIdentity(numFeatures); + } + case DIAGONAL, SPHERICAL -> { + covariances[i] = new DenseVector(numFeatures, 1.0); + precision[i] = new DenseVector(numFeatures, 1.0); + } + } + } boolean parallel = numThreads > 1; @@ -310,7 +335,7 @@ public GaussianMixtureModel train(Dataset examples, Map examples, Map { Tensor[] output = new Tensor[numGaussians]; for (int j = 0; j < numGaussians; j++) { - output[i] = new DenseMatrix(featureMap.size(), featureMap.size()); + output[j] = new DenseMatrix(numFeatures, numFeatures); } yield output; } case DIAGONAL, SPHERICAL -> { Tensor[] output = new Tensor[numGaussians]; for (int j = 0; j < numGaussians; j++) { - output[i] = new DenseVector(featureMap.size()); + output[j] = new DenseVector(numFeatures); } yield output; } }; - Function mStep = (Vectors v) -> { + // Fix parallel behaviour + BiFunction mStep = switch (covarianceType) { + case FULL -> (Tensor[] input, Vectors v) -> { + for (int j = 0; j < numGaussians; j++) { + // Compute covariance contribution from current input + DenseMatrix curCov = (DenseMatrix) input[j]; + DenseVector diff = (DenseVector) v.data.subtract(meanVectors[j]); + diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j)); + curCov.intersectAndAddInPlace(diff.outer(diff)); + } + return input; + }; + case DIAGONAL -> (Tensor[] input, Vectors v) -> { + for (int j = 0; j < numGaussians; j++) { + // Compute covariance contribution from current input + DenseVector curCov = (DenseVector) input[j]; + double curResp = v.responsibility.get(j); + for (int k = 0; k < numFeatures; k++) { + double currentCovValue = curCov.get(k); + double curMean = meanVectors[j].get(k); + double curData = v.data.get(k); + double dataSq = curResp * curData * curData / newMixingDistribution.get(j); + double meanSq = curMean * curMean; + double dataMean = 2 * curResp * curData * curMean / newMixingDistribution.get(j); + double update = currentCovValue + dataSq - dataMean + meanSq; + curCov.set(k, update); + } + } + return input; + }; + case SPHERICAL -> (Tensor[] input, Vectors v) -> { + for (int j = 0; j < numGaussians; j++) { + // Compute covariance contribution from current input + DenseVector curCov = (DenseVector) input[j]; + double curResp = v.responsibility.get(j); + double update = 0; + for (int k = 0; k < numFeatures; k++) { + double curMean = meanVectors[j].get(k); + double curData = v.data.get(k); + double dataSq = curResp * curData * curData / newMixingDistribution.get(j); + double meanSq = curMean * curMean; + double dataMean = 2 * curResp * curData * curMean / newMixingDistribution.get(j); + update += dataSq - dataMean + meanSq; + } + update = update / numFeatures; + curCov.scalarAddInPlace(update); + } + return input; + }; }; BinaryOperator combineTensor = (Tensor[] a, Tensor[] b) -> { Tensor[] output = new Tensor[a.length]; @@ -350,13 +423,16 @@ public GaussianMixtureModel train(Dataset examples, Map zipMStream.parallel().map(mStep).reduce(zeroTensorArr, combineTensor)).get(); + covariances = fjp.submit(() -> zipMStream.parallel().reduce(zeroTensorArr, mStep, combineTensor)).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } + */ } else { - covariances = zipMStream.parallel().map(mStep).reduce(zeroTensorArr, combineTensor); + covariances = zipMStream.reduce(zeroTensorArr, mStep, combineTensor); } // renormalize mixing distribution @@ -368,8 +444,9 @@ public GaussianMixtureModel train(Dataset examples, Map { for (int j = 0; j < covariances.length; j++) { - DenseMatrix covMax = (DenseMatrix) covariances[j]; - Optional optFact = covMax.choleskyFactorization(); + DenseMatrix covMat = (DenseMatrix) covariances[j]; + covMat.intersectAndAddInPlace(covarianceJitter); + Optional optFact = covMat.choleskyFactorization(); if (optFact.isPresent()) { DenseMatrix.CholeskyFactorization fact = optFact.get(); precision[j] = fact.inverse(); @@ -382,6 +459,7 @@ public GaussianMixtureModel train(Dataset examples, Map { for (int j = 0; j < covariances.length; j++) { DenseVector covVec = (DenseVector) covariances[j]; + covVec.intersectAndAddInPlace(covarianceJitter); DenseVector preVec = (DenseVector) precision[j]; double tmp = 1; for (int k = 0; k < preVec.size(); k++) { diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java index 52936cabd..42e36a4f0 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GaussianMixtureModel.java @@ -27,6 +27,7 @@ import org.tribuo.Model; import org.tribuo.Prediction; import org.tribuo.clustering.ClusterID; +import org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto; import org.tribuo.impl.ArrayExample; import org.tribuo.impl.ModelDataCarrier; import org.tribuo.math.distributions.MultivariateNormalDistribution; @@ -81,7 +82,7 @@ public class GaussianMixtureModel extends Model { private final MultivariateNormalDistribution.CovarianceType covarianceType; - private final MultivariateNormalDistribution[] distributions; + private transient MultivariateNormalDistribution[] distributions; GaussianMixtureModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo outputIDInfo, DenseVector[] meanVectors, @@ -265,7 +266,7 @@ public Prediction predict(Example example) { } // generate cluster responsibilities and normalize into a distribution - DenseVector responsibilities = new DenseVector(meanVectors[0].size()); + DenseVector responsibilities = new DenseVector(distributions.length); // compute log probs for (int i = 0; i < distributions.length; i++) { diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/GaussianMixtureModelProto.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/GaussianMixtureModelProto.java new file mode 100644 index 000000000..a177d0483 --- /dev/null +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/GaussianMixtureModelProto.java @@ -0,0 +1,1711 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tribuo-clustering-gmm.proto + +// Protobuf Java Version: 3.25.3 +package org.tribuo.clustering.gmm.protos; + +/** + *

+ *
+ *GaussianMixtureModel proto
+ * 
+ * + * Protobuf type {@code tribuo.clustering.gmm.GaussianMixtureModelProto} + */ +public final class GaussianMixtureModelProto extends + com.google.protobuf.GeneratedMessageV3 implements + // @@protoc_insertion_point(message_implements:tribuo.clustering.gmm.GaussianMixtureModelProto) + GaussianMixtureModelProtoOrBuilder { +private static final long serialVersionUID = 0L; + // Use GaussianMixtureModelProto.newBuilder() to construct. + private GaussianMixtureModelProto(com.google.protobuf.GeneratedMessageV3.Builder builder) { + super(builder); + } + private GaussianMixtureModelProto() { + meanVectors_ = java.util.Collections.emptyList(); + covarianceMatrices_ = java.util.Collections.emptyList(); + covarianceType_ = 0; + } + + @java.lang.Override + @SuppressWarnings({"unused"}) + protected java.lang.Object newInstance( + UnusedPrivateParameter unused) { + return new GaussianMixtureModelProto(); + } + + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.tribuo.clustering.gmm.protos.TribuoClusteringGmm.internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.tribuo.clustering.gmm.protos.TribuoClusteringGmm.internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.class, org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.Builder.class); + } + + /** + * Protobuf enum {@code tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto} + */ + public enum CovarianceTypeProto + implements com.google.protobuf.ProtocolMessageEnum { + /** + * FULL = 0; + */ + FULL(0), + /** + * DIAGONAL = 1; + */ + DIAGONAL(1), + /** + * SPHERICAL = 2; + */ + SPHERICAL(2), + UNRECOGNIZED(-1), + ; + + /** + * FULL = 0; + */ + public static final int FULL_VALUE = 0; + /** + * DIAGONAL = 1; + */ + public static final int DIAGONAL_VALUE = 1; + /** + * SPHERICAL = 2; + */ + public static final int SPHERICAL_VALUE = 2; + + + public final int getNumber() { + if (this == UNRECOGNIZED) { + throw new java.lang.IllegalArgumentException( + "Can't get the number of an unknown enum value."); + } + return value; + } + + /** + * @param value The numeric wire value of the corresponding enum entry. + * @return The enum associated with the given numeric wire value. + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static CovarianceTypeProto valueOf(int value) { + return forNumber(value); + } + + /** + * @param value The numeric wire value of the corresponding enum entry. + * @return The enum associated with the given numeric wire value. + */ + public static CovarianceTypeProto forNumber(int value) { + switch (value) { + case 0: return FULL; + case 1: return DIAGONAL; + case 2: return SPHERICAL; + default: return null; + } + } + + public static com.google.protobuf.Internal.EnumLiteMap + internalGetValueMap() { + return internalValueMap; + } + private static final com.google.protobuf.Internal.EnumLiteMap< + CovarianceTypeProto> internalValueMap = + new com.google.protobuf.Internal.EnumLiteMap() { + public CovarianceTypeProto findValueByNumber(int number) { + return CovarianceTypeProto.forNumber(number); + } + }; + + public final com.google.protobuf.Descriptors.EnumValueDescriptor + getValueDescriptor() { + if (this == UNRECOGNIZED) { + throw new java.lang.IllegalStateException( + "Can't get the descriptor of an unrecognized enum value."); + } + return getDescriptor().getValues().get(ordinal()); + } + public final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptorForType() { + return getDescriptor(); + } + public static final com.google.protobuf.Descriptors.EnumDescriptor + getDescriptor() { + return org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.getDescriptor().getEnumTypes().get(0); + } + + private static final CovarianceTypeProto[] VALUES = values(); + + public static CovarianceTypeProto valueOf( + com.google.protobuf.Descriptors.EnumValueDescriptor desc) { + if (desc.getType() != getDescriptor()) { + throw new java.lang.IllegalArgumentException( + "EnumValueDescriptor is not for this type."); + } + if (desc.getIndex() == -1) { + return UNRECOGNIZED; + } + return VALUES[desc.getIndex()]; + } + + private final int value; + + private CovarianceTypeProto(int value) { + this.value = value; + } + + // @@protoc_insertion_point(enum_scope:tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto) + } + + private int bitField0_; + public static final int METADATA_FIELD_NUMBER = 1; + private org.tribuo.protos.core.ModelDataProto metadata_; + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return Whether the metadata field is set. + */ + @java.lang.Override + public boolean hasMetadata() { + return ((bitField0_ & 0x00000001) != 0); + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return The metadata. + */ + @java.lang.Override + public org.tribuo.protos.core.ModelDataProto getMetadata() { + return metadata_ == null ? org.tribuo.protos.core.ModelDataProto.getDefaultInstance() : metadata_; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + @java.lang.Override + public org.tribuo.protos.core.ModelDataProtoOrBuilder getMetadataOrBuilder() { + return metadata_ == null ? org.tribuo.protos.core.ModelDataProto.getDefaultInstance() : metadata_; + } + + public static final int MIXING_DISTRIBUTION_FIELD_NUMBER = 2; + private org.tribuo.math.protos.TensorProto mixingDistribution_; + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + * @return Whether the mixingDistribution field is set. + */ + @java.lang.Override + public boolean hasMixingDistribution() { + return ((bitField0_ & 0x00000002) != 0); + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + * @return The mixingDistribution. + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProto getMixingDistribution() { + return mixingDistribution_ == null ? org.tribuo.math.protos.TensorProto.getDefaultInstance() : mixingDistribution_; + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProtoOrBuilder getMixingDistributionOrBuilder() { + return mixingDistribution_ == null ? org.tribuo.math.protos.TensorProto.getDefaultInstance() : mixingDistribution_; + } + + public static final int MEAN_VECTORS_FIELD_NUMBER = 3; + @SuppressWarnings("serial") + private java.util.List meanVectors_; + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + @java.lang.Override + public java.util.List getMeanVectorsList() { + return meanVectors_; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + @java.lang.Override + public java.util.List + getMeanVectorsOrBuilderList() { + return meanVectors_; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + @java.lang.Override + public int getMeanVectorsCount() { + return meanVectors_.size(); + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProto getMeanVectors(int index) { + return meanVectors_.get(index); + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProtoOrBuilder getMeanVectorsOrBuilder( + int index) { + return meanVectors_.get(index); + } + + public static final int COVARIANCE_MATRICES_FIELD_NUMBER = 4; + @SuppressWarnings("serial") + private java.util.List covarianceMatrices_; + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + @java.lang.Override + public java.util.List getCovarianceMatricesList() { + return covarianceMatrices_; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + @java.lang.Override + public java.util.List + getCovarianceMatricesOrBuilderList() { + return covarianceMatrices_; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + @java.lang.Override + public int getCovarianceMatricesCount() { + return covarianceMatrices_.size(); + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProto getCovarianceMatrices(int index) { + return covarianceMatrices_.get(index); + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + @java.lang.Override + public org.tribuo.math.protos.TensorProtoOrBuilder getCovarianceMatricesOrBuilder( + int index) { + return covarianceMatrices_.get(index); + } + + public static final int COVARIANCE_TYPE_FIELD_NUMBER = 5; + private int covarianceType_ = 0; + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @return The enum numeric value on the wire for covarianceType. + */ + @java.lang.Override public int getCovarianceTypeValue() { + return covarianceType_; + } + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @return The covarianceType. + */ + @java.lang.Override public org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto getCovarianceType() { + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto result = org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto.forNumber(covarianceType_); + return result == null ? org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto.UNRECOGNIZED : result; + } + + private byte memoizedIsInitialized = -1; + @java.lang.Override + public final boolean isInitialized() { + byte isInitialized = memoizedIsInitialized; + if (isInitialized == 1) return true; + if (isInitialized == 0) return false; + + memoizedIsInitialized = 1; + return true; + } + + @java.lang.Override + public void writeTo(com.google.protobuf.CodedOutputStream output) + throws java.io.IOException { + if (((bitField0_ & 0x00000001) != 0)) { + output.writeMessage(1, getMetadata()); + } + if (((bitField0_ & 0x00000002) != 0)) { + output.writeMessage(2, getMixingDistribution()); + } + for (int i = 0; i < meanVectors_.size(); i++) { + output.writeMessage(3, meanVectors_.get(i)); + } + for (int i = 0; i < covarianceMatrices_.size(); i++) { + output.writeMessage(4, covarianceMatrices_.get(i)); + } + if (covarianceType_ != org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto.FULL.getNumber()) { + output.writeEnum(5, covarianceType_); + } + getUnknownFields().writeTo(output); + } + + @java.lang.Override + public int getSerializedSize() { + int size = memoizedSize; + if (size != -1) return size; + + size = 0; + if (((bitField0_ & 0x00000001) != 0)) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(1, getMetadata()); + } + if (((bitField0_ & 0x00000002) != 0)) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(2, getMixingDistribution()); + } + for (int i = 0; i < meanVectors_.size(); i++) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(3, meanVectors_.get(i)); + } + for (int i = 0; i < covarianceMatrices_.size(); i++) { + size += com.google.protobuf.CodedOutputStream + .computeMessageSize(4, covarianceMatrices_.get(i)); + } + if (covarianceType_ != org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto.FULL.getNumber()) { + size += com.google.protobuf.CodedOutputStream + .computeEnumSize(5, covarianceType_); + } + size += getUnknownFields().getSerializedSize(); + memoizedSize = size; + return size; + } + + @java.lang.Override + public boolean equals(final java.lang.Object obj) { + if (obj == this) { + return true; + } + if (!(obj instanceof org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto)) { + return super.equals(obj); + } + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto other = (org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto) obj; + + if (hasMetadata() != other.hasMetadata()) return false; + if (hasMetadata()) { + if (!getMetadata() + .equals(other.getMetadata())) return false; + } + if (hasMixingDistribution() != other.hasMixingDistribution()) return false; + if (hasMixingDistribution()) { + if (!getMixingDistribution() + .equals(other.getMixingDistribution())) return false; + } + if (!getMeanVectorsList() + .equals(other.getMeanVectorsList())) return false; + if (!getCovarianceMatricesList() + .equals(other.getCovarianceMatricesList())) return false; + if (covarianceType_ != other.covarianceType_) return false; + if (!getUnknownFields().equals(other.getUnknownFields())) return false; + return true; + } + + @java.lang.Override + public int hashCode() { + if (memoizedHashCode != 0) { + return memoizedHashCode; + } + int hash = 41; + hash = (19 * hash) + getDescriptor().hashCode(); + if (hasMetadata()) { + hash = (37 * hash) + METADATA_FIELD_NUMBER; + hash = (53 * hash) + getMetadata().hashCode(); + } + if (hasMixingDistribution()) { + hash = (37 * hash) + MIXING_DISTRIBUTION_FIELD_NUMBER; + hash = (53 * hash) + getMixingDistribution().hashCode(); + } + if (getMeanVectorsCount() > 0) { + hash = (37 * hash) + MEAN_VECTORS_FIELD_NUMBER; + hash = (53 * hash) + getMeanVectorsList().hashCode(); + } + if (getCovarianceMatricesCount() > 0) { + hash = (37 * hash) + COVARIANCE_MATRICES_FIELD_NUMBER; + hash = (53 * hash) + getCovarianceMatricesList().hashCode(); + } + hash = (37 * hash) + COVARIANCE_TYPE_FIELD_NUMBER; + hash = (53 * hash) + covarianceType_; + hash = (29 * hash) + getUnknownFields().hashCode(); + memoizedHashCode = hash; + return hash; + } + + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom( + java.nio.ByteBuffer data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom( + java.nio.ByteBuffer data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom( + com.google.protobuf.ByteString data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom( + com.google.protobuf.ByteString data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom(byte[] data) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom( + byte[] data, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + return PARSER.parseFrom(data, extensionRegistry); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseDelimitedFrom(java.io.InputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input); + } + + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseDelimitedFrom( + java.io.InputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseDelimitedWithIOException(PARSER, input, extensionRegistry); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom( + com.google.protobuf.CodedInputStream input) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input); + } + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto parseFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + return com.google.protobuf.GeneratedMessageV3 + .parseWithIOException(PARSER, input, extensionRegistry); + } + + @java.lang.Override + public Builder newBuilderForType() { return newBuilder(); } + public static Builder newBuilder() { + return DEFAULT_INSTANCE.toBuilder(); + } + public static Builder newBuilder(org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto prototype) { + return DEFAULT_INSTANCE.toBuilder().mergeFrom(prototype); + } + @java.lang.Override + public Builder toBuilder() { + return this == DEFAULT_INSTANCE + ? new Builder() : new Builder().mergeFrom(this); + } + + @java.lang.Override + protected Builder newBuilderForType( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + Builder builder = new Builder(parent); + return builder; + } + /** + *
+   *
+   *GaussianMixtureModel proto
+   * 
+ * + * Protobuf type {@code tribuo.clustering.gmm.GaussianMixtureModelProto} + */ + public static final class Builder extends + com.google.protobuf.GeneratedMessageV3.Builder implements + // @@protoc_insertion_point(builder_implements:tribuo.clustering.gmm.GaussianMixtureModelProto) + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProtoOrBuilder { + public static final com.google.protobuf.Descriptors.Descriptor + getDescriptor() { + return org.tribuo.clustering.gmm.protos.TribuoClusteringGmm.internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_descriptor; + } + + @java.lang.Override + protected com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internalGetFieldAccessorTable() { + return org.tribuo.clustering.gmm.protos.TribuoClusteringGmm.internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_fieldAccessorTable + .ensureFieldAccessorsInitialized( + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.class, org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.Builder.class); + } + + // Construct using org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.newBuilder() + private Builder() { + maybeForceBuilderInitialization(); + } + + private Builder( + com.google.protobuf.GeneratedMessageV3.BuilderParent parent) { + super(parent); + maybeForceBuilderInitialization(); + } + private void maybeForceBuilderInitialization() { + if (com.google.protobuf.GeneratedMessageV3 + .alwaysUseFieldBuilders) { + getMetadataFieldBuilder(); + getMixingDistributionFieldBuilder(); + getMeanVectorsFieldBuilder(); + getCovarianceMatricesFieldBuilder(); + } + } + @java.lang.Override + public Builder clear() { + super.clear(); + bitField0_ = 0; + metadata_ = null; + if (metadataBuilder_ != null) { + metadataBuilder_.dispose(); + metadataBuilder_ = null; + } + mixingDistribution_ = null; + if (mixingDistributionBuilder_ != null) { + mixingDistributionBuilder_.dispose(); + mixingDistributionBuilder_ = null; + } + if (meanVectorsBuilder_ == null) { + meanVectors_ = java.util.Collections.emptyList(); + } else { + meanVectors_ = null; + meanVectorsBuilder_.clear(); + } + bitField0_ = (bitField0_ & ~0x00000004); + if (covarianceMatricesBuilder_ == null) { + covarianceMatrices_ = java.util.Collections.emptyList(); + } else { + covarianceMatrices_ = null; + covarianceMatricesBuilder_.clear(); + } + bitField0_ = (bitField0_ & ~0x00000008); + covarianceType_ = 0; + return this; + } + + @java.lang.Override + public com.google.protobuf.Descriptors.Descriptor + getDescriptorForType() { + return org.tribuo.clustering.gmm.protos.TribuoClusteringGmm.internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_descriptor; + } + + @java.lang.Override + public org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto getDefaultInstanceForType() { + return org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.getDefaultInstance(); + } + + @java.lang.Override + public org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto build() { + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto result = buildPartial(); + if (!result.isInitialized()) { + throw newUninitializedMessageException(result); + } + return result; + } + + @java.lang.Override + public org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto buildPartial() { + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto result = new org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto(this); + buildPartialRepeatedFields(result); + if (bitField0_ != 0) { buildPartial0(result); } + onBuilt(); + return result; + } + + private void buildPartialRepeatedFields(org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto result) { + if (meanVectorsBuilder_ == null) { + if (((bitField0_ & 0x00000004) != 0)) { + meanVectors_ = java.util.Collections.unmodifiableList(meanVectors_); + bitField0_ = (bitField0_ & ~0x00000004); + } + result.meanVectors_ = meanVectors_; + } else { + result.meanVectors_ = meanVectorsBuilder_.build(); + } + if (covarianceMatricesBuilder_ == null) { + if (((bitField0_ & 0x00000008) != 0)) { + covarianceMatrices_ = java.util.Collections.unmodifiableList(covarianceMatrices_); + bitField0_ = (bitField0_ & ~0x00000008); + } + result.covarianceMatrices_ = covarianceMatrices_; + } else { + result.covarianceMatrices_ = covarianceMatricesBuilder_.build(); + } + } + + private void buildPartial0(org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto result) { + int from_bitField0_ = bitField0_; + int to_bitField0_ = 0; + if (((from_bitField0_ & 0x00000001) != 0)) { + result.metadata_ = metadataBuilder_ == null + ? metadata_ + : metadataBuilder_.build(); + to_bitField0_ |= 0x00000001; + } + if (((from_bitField0_ & 0x00000002) != 0)) { + result.mixingDistribution_ = mixingDistributionBuilder_ == null + ? mixingDistribution_ + : mixingDistributionBuilder_.build(); + to_bitField0_ |= 0x00000002; + } + if (((from_bitField0_ & 0x00000010) != 0)) { + result.covarianceType_ = covarianceType_; + } + result.bitField0_ |= to_bitField0_; + } + + @java.lang.Override + public Builder clone() { + return super.clone(); + } + @java.lang.Override + public Builder setField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.setField(field, value); + } + @java.lang.Override + public Builder clearField( + com.google.protobuf.Descriptors.FieldDescriptor field) { + return super.clearField(field); + } + @java.lang.Override + public Builder clearOneof( + com.google.protobuf.Descriptors.OneofDescriptor oneof) { + return super.clearOneof(oneof); + } + @java.lang.Override + public Builder setRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + int index, java.lang.Object value) { + return super.setRepeatedField(field, index, value); + } + @java.lang.Override + public Builder addRepeatedField( + com.google.protobuf.Descriptors.FieldDescriptor field, + java.lang.Object value) { + return super.addRepeatedField(field, value); + } + @java.lang.Override + public Builder mergeFrom(com.google.protobuf.Message other) { + if (other instanceof org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto) { + return mergeFrom((org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto)other); + } else { + super.mergeFrom(other); + return this; + } + } + + public Builder mergeFrom(org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto other) { + if (other == org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.getDefaultInstance()) return this; + if (other.hasMetadata()) { + mergeMetadata(other.getMetadata()); + } + if (other.hasMixingDistribution()) { + mergeMixingDistribution(other.getMixingDistribution()); + } + if (meanVectorsBuilder_ == null) { + if (!other.meanVectors_.isEmpty()) { + if (meanVectors_.isEmpty()) { + meanVectors_ = other.meanVectors_; + bitField0_ = (bitField0_ & ~0x00000004); + } else { + ensureMeanVectorsIsMutable(); + meanVectors_.addAll(other.meanVectors_); + } + onChanged(); + } + } else { + if (!other.meanVectors_.isEmpty()) { + if (meanVectorsBuilder_.isEmpty()) { + meanVectorsBuilder_.dispose(); + meanVectorsBuilder_ = null; + meanVectors_ = other.meanVectors_; + bitField0_ = (bitField0_ & ~0x00000004); + meanVectorsBuilder_ = + com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? + getMeanVectorsFieldBuilder() : null; + } else { + meanVectorsBuilder_.addAllMessages(other.meanVectors_); + } + } + } + if (covarianceMatricesBuilder_ == null) { + if (!other.covarianceMatrices_.isEmpty()) { + if (covarianceMatrices_.isEmpty()) { + covarianceMatrices_ = other.covarianceMatrices_; + bitField0_ = (bitField0_ & ~0x00000008); + } else { + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.addAll(other.covarianceMatrices_); + } + onChanged(); + } + } else { + if (!other.covarianceMatrices_.isEmpty()) { + if (covarianceMatricesBuilder_.isEmpty()) { + covarianceMatricesBuilder_.dispose(); + covarianceMatricesBuilder_ = null; + covarianceMatrices_ = other.covarianceMatrices_; + bitField0_ = (bitField0_ & ~0x00000008); + covarianceMatricesBuilder_ = + com.google.protobuf.GeneratedMessageV3.alwaysUseFieldBuilders ? + getCovarianceMatricesFieldBuilder() : null; + } else { + covarianceMatricesBuilder_.addAllMessages(other.covarianceMatrices_); + } + } + } + if (other.covarianceType_ != 0) { + setCovarianceTypeValue(other.getCovarianceTypeValue()); + } + this.mergeUnknownFields(other.getUnknownFields()); + onChanged(); + return this; + } + + @java.lang.Override + public final boolean isInitialized() { + return true; + } + + @java.lang.Override + public Builder mergeFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws java.io.IOException { + if (extensionRegistry == null) { + throw new java.lang.NullPointerException(); + } + try { + boolean done = false; + while (!done) { + int tag = input.readTag(); + switch (tag) { + case 0: + done = true; + break; + case 10: { + input.readMessage( + getMetadataFieldBuilder().getBuilder(), + extensionRegistry); + bitField0_ |= 0x00000001; + break; + } // case 10 + case 18: { + input.readMessage( + getMixingDistributionFieldBuilder().getBuilder(), + extensionRegistry); + bitField0_ |= 0x00000002; + break; + } // case 18 + case 26: { + org.tribuo.math.protos.TensorProto m = + input.readMessage( + org.tribuo.math.protos.TensorProto.parser(), + extensionRegistry); + if (meanVectorsBuilder_ == null) { + ensureMeanVectorsIsMutable(); + meanVectors_.add(m); + } else { + meanVectorsBuilder_.addMessage(m); + } + break; + } // case 26 + case 34: { + org.tribuo.math.protos.TensorProto m = + input.readMessage( + org.tribuo.math.protos.TensorProto.parser(), + extensionRegistry); + if (covarianceMatricesBuilder_ == null) { + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.add(m); + } else { + covarianceMatricesBuilder_.addMessage(m); + } + break; + } // case 34 + case 40: { + covarianceType_ = input.readEnum(); + bitField0_ |= 0x00000010; + break; + } // case 40 + default: { + if (!super.parseUnknownField(input, extensionRegistry, tag)) { + done = true; // was an endgroup tag + } + break; + } // default: + } // switch (tag) + } // while (!done) + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.unwrapIOException(); + } finally { + onChanged(); + } // finally + return this; + } + private int bitField0_; + + private org.tribuo.protos.core.ModelDataProto metadata_; + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.protos.core.ModelDataProto, org.tribuo.protos.core.ModelDataProto.Builder, org.tribuo.protos.core.ModelDataProtoOrBuilder> metadataBuilder_; + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return Whether the metadata field is set. + */ + public boolean hasMetadata() { + return ((bitField0_ & 0x00000001) != 0); + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return The metadata. + */ + public org.tribuo.protos.core.ModelDataProto getMetadata() { + if (metadataBuilder_ == null) { + return metadata_ == null ? org.tribuo.protos.core.ModelDataProto.getDefaultInstance() : metadata_; + } else { + return metadataBuilder_.getMessage(); + } + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public Builder setMetadata(org.tribuo.protos.core.ModelDataProto value) { + if (metadataBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + metadata_ = value; + } else { + metadataBuilder_.setMessage(value); + } + bitField0_ |= 0x00000001; + onChanged(); + return this; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public Builder setMetadata( + org.tribuo.protos.core.ModelDataProto.Builder builderForValue) { + if (metadataBuilder_ == null) { + metadata_ = builderForValue.build(); + } else { + metadataBuilder_.setMessage(builderForValue.build()); + } + bitField0_ |= 0x00000001; + onChanged(); + return this; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public Builder mergeMetadata(org.tribuo.protos.core.ModelDataProto value) { + if (metadataBuilder_ == null) { + if (((bitField0_ & 0x00000001) != 0) && + metadata_ != null && + metadata_ != org.tribuo.protos.core.ModelDataProto.getDefaultInstance()) { + getMetadataBuilder().mergeFrom(value); + } else { + metadata_ = value; + } + } else { + metadataBuilder_.mergeFrom(value); + } + if (metadata_ != null) { + bitField0_ |= 0x00000001; + onChanged(); + } + return this; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public Builder clearMetadata() { + bitField0_ = (bitField0_ & ~0x00000001); + metadata_ = null; + if (metadataBuilder_ != null) { + metadataBuilder_.dispose(); + metadataBuilder_ = null; + } + onChanged(); + return this; + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public org.tribuo.protos.core.ModelDataProto.Builder getMetadataBuilder() { + bitField0_ |= 0x00000001; + onChanged(); + return getMetadataFieldBuilder().getBuilder(); + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + public org.tribuo.protos.core.ModelDataProtoOrBuilder getMetadataOrBuilder() { + if (metadataBuilder_ != null) { + return metadataBuilder_.getMessageOrBuilder(); + } else { + return metadata_ == null ? + org.tribuo.protos.core.ModelDataProto.getDefaultInstance() : metadata_; + } + } + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.protos.core.ModelDataProto, org.tribuo.protos.core.ModelDataProto.Builder, org.tribuo.protos.core.ModelDataProtoOrBuilder> + getMetadataFieldBuilder() { + if (metadataBuilder_ == null) { + metadataBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.protos.core.ModelDataProto, org.tribuo.protos.core.ModelDataProto.Builder, org.tribuo.protos.core.ModelDataProtoOrBuilder>( + getMetadata(), + getParentForChildren(), + isClean()); + metadata_ = null; + } + return metadataBuilder_; + } + + private org.tribuo.math.protos.TensorProto mixingDistribution_; + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> mixingDistributionBuilder_; + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + * @return Whether the mixingDistribution field is set. + */ + public boolean hasMixingDistribution() { + return ((bitField0_ & 0x00000002) != 0); + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + * @return The mixingDistribution. + */ + public org.tribuo.math.protos.TensorProto getMixingDistribution() { + if (mixingDistributionBuilder_ == null) { + return mixingDistribution_ == null ? org.tribuo.math.protos.TensorProto.getDefaultInstance() : mixingDistribution_; + } else { + return mixingDistributionBuilder_.getMessage(); + } + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + public Builder setMixingDistribution(org.tribuo.math.protos.TensorProto value) { + if (mixingDistributionBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + mixingDistribution_ = value; + } else { + mixingDistributionBuilder_.setMessage(value); + } + bitField0_ |= 0x00000002; + onChanged(); + return this; + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + public Builder setMixingDistribution( + org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (mixingDistributionBuilder_ == null) { + mixingDistribution_ = builderForValue.build(); + } else { + mixingDistributionBuilder_.setMessage(builderForValue.build()); + } + bitField0_ |= 0x00000002; + onChanged(); + return this; + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + public Builder mergeMixingDistribution(org.tribuo.math.protos.TensorProto value) { + if (mixingDistributionBuilder_ == null) { + if (((bitField0_ & 0x00000002) != 0) && + mixingDistribution_ != null && + mixingDistribution_ != org.tribuo.math.protos.TensorProto.getDefaultInstance()) { + getMixingDistributionBuilder().mergeFrom(value); + } else { + mixingDistribution_ = value; + } + } else { + mixingDistributionBuilder_.mergeFrom(value); + } + if (mixingDistribution_ != null) { + bitField0_ |= 0x00000002; + onChanged(); + } + return this; + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + public Builder clearMixingDistribution() { + bitField0_ = (bitField0_ & ~0x00000002); + mixingDistribution_ = null; + if (mixingDistributionBuilder_ != null) { + mixingDistributionBuilder_.dispose(); + mixingDistributionBuilder_ = null; + } + onChanged(); + return this; + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + public org.tribuo.math.protos.TensorProto.Builder getMixingDistributionBuilder() { + bitField0_ |= 0x00000002; + onChanged(); + return getMixingDistributionFieldBuilder().getBuilder(); + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + public org.tribuo.math.protos.TensorProtoOrBuilder getMixingDistributionOrBuilder() { + if (mixingDistributionBuilder_ != null) { + return mixingDistributionBuilder_.getMessageOrBuilder(); + } else { + return mixingDistribution_ == null ? + org.tribuo.math.protos.TensorProto.getDefaultInstance() : mixingDistribution_; + } + } + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + private com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> + getMixingDistributionFieldBuilder() { + if (mixingDistributionBuilder_ == null) { + mixingDistributionBuilder_ = new com.google.protobuf.SingleFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder>( + getMixingDistribution(), + getParentForChildren(), + isClean()); + mixingDistribution_ = null; + } + return mixingDistributionBuilder_; + } + + private java.util.List meanVectors_ = + java.util.Collections.emptyList(); + private void ensureMeanVectorsIsMutable() { + if (!((bitField0_ & 0x00000004) != 0)) { + meanVectors_ = new java.util.ArrayList(meanVectors_); + bitField0_ |= 0x00000004; + } + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> meanVectorsBuilder_; + + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public java.util.List getMeanVectorsList() { + if (meanVectorsBuilder_ == null) { + return java.util.Collections.unmodifiableList(meanVectors_); + } else { + return meanVectorsBuilder_.getMessageList(); + } + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public int getMeanVectorsCount() { + if (meanVectorsBuilder_ == null) { + return meanVectors_.size(); + } else { + return meanVectorsBuilder_.getCount(); + } + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public org.tribuo.math.protos.TensorProto getMeanVectors(int index) { + if (meanVectorsBuilder_ == null) { + return meanVectors_.get(index); + } else { + return meanVectorsBuilder_.getMessage(index); + } + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder setMeanVectors( + int index, org.tribuo.math.protos.TensorProto value) { + if (meanVectorsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureMeanVectorsIsMutable(); + meanVectors_.set(index, value); + onChanged(); + } else { + meanVectorsBuilder_.setMessage(index, value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder setMeanVectors( + int index, org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (meanVectorsBuilder_ == null) { + ensureMeanVectorsIsMutable(); + meanVectors_.set(index, builderForValue.build()); + onChanged(); + } else { + meanVectorsBuilder_.setMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder addMeanVectors(org.tribuo.math.protos.TensorProto value) { + if (meanVectorsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureMeanVectorsIsMutable(); + meanVectors_.add(value); + onChanged(); + } else { + meanVectorsBuilder_.addMessage(value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder addMeanVectors( + int index, org.tribuo.math.protos.TensorProto value) { + if (meanVectorsBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureMeanVectorsIsMutable(); + meanVectors_.add(index, value); + onChanged(); + } else { + meanVectorsBuilder_.addMessage(index, value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder addMeanVectors( + org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (meanVectorsBuilder_ == null) { + ensureMeanVectorsIsMutable(); + meanVectors_.add(builderForValue.build()); + onChanged(); + } else { + meanVectorsBuilder_.addMessage(builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder addMeanVectors( + int index, org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (meanVectorsBuilder_ == null) { + ensureMeanVectorsIsMutable(); + meanVectors_.add(index, builderForValue.build()); + onChanged(); + } else { + meanVectorsBuilder_.addMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder addAllMeanVectors( + java.lang.Iterable values) { + if (meanVectorsBuilder_ == null) { + ensureMeanVectorsIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, meanVectors_); + onChanged(); + } else { + meanVectorsBuilder_.addAllMessages(values); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder clearMeanVectors() { + if (meanVectorsBuilder_ == null) { + meanVectors_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000004); + onChanged(); + } else { + meanVectorsBuilder_.clear(); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public Builder removeMeanVectors(int index) { + if (meanVectorsBuilder_ == null) { + ensureMeanVectorsIsMutable(); + meanVectors_.remove(index); + onChanged(); + } else { + meanVectorsBuilder_.remove(index); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public org.tribuo.math.protos.TensorProto.Builder getMeanVectorsBuilder( + int index) { + return getMeanVectorsFieldBuilder().getBuilder(index); + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public org.tribuo.math.protos.TensorProtoOrBuilder getMeanVectorsOrBuilder( + int index) { + if (meanVectorsBuilder_ == null) { + return meanVectors_.get(index); } else { + return meanVectorsBuilder_.getMessageOrBuilder(index); + } + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public java.util.List + getMeanVectorsOrBuilderList() { + if (meanVectorsBuilder_ != null) { + return meanVectorsBuilder_.getMessageOrBuilderList(); + } else { + return java.util.Collections.unmodifiableList(meanVectors_); + } + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public org.tribuo.math.protos.TensorProto.Builder addMeanVectorsBuilder() { + return getMeanVectorsFieldBuilder().addBuilder( + org.tribuo.math.protos.TensorProto.getDefaultInstance()); + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public org.tribuo.math.protos.TensorProto.Builder addMeanVectorsBuilder( + int index) { + return getMeanVectorsFieldBuilder().addBuilder( + index, org.tribuo.math.protos.TensorProto.getDefaultInstance()); + } + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + public java.util.List + getMeanVectorsBuilderList() { + return getMeanVectorsFieldBuilder().getBuilderList(); + } + private com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> + getMeanVectorsFieldBuilder() { + if (meanVectorsBuilder_ == null) { + meanVectorsBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder>( + meanVectors_, + ((bitField0_ & 0x00000004) != 0), + getParentForChildren(), + isClean()); + meanVectors_ = null; + } + return meanVectorsBuilder_; + } + + private java.util.List covarianceMatrices_ = + java.util.Collections.emptyList(); + private void ensureCovarianceMatricesIsMutable() { + if (!((bitField0_ & 0x00000008) != 0)) { + covarianceMatrices_ = new java.util.ArrayList(covarianceMatrices_); + bitField0_ |= 0x00000008; + } + } + + private com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> covarianceMatricesBuilder_; + + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public java.util.List getCovarianceMatricesList() { + if (covarianceMatricesBuilder_ == null) { + return java.util.Collections.unmodifiableList(covarianceMatrices_); + } else { + return covarianceMatricesBuilder_.getMessageList(); + } + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public int getCovarianceMatricesCount() { + if (covarianceMatricesBuilder_ == null) { + return covarianceMatrices_.size(); + } else { + return covarianceMatricesBuilder_.getCount(); + } + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public org.tribuo.math.protos.TensorProto getCovarianceMatrices(int index) { + if (covarianceMatricesBuilder_ == null) { + return covarianceMatrices_.get(index); + } else { + return covarianceMatricesBuilder_.getMessage(index); + } + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder setCovarianceMatrices( + int index, org.tribuo.math.protos.TensorProto value) { + if (covarianceMatricesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.set(index, value); + onChanged(); + } else { + covarianceMatricesBuilder_.setMessage(index, value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder setCovarianceMatrices( + int index, org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (covarianceMatricesBuilder_ == null) { + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.set(index, builderForValue.build()); + onChanged(); + } else { + covarianceMatricesBuilder_.setMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder addCovarianceMatrices(org.tribuo.math.protos.TensorProto value) { + if (covarianceMatricesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.add(value); + onChanged(); + } else { + covarianceMatricesBuilder_.addMessage(value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder addCovarianceMatrices( + int index, org.tribuo.math.protos.TensorProto value) { + if (covarianceMatricesBuilder_ == null) { + if (value == null) { + throw new NullPointerException(); + } + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.add(index, value); + onChanged(); + } else { + covarianceMatricesBuilder_.addMessage(index, value); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder addCovarianceMatrices( + org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (covarianceMatricesBuilder_ == null) { + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.add(builderForValue.build()); + onChanged(); + } else { + covarianceMatricesBuilder_.addMessage(builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder addCovarianceMatrices( + int index, org.tribuo.math.protos.TensorProto.Builder builderForValue) { + if (covarianceMatricesBuilder_ == null) { + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.add(index, builderForValue.build()); + onChanged(); + } else { + covarianceMatricesBuilder_.addMessage(index, builderForValue.build()); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder addAllCovarianceMatrices( + java.lang.Iterable values) { + if (covarianceMatricesBuilder_ == null) { + ensureCovarianceMatricesIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, covarianceMatrices_); + onChanged(); + } else { + covarianceMatricesBuilder_.addAllMessages(values); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder clearCovarianceMatrices() { + if (covarianceMatricesBuilder_ == null) { + covarianceMatrices_ = java.util.Collections.emptyList(); + bitField0_ = (bitField0_ & ~0x00000008); + onChanged(); + } else { + covarianceMatricesBuilder_.clear(); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public Builder removeCovarianceMatrices(int index) { + if (covarianceMatricesBuilder_ == null) { + ensureCovarianceMatricesIsMutable(); + covarianceMatrices_.remove(index); + onChanged(); + } else { + covarianceMatricesBuilder_.remove(index); + } + return this; + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public org.tribuo.math.protos.TensorProto.Builder getCovarianceMatricesBuilder( + int index) { + return getCovarianceMatricesFieldBuilder().getBuilder(index); + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public org.tribuo.math.protos.TensorProtoOrBuilder getCovarianceMatricesOrBuilder( + int index) { + if (covarianceMatricesBuilder_ == null) { + return covarianceMatrices_.get(index); } else { + return covarianceMatricesBuilder_.getMessageOrBuilder(index); + } + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public java.util.List + getCovarianceMatricesOrBuilderList() { + if (covarianceMatricesBuilder_ != null) { + return covarianceMatricesBuilder_.getMessageOrBuilderList(); + } else { + return java.util.Collections.unmodifiableList(covarianceMatrices_); + } + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public org.tribuo.math.protos.TensorProto.Builder addCovarianceMatricesBuilder() { + return getCovarianceMatricesFieldBuilder().addBuilder( + org.tribuo.math.protos.TensorProto.getDefaultInstance()); + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public org.tribuo.math.protos.TensorProto.Builder addCovarianceMatricesBuilder( + int index) { + return getCovarianceMatricesFieldBuilder().addBuilder( + index, org.tribuo.math.protos.TensorProto.getDefaultInstance()); + } + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + public java.util.List + getCovarianceMatricesBuilderList() { + return getCovarianceMatricesFieldBuilder().getBuilderList(); + } + private com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder> + getCovarianceMatricesFieldBuilder() { + if (covarianceMatricesBuilder_ == null) { + covarianceMatricesBuilder_ = new com.google.protobuf.RepeatedFieldBuilderV3< + org.tribuo.math.protos.TensorProto, org.tribuo.math.protos.TensorProto.Builder, org.tribuo.math.protos.TensorProtoOrBuilder>( + covarianceMatrices_, + ((bitField0_ & 0x00000008) != 0), + getParentForChildren(), + isClean()); + covarianceMatrices_ = null; + } + return covarianceMatricesBuilder_; + } + + private int covarianceType_ = 0; + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @return The enum numeric value on the wire for covarianceType. + */ + @java.lang.Override public int getCovarianceTypeValue() { + return covarianceType_; + } + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @param value The enum numeric value on the wire for covarianceType to set. + * @return This builder for chaining. + */ + public Builder setCovarianceTypeValue(int value) { + covarianceType_ = value; + bitField0_ |= 0x00000010; + onChanged(); + return this; + } + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @return The covarianceType. + */ + @java.lang.Override + public org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto getCovarianceType() { + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto result = org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto.forNumber(covarianceType_); + return result == null ? org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto.UNRECOGNIZED : result; + } + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @param value The covarianceType to set. + * @return This builder for chaining. + */ + public Builder setCovarianceType(org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto value) { + if (value == null) { + throw new NullPointerException(); + } + bitField0_ |= 0x00000010; + covarianceType_ = value.getNumber(); + onChanged(); + return this; + } + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @return This builder for chaining. + */ + public Builder clearCovarianceType() { + bitField0_ = (bitField0_ & ~0x00000010); + covarianceType_ = 0; + onChanged(); + return this; + } + @java.lang.Override + public final Builder setUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.setUnknownFields(unknownFields); + } + + @java.lang.Override + public final Builder mergeUnknownFields( + final com.google.protobuf.UnknownFieldSet unknownFields) { + return super.mergeUnknownFields(unknownFields); + } + + + // @@protoc_insertion_point(builder_scope:tribuo.clustering.gmm.GaussianMixtureModelProto) + } + + // @@protoc_insertion_point(class_scope:tribuo.clustering.gmm.GaussianMixtureModelProto) + private static final org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto DEFAULT_INSTANCE; + static { + DEFAULT_INSTANCE = new org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto(); + } + + public static org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto getDefaultInstance() { + return DEFAULT_INSTANCE; + } + + private static final com.google.protobuf.Parser + PARSER = new com.google.protobuf.AbstractParser() { + @java.lang.Override + public GaussianMixtureModelProto parsePartialFrom( + com.google.protobuf.CodedInputStream input, + com.google.protobuf.ExtensionRegistryLite extensionRegistry) + throws com.google.protobuf.InvalidProtocolBufferException { + Builder builder = newBuilder(); + try { + builder.mergeFrom(input, extensionRegistry); + } catch (com.google.protobuf.InvalidProtocolBufferException e) { + throw e.setUnfinishedMessage(builder.buildPartial()); + } catch (com.google.protobuf.UninitializedMessageException e) { + throw e.asInvalidProtocolBufferException().setUnfinishedMessage(builder.buildPartial()); + } catch (java.io.IOException e) { + throw new com.google.protobuf.InvalidProtocolBufferException(e) + .setUnfinishedMessage(builder.buildPartial()); + } + return builder.buildPartial(); + } + }; + + public static com.google.protobuf.Parser parser() { + return PARSER; + } + + @java.lang.Override + public com.google.protobuf.Parser getParserForType() { + return PARSER; + } + + @java.lang.Override + public org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto getDefaultInstanceForType() { + return DEFAULT_INSTANCE; + } + +} + diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/GaussianMixtureModelProtoOrBuilder.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/GaussianMixtureModelProtoOrBuilder.java new file mode 100644 index 000000000..3b4527f49 --- /dev/null +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/GaussianMixtureModelProtoOrBuilder.java @@ -0,0 +1,99 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tribuo-clustering-gmm.proto + +// Protobuf Java Version: 3.25.3 +package org.tribuo.clustering.gmm.protos; + +public interface GaussianMixtureModelProtoOrBuilder extends + // @@protoc_insertion_point(interface_extends:tribuo.clustering.gmm.GaussianMixtureModelProto) + com.google.protobuf.MessageOrBuilder { + + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return Whether the metadata field is set. + */ + boolean hasMetadata(); + /** + * .tribuo.core.ModelDataProto metadata = 1; + * @return The metadata. + */ + org.tribuo.protos.core.ModelDataProto getMetadata(); + /** + * .tribuo.core.ModelDataProto metadata = 1; + */ + org.tribuo.protos.core.ModelDataProtoOrBuilder getMetadataOrBuilder(); + + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + * @return Whether the mixingDistribution field is set. + */ + boolean hasMixingDistribution(); + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + * @return The mixingDistribution. + */ + org.tribuo.math.protos.TensorProto getMixingDistribution(); + /** + * .tribuo.math.TensorProto mixing_distribution = 2; + */ + org.tribuo.math.protos.TensorProtoOrBuilder getMixingDistributionOrBuilder(); + + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + java.util.List + getMeanVectorsList(); + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + org.tribuo.math.protos.TensorProto getMeanVectors(int index); + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + int getMeanVectorsCount(); + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + java.util.List + getMeanVectorsOrBuilderList(); + /** + * repeated .tribuo.math.TensorProto mean_vectors = 3; + */ + org.tribuo.math.protos.TensorProtoOrBuilder getMeanVectorsOrBuilder( + int index); + + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + java.util.List + getCovarianceMatricesList(); + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + org.tribuo.math.protos.TensorProto getCovarianceMatrices(int index); + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + int getCovarianceMatricesCount(); + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + java.util.List + getCovarianceMatricesOrBuilderList(); + /** + * repeated .tribuo.math.TensorProto covariance_matrices = 4; + */ + org.tribuo.math.protos.TensorProtoOrBuilder getCovarianceMatricesOrBuilder( + int index); + + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @return The enum numeric value on the wire for covarianceType. + */ + int getCovarianceTypeValue(); + /** + * .tribuo.clustering.gmm.GaussianMixtureModelProto.CovarianceTypeProto covariance_type = 5; + * @return The covarianceType. + */ + org.tribuo.clustering.gmm.protos.GaussianMixtureModelProto.CovarianceTypeProto getCovarianceType(); +} diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/TribuoClusteringGmm.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/TribuoClusteringGmm.java new file mode 100644 index 000000000..bbed5dd42 --- /dev/null +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/protos/TribuoClusteringGmm.java @@ -0,0 +1,64 @@ +// Generated by the protocol buffer compiler. DO NOT EDIT! +// source: tribuo-clustering-gmm.proto + +// Protobuf Java Version: 3.25.3 +package org.tribuo.clustering.gmm.protos; + +public final class TribuoClusteringGmm { + private TribuoClusteringGmm() {} + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistryLite registry) { + } + + public static void registerAllExtensions( + com.google.protobuf.ExtensionRegistry registry) { + registerAllExtensions( + (com.google.protobuf.ExtensionRegistryLite) registry); + } + static final com.google.protobuf.Descriptors.Descriptor + internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_descriptor; + static final + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable + internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_fieldAccessorTable; + + public static com.google.protobuf.Descriptors.FileDescriptor + getDescriptor() { + return descriptor; + } + private static com.google.protobuf.Descriptors.FileDescriptor + descriptor; + static { + java.lang.String[] descriptorData = { + "\n\033tribuo-clustering-gmm.proto\022\025tribuo.cl" + + "ustering.gmm\032\021tribuo-core.proto\032\021tribuo-" + + "math.proto\"\205\003\n\031GaussianMixtureModelProto" + + "\022-\n\010metadata\030\001 \001(\0132\033.tribuo.core.ModelDa" + + "taProto\0225\n\023mixing_distribution\030\002 \001(\0132\030.t" + + "ribuo.math.TensorProto\022.\n\014mean_vectors\030\003" + + " \003(\0132\030.tribuo.math.TensorProto\0225\n\023covari" + + "ance_matrices\030\004 \003(\0132\030.tribuo.math.Tensor" + + "Proto\022]\n\017covariance_type\030\005 \001(\0162D.tribuo." + + "clustering.gmm.GaussianMixtureModelProto" + + ".CovarianceTypeProto\"<\n\023CovarianceTypePr" + + "oto\022\010\n\004FULL\020\000\022\014\n\010DIAGONAL\020\001\022\r\n\tSPHERICAL" + + "\020\002B$\n org.tribuo.clustering.gmm.protosP\001" + + "b\006proto3" + }; + descriptor = com.google.protobuf.Descriptors.FileDescriptor + .internalBuildGeneratedFileFrom(descriptorData, + new com.google.protobuf.Descriptors.FileDescriptor[] { + org.tribuo.protos.core.TribuoCore.getDescriptor(), + org.tribuo.math.protos.TribuoMath.getDescriptor(), + }); + internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_descriptor = + getDescriptor().getMessageTypes().get(0); + internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_fieldAccessorTable = new + com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( + internal_static_tribuo_clustering_gmm_GaussianMixtureModelProto_descriptor, + new java.lang.String[] { "Metadata", "MixingDistribution", "MeanVectors", "CovarianceMatrices", "CovarianceType", }); + org.tribuo.protos.core.TribuoCore.getDescriptor(); + org.tribuo.math.protos.TribuoMath.getDescriptor(); + } + + // @@protoc_insertion_point(outer_class_scope) +} diff --git a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java index 5527525b4..021d68016 100644 --- a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java +++ b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java @@ -39,14 +39,14 @@ import static org.junit.jupiter.api.Assertions.assertThrows; /** - * Smoke tests for k-means. + * Smoke tests for GMM. */ public class TestGMM { - private static final GMMTrainer t = new GMMTrainer(4, 10, MultivariateNormalDistribution.CovarianceType.DIAGONAL, + private static final GMMTrainer t = new GMMTrainer(5, 10, MultivariateNormalDistribution.CovarianceType.DIAGONAL, GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1); - private static final GMMTrainer plusPlus = new GMMTrainer(4, 10, MultivariateNormalDistribution.CovarianceType.FULL, + private static final GMMTrainer plusPlus = new GMMTrainer(5, 10, MultivariateNormalDistribution.CovarianceType.FULL, GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 1, 1); @BeforeAll diff --git a/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java b/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java index 56cf32b26..d5c1febf0 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java @@ -185,6 +185,19 @@ public static DenseMatrix createDenseMatrix(SGDVector[] vectors) { return new DenseMatrix(newValues); } + /** + * Creates an identity matrix of the specified size. + * @param dimension The matrix dimension. + * @return The identity matrix. + */ + public static DenseMatrix createIdentity(int dimension) { + double[][] newValues = new double[dimension][dimension]; + for (int i = 0; i < dimension; i++) { + newValues[i][i] = 1.0; + } + return new DenseMatrix(newValues); + } + /** * Deserialization factory. * @param version The serialized object version. diff --git a/pom.xml b/pom.xml index dc07926a9..d5b0f42a4 100644 --- a/pom.xml +++ b/pom.xml @@ -57,8 +57,8 @@ 5.9.1 - 5.7.1 - 3.19.6 + 5.9 + 3.25.3 2.14.3 @@ -85,12 +85,6 @@ Oracle Labs https://labs.oracle.com - - Kate Silverstein - kate.silverstein@oracle.com - Oracle Labs - https://labs.oracle.com - Stephen Green stephen.x.green@oracle.com @@ -154,58 +148,57 @@ org.apache.maven.plugins maven-surefire-plugin - 3.0.0-M5 + 3.2.5 org.apache.maven.plugins maven-resources-plugin - - 2.7 + 3.3.1 org.apache.maven.plugins maven-source-plugin - 3.2.1 + 3.3.1 org.apache.maven.plugins maven-javadoc-plugin - 3.4.1 + 3.6.3 org.apache.maven.plugins maven-release-plugin - 3.0.0-M4 + 3.0.1 org.apache.maven.plugins maven-dependency-plugin - 3.2.0 + 3.6.1 org.apache.maven.plugins maven-deploy-plugin - 3.0.0-M1 + 3.1.1 org.apache.maven.plugins maven-assembly-plugin - 3.3.0 + 3.7.1 org.apache.maven.plugins maven-enforcer-plugin - 3.0.0 + 3.4.1 org.apache.maven.plugins maven-site-plugin - 3.9.1 + 3.12.1 org.apache.maven.plugins maven-install-plugin - 3.0.0-M1 + 3.1.1
@@ -213,7 +206,6 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.4.1 -Xmaxerrs @@ -245,7 +237,7 @@ ./Core/src/main/javadoc/overview.html - Copyright © 2015–2022 Oracle and/or its affiliates. All rights reserved. + Copyright © 2015–2024 Oracle and/or its affiliates. All rights reserved. Core Packages @@ -308,7 +300,7 @@ false true ./Core/src/main/javadoc/overview.html - Copyright © 2015–2022 Oracle and/or its affiliates. All rights reserved. + Copyright © 2015–2024 Oracle and/or its affiliates. All rights reserved. Core Packages @@ -357,14 +349,10 @@ org.apache.maven.plugins maven-compiler-plugin - 3.8.1 + 3.13.0 17 - 17 - 17 17 - 17 - 17 -Xlint:all @@ -373,7 +361,6 @@ org.apache.maven.plugins maven-enforcer-plugin - 3.2.1 enforce @@ -394,14 +381,13 @@ org.apache.maven.plugins maven-assembly-plugin - 3.3.0
- + arm true @@ -415,7 +401,6 @@ org.apache.maven.plugins maven-source-plugin - 3.2.1 attach-sources @@ -436,9 +421,8 @@ org.apache.maven.plugins maven-javadoc-plugin - 3.4.1 - Copyright © 2015–2022 Oracle and/or its affiliates. All rights reserved. + Copyright © 2015–2024 Oracle and/or its affiliates. All rights reserved. From b8a9b3b918f637272697052661b27225673563e5 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 28 Apr 2024 21:26:42 -0400 Subject: [PATCH 10/15] Fix bugs in MultivariateNormalDistribution, Cholesky.determinant, LU.determinant, SparseVector.subtract. --- .../MultivariateNormalDistribution.java | 2 +- .../java/org/tribuo/math/la/DenseMatrix.java | 4 ++-- .../java/org/tribuo/math/la/DenseVector.java | 20 +++++++++++++++---- .../java/org/tribuo/math/la/SparseVector.java | 3 ++- .../org/tribuo/math/la/SparseVectorTest.java | 8 ++++++++ 5 files changed, 29 insertions(+), 8 deletions(-) diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java index c36c9e16d..89878751f 100644 --- a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java +++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java @@ -319,7 +319,7 @@ public static double logProbability(SGDVector input, DenseVector mean, Tensor pr yield diff.sum(); } }; - return scalar + (0.5 * distance); + return scalar - (0.5 * distance); } @Override diff --git a/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java b/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java index d5c1febf0..1dca0e8b2 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseMatrix.java @@ -1540,7 +1540,7 @@ public int dim2() { */ @Override public double determinant() { - double det = 0.0; + double det = 1.0; for (int i = 0; i < lMatrix.dim1; i++) { det *= lMatrix.values[i][i] * lMatrix.values[i][i]; } @@ -1711,7 +1711,7 @@ public int dim2() { */ @Override public double determinant() { - double det = 0.0; + double det = 1.0; for (int i = 0; i < upper.dim1; i++) { det *= upper.values[i][i]; } diff --git a/Math/src/main/java/org/tribuo/math/la/DenseVector.java b/Math/src/main/java/org/tribuo/math/la/DenseVector.java index 6902dfcde..f9d9a05fd 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseVector.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseVector.java @@ -363,8 +363,14 @@ public DenseVector add(SGDVector other) { throw new IllegalArgumentException("Can't add two vectors of different dimension, this = " + elements.length + ", other = " + other.size()); } double[] newValues = toArray(); - for (VectorTuple tuple : other) { - newValues[tuple.index] += tuple.value; + if (other instanceof DenseVector otherDense) { + for (int i = 0; i < newValues.length; i++) { + newValues[i] += otherDense.get(i); + } + } else { + for (VectorTuple tuple : other) { + newValues[tuple.index] += tuple.value; + } } return new DenseVector(newValues); } @@ -380,8 +386,14 @@ public DenseVector subtract(SGDVector other) { throw new IllegalArgumentException("Can't subtract two vectors of different dimension, this = " + elements.length + ", other = " + other.size()); } double[] newValues = toArray(); - for (VectorTuple tuple : other) { - newValues[tuple.index] -= tuple.value; + if (other instanceof DenseVector otherDense) { + for (int i = 0; i < newValues.length; i++) { + newValues[i] -= otherDense.get(i); + } + } else { + for (VectorTuple tuple : other) { + newValues[tuple.index] -= tuple.value; + } } return new DenseVector(newValues); } diff --git a/Math/src/main/java/org/tribuo/math/la/SparseVector.java b/Math/src/main/java/org/tribuo/math/la/SparseVector.java index e973a8b3b..febc63cb2 100644 --- a/Math/src/main/java/org/tribuo/math/la/SparseVector.java +++ b/Math/src/main/java/org/tribuo/math/la/SparseVector.java @@ -467,8 +467,9 @@ public SGDVector subtract(SGDVector other) { } if (other instanceof DenseVector) { DenseVector output = ((DenseVector)other).copy(); + output.scaleInPlace(-1.0); for (VectorTuple tuple : this) { - output.set(tuple.index,tuple.value-output.get(tuple.index)); + output.set(tuple.index,tuple.value+output.get(tuple.index)); } return output; } else if (other instanceof SparseVector) { diff --git a/Math/src/test/java/org/tribuo/math/la/SparseVectorTest.java b/Math/src/test/java/org/tribuo/math/la/SparseVectorTest.java index bcdf497d5..c644fa094 100644 --- a/Math/src/test/java/org/tribuo/math/la/SparseVectorTest.java +++ b/Math/src/test/java/org/tribuo/math/la/SparseVectorTest.java @@ -91,6 +91,11 @@ private SparseVector generateVectorASubC() { return SparseVector.createSparseVector(10,indices,values); } + private DenseVector generateVectorASubOnes() { + double[] values = new double[]{0.0, 1.0, -1.0, -1.0, 2.0, 3.0, -1.0, -1.0, 4.0, -1.0}; + return DenseVector.createDenseVector(values); + } + private SparseVector generateVectorBSubA() { int[] indices = new int[]{0,1,4,5,8}; double[] values = new double[]{-2.0,0.0,-6.0,0.0,-10.0}; @@ -290,6 +295,9 @@ public void subtract() { assertEquals(bSubC, b.subtract(c), "B - C"); assertEquals(cSubA, c.subtract(a), "C - A"); assertEquals(cSubB, c.subtract(b), "C - B"); + + DenseVector ones = new DenseVector(10, 1.0); + assertEquals(generateVectorASubOnes(), a.subtract(ones), "A - Ones"); } public static SparseVector invert(SparseVector input) { From a818f75d91d8b72ba8748f744ef1d40dac4cddf9 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 28 Apr 2024 21:27:01 -0400 Subject: [PATCH 11/15] Fixing bugs in GMM. --- .../org/tribuo/clustering/gmm/GMMTrainer.java | 16 +++++---- .../clustering/gmm/GaussianMixtureModel.java | 7 ++-- .../org/tribuo/clustering/gmm/TestGMM.java | 36 +++++++++++-------- 3 files changed, 36 insertions(+), 23 deletions(-) diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index f4b940345..b257fb594 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -376,13 +376,14 @@ public GaussianMixtureModel train(Dataset examples, Map examples, Map examples, Map examples, Map 1e-5) { + throw new IllegalStateException("Invalid protobuf, covariance was not spherical, diagonal elements not all equal, found " + covariance + "."); + } } covariances[i] = covariance; } else { diff --git a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java index 021d68016..2938ecd8a 100644 --- a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java +++ b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java @@ -43,10 +43,13 @@ */ public class TestGMM { - private static final GMMTrainer t = new GMMTrainer(5, 10, MultivariateNormalDistribution.CovarianceType.DIAGONAL, + private static final GMMTrainer diagonal = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.DIAGONAL, GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1); - private static final GMMTrainer plusPlus = new GMMTrainer(5, 10, MultivariateNormalDistribution.CovarianceType.FULL, + private static final GMMTrainer plusPlusFull = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.FULL, + GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 1, 1); + + private static final GMMTrainer plusPlusSpherical = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.SPHERICAL, GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 1, 1); @BeforeAll @@ -59,12 +62,17 @@ public static void setup() { @Test public void testEvaluation() { - runEvaluation(t); + runEvaluation(diagonal); + } + + @Test + public void testPlusPlusSphericalEvaluation() { + runEvaluation(plusPlusSpherical); } @Test - public void testPlusPlusEvaluation() { - runEvaluation(plusPlus); + public void testPlusPlusFullEvaluation() { + runEvaluation(plusPlusFull); } public static void runEvaluation(GMMTrainer trainer) { @@ -100,13 +108,13 @@ public static Model runDenseData(GMMTrainer trainer) { @Test public void testDenseData() { - Model model = runDenseData(t); + Model model = runDenseData(diagonal); Helpers.testModelSerialization(model,ClusterID.class); } @Test public void testPlusPlusDenseData() { - runDenseData(plusPlus); + runDenseData(plusPlusFull); } public void runSparseData(GMMTrainer trainer) { @@ -116,12 +124,12 @@ public void runSparseData(GMMTrainer trainer) { @Test public void testSparseData() { - runSparseData(t); + runSparseData(diagonal); } @Test public void testPlusPlusSparseData() { - runSparseData(plusPlus); + runSparseData(plusPlusFull); } public void runInvalidExample(GMMTrainer trainer) { @@ -134,12 +142,12 @@ public void runInvalidExample(GMMTrainer trainer) { @Test public void testInvalidExample() { - runInvalidExample(t); + runInvalidExample(diagonal); } @Test public void testPlusPlusInvalidExample() { - runInvalidExample(plusPlus); + runInvalidExample(plusPlusFull); } @@ -153,19 +161,19 @@ public void runEmptyExample(GMMTrainer trainer) { @Test public void testEmptyExample() { - runEmptyExample(t); + runEmptyExample(diagonal); } @Test public void testPlusPlusEmptyExample() { - runEmptyExample(plusPlus); + runEmptyExample(plusPlusFull); } @Test public void testPlusPlusTooManyCentroids() { assertThrows(IllegalArgumentException.class, () -> { Dataset data = ClusteringDataGenerator.gaussianClusters(3, 1L); - plusPlus.train(data); + plusPlusFull.train(data); }); } From ceaeb218e25c1af4df14d2687fbf82fbce30fee4 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 28 Apr 2024 21:58:03 -0400 Subject: [PATCH 12/15] Small tidy ups to Math. --- .../distributions/MultivariateNormalDistribution.java | 9 ++++++++- Math/src/main/java/org/tribuo/math/la/DenseVector.java | 2 +- Math/src/main/java/org/tribuo/math/la/SparseVector.java | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java index 89878751f..b05e8c892 100644 --- a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java +++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java @@ -361,6 +361,13 @@ public int value() { return value; } + /** + * Convert enum value into enum instance, used for serialization. + *

+ * Throws {@link IllegalArgumentException} if the enum value is out of range. + * @param value The enum value. + * @return The enum type. + */ public static CovarianceType fromValue(int value) { CovarianceType[] values = CovarianceType.values(); for (CovarianceType t : values) { @@ -369,7 +376,7 @@ public static CovarianceType fromValue(int value) { } } // Failed to find the enum. - throw new IllegalStateException("Invalid CovarianceType enum value, found " + value); + throw new IllegalArgumentException("Invalid CovarianceType enum value, found " + value); } } } diff --git a/Math/src/main/java/org/tribuo/math/la/DenseVector.java b/Math/src/main/java/org/tribuo/math/la/DenseVector.java index f9d9a05fd..2e473d74c 100644 --- a/Math/src/main/java/org/tribuo/math/la/DenseVector.java +++ b/Math/src/main/java/org/tribuo/math/la/DenseVector.java @@ -580,7 +580,7 @@ public void set(int index, double value) { * @param value The value to set things to. */ public void set(double value) { - Arrays.fill(elements, 0); + Arrays.fill(elements, value); } /** diff --git a/Math/src/main/java/org/tribuo/math/la/SparseVector.java b/Math/src/main/java/org/tribuo/math/la/SparseVector.java index febc63cb2..6b6b46fce 100644 --- a/Math/src/main/java/org/tribuo/math/la/SparseVector.java +++ b/Math/src/main/java/org/tribuo/math/la/SparseVector.java @@ -469,7 +469,7 @@ public SGDVector subtract(SGDVector other) { DenseVector output = ((DenseVector)other).copy(); output.scaleInPlace(-1.0); for (VectorTuple tuple : this) { - output.set(tuple.index,tuple.value+output.get(tuple.index)); + output.set(tuple.index, tuple.value + output.get(tuple.index)); } return output; } else if (other instanceof SparseVector) { From 3f4614b6c8d1b25652a86bbc9460d29b74744e09 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 28 Apr 2024 22:19:31 -0400 Subject: [PATCH 13/15] Fixing diagonal and spherical coveriance estimation. --- .../org/tribuo/clustering/gmm/GMMTrainer.java | 35 ++++++------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index b257fb594..9571b4348 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -375,18 +375,10 @@ public GaussianMixtureModel train(Dataset examples, Map a * a); + diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j)); + curCov.intersectAndAddInPlace(diff); } return input; }; @@ -394,19 +386,12 @@ public GaussianMixtureModel train(Dataset examples, Map a * a); + diff.scaleInPlace(v.responsibility.get(j) / newMixingDistribution.get(j)); + double mean = diff.sum() / numFeatures; + diff.set(mean); + curCov.intersectAndAddInPlace(diff); } return input; }; From 855592672b6dcb39333b2b11c3d95c975cde4a29 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 1 May 2024 20:28:42 -0400 Subject: [PATCH 14/15] Adding a mixture distribution and a distribution interface. --- Core/src/main/java/org/tribuo/util/Util.java | 19 +++ .../math/distributions/Distribution.java | 51 ++++++++ .../distributions/MixtureDistribution.java | 113 ++++++++++++++++++ .../MultivariateNormalDistribution.java | 12 +- 4 files changed, 186 insertions(+), 9 deletions(-) create mode 100644 Math/src/main/java/org/tribuo/math/distributions/Distribution.java create mode 100644 Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java diff --git a/Core/src/main/java/org/tribuo/util/Util.java b/Core/src/main/java/org/tribuo/util/Util.java index 2db850193..06db14fd3 100644 --- a/Core/src/main/java/org/tribuo/util/Util.java +++ b/Core/src/main/java/org/tribuo/util/Util.java @@ -429,6 +429,25 @@ public static double[] generateCDF(double[] pmf) { return cumulativeSum(pmf); } + /** + * Validates that the supplied double array is a probability mass function. + *

+ * That is, each element is bounded 0,1 and all elements sum to 1. + * @param pmf The PMF to check. + * @return True if it's a valid pmf. + */ + public static boolean validatePMF(double[] pmf) { + double total = 0.0; + for (double v : pmf) { + if ((v < 0) || (v > 1.0)) { + return false; + } else { + total += v; + } + } + return !(Math.abs(total - 1.0) > 1e-10); + } + /** * Produces a cumulative sum array. * @param input The input to sum. diff --git a/Math/src/main/java/org/tribuo/math/distributions/Distribution.java b/Math/src/main/java/org/tribuo/math/distributions/Distribution.java new file mode 100644 index 000000000..944e0d5f4 --- /dev/null +++ b/Math/src/main/java/org/tribuo/math/distributions/Distribution.java @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.math.distributions; + +import org.tribuo.math.la.DenseVector; + +import java.util.random.RandomGenerator; + +/** + * Interface for probability distributions which can be sampled from. + *

+ * The vector sampled represents a single sample from that (possibly multivariate) + * distribution rather than a sequence of samples. + */ +public interface Distribution { + + /** + * Sample a single vector from this probability distribution. + * @return A vector sampled from the distribution. + */ + DenseVector sampleVector(); + + /** + * Sample a single vector from this probability distribution using the supplied RNG. + * @param otherRNG The RNG to use. + * @return A vector sampled from this distribution. + */ + DenseVector sampleVector(RandomGenerator otherRNG); + + /** + * Sample a vector from this probability distribution and return it as an array. + * @return An array sampled from this distribution. + */ + default double[] sampleArray() { + return sampleVector().toArray(); + } +} diff --git a/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java new file mode 100644 index 000000000..7e1f7f70a --- /dev/null +++ b/Math/src/main/java/org/tribuo/math/distributions/MixtureDistribution.java @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.math.distributions; + +import org.tribuo.math.la.DenseVector; +import org.tribuo.util.Util; + +import java.util.Arrays; +import java.util.List; +import java.util.SplittableRandom; +import java.util.random.RandomGenerator; + +/** + * A mixture distribution which samples from a set of internal distributions mixed by some probability distribution. + * @param The inner distribution type. + */ +public final class MixtureDistribution implements Distribution { + + private final List dists; + + private final double[] mixingDistribution; + + private final double[] cdf; + + private final RandomGenerator rng; + + private final long seed; + + /** + * Construct a mixture distribution over the supplied components. + * @param distributions The distribution components. + * @param mixingDistribution The mixing distribution, must be a valid PMF. + * @param seed The RNG seed. + */ + public MixtureDistribution(List distributions, DenseVector mixingDistribution, long seed) { + this(distributions, mixingDistribution.toArray(), seed); + } + + /** + * Construct a mixture distribution over the supplied components. + * @param distributions The distribution components. + * @param mixingDistribution The mixing distribution, must be a valid PMF. + * @param seed The RNG seed. + */ + public MixtureDistribution(List distributions, double[] mixingDistribution, long seed) { + this.dists = List.copyOf(distributions); + this.mixingDistribution = Arrays.copyOf(mixingDistribution, mixingDistribution.length); + this.seed = seed; + this.rng = new SplittableRandom(seed); + if (dists.size() != this.mixingDistribution.length) { + throw new IllegalArgumentException("Invalid distribution, expected the same number of components as probabilities, found " + dists.size() + " components, and " + this.mixingDistribution.length + " probabilities"); + } + if (!Util.validatePMF(this.mixingDistribution)) { + throw new IllegalArgumentException("Invalid mixing distribution, was not a valid PMF, found " + Arrays.toString(this.mixingDistribution)); + } + this.cdf = Util.generateCDF(this.mixingDistribution); + } + + /** + * Returns the number of distributions. + * @return The number of distributions. + */ + public int getNumComponents() { + return dists.size(); + } + + /** + * Return a mixture component. + * @param i The index of the mixture component. + * @return The ith component. + */ + public T getComponent(int i) { + return dists.get(i); + } + + /** + * Returns a copy of the mixing distribution. + * @return A copy of the mixing distribution. + */ + public double[] getMixingDistribution() { + return Arrays.copyOf(mixingDistribution, mixingDistribution.length); + } + + @Override + public DenseVector sampleVector() { + return sampleVector(rng); + } + + @Override + public DenseVector sampleVector(RandomGenerator otherRNG) { + int idx = Util.sampleFromCDF(cdf, otherRNG); + return dists.get(idx).sampleVector(); + } + + @Override + public String toString() { + return "Mixture(seed="+seed+",mixingDistribution="+ Arrays.toString(mixingDistribution) +",components="+dists+")"; + } +} diff --git a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java index b05e8c892..bd4e87fd3 100644 --- a/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java +++ b/Math/src/main/java/org/tribuo/math/distributions/MultivariateNormalDistribution.java @@ -30,7 +30,7 @@ /** * A class for sampling from multivariate normal distributions. */ -public final class MultivariateNormalDistribution { +public final class MultivariateNormalDistribution implements Distribution { private final long seed; private final Random rng; @@ -231,6 +231,7 @@ public MultivariateNormalDistribution(DenseVector means, Tensor covariance, Cova * Sample a vector from this multivariate normal distribution. * @return A sample from this distribution. */ + @Override public DenseVector sampleVector() { return sampleVector(rng); } @@ -239,6 +240,7 @@ public DenseVector sampleVector() { * Sample a vector from this multivariate normal distribution. * @return A sample from this distribution. */ + @Override public DenseVector sampleVector(RandomGenerator otherRNG) { DenseVector sampled = new DenseVector(means.size()); for (int i = 0; i < means.size(); i++) { @@ -256,14 +258,6 @@ public DenseVector sampleVector(RandomGenerator otherRNG) { return sampled; } - /** - * Sample a vector from this multivariate normal distribution. - * @return A sample from this distribution. - */ - public double[] sampleArray() { - return sampleVector().toArray(); - } - /** * Gets a copy of the mean vector. * @return A copy of the mean vector. From 356a355ed174dcb0cd9905923b3f94f141e25fe5 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 19 May 2024 20:07:52 -0400 Subject: [PATCH 15/15] Fixing parallel reduction by converting it into collect. --- .../org/tribuo/clustering/gmm/GMMTrainer.java | 36 ++++++++----------- .../org/tribuo/clustering/gmm/TestGMM.java | 11 ++++-- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java index 9571b4348..abcc6fafe 100644 --- a/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java +++ b/Clustering/GMM/src/main/java/org/tribuo/clustering/gmm/GMMTrainer.java @@ -51,9 +51,11 @@ import java.util.SplittableRandom; import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; +import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.BinaryOperator; import java.util.function.Function; +import java.util.function.Supplier; import java.util.function.ToDoubleFunction; import java.util.logging.Level; import java.util.logging.Logger; @@ -342,24 +344,24 @@ public GaussianMixtureModel train(Dataset examples, Map dataMStream = Arrays.stream(data); Stream resMStream = Arrays.stream(responsibilities); Stream zipMStream = StreamUtil.zip(dataMStream, resMStream, Vectors::new); - Tensor[] zeroTensorArr = switch (covarianceType) { - case FULL -> { + Supplier zeroTensor = switch (covarianceType) { + case FULL -> () -> { Tensor[] output = new Tensor[numGaussians]; for (int j = 0; j < numGaussians; j++) { output[j] = new DenseMatrix(numFeatures, numFeatures); } - yield output; - } - case DIAGONAL, SPHERICAL -> { + return output; + }; + case DIAGONAL, SPHERICAL -> () -> { Tensor[] output = new Tensor[numGaussians]; for (int j = 0; j < numGaussians; j++) { output[j] = new DenseVector(numFeatures); } - yield output; - } + return output; + }; }; // Fix parallel behaviour - BiFunction mStep = switch (covarianceType) { + BiConsumer mStep = switch (covarianceType) { case FULL -> (Tensor[] input, Vectors v) -> { for (int j = 0; j < numGaussians; j++) { // Compute covariance contribution from current input @@ -369,7 +371,6 @@ public GaussianMixtureModel train(Dataset examples, Map (Tensor[] input, Vectors v) -> { for (int j = 0; j < numGaussians; j++) { @@ -380,7 +381,6 @@ public GaussianMixtureModel train(Dataset examples, Map (Tensor[] input, Vectors v) -> { for (int j = 0; j < numGaussians; j++) { @@ -393,33 +393,27 @@ public GaussianMixtureModel train(Dataset examples, Map combineTensor = (Tensor[] a, Tensor[] b) -> { - Tensor[] output = new Tensor[a.length]; + BiConsumer combineTensor = (Tensor[] a, Tensor[] b) -> { for (int j = 0; j < a.length; j++) { if (a[j] instanceof DenseMatrix aMat && b[j] instanceof DenseMatrix bMat) { - output[j] = aMat.add(bMat); + aMat.intersectAndAddInPlace(bMat); } else if (a[j] instanceof DenseVector aVec && b[j] instanceof DenseVector bVec) { - output[j] = aVec.add(bVec); + aVec.intersectAndAddInPlace(bVec); } else { throw new IllegalStateException("Invalid types in reduce, expected both DenseMatrix or DenseVector, found " + a[j].getClass() + " and " + b[j].getClass()); } } - return output; }; if (parallel) { - throw new RuntimeException("Parallel mstep not implemented"); - /* try { - covariances = fjp.submit(() -> zipMStream.parallel().reduce(zeroTensorArr, mStep, combineTensor)).get(); + covariances = fjp.submit(() -> zipMStream.parallel().collect(zeroTensor, mStep, combineTensor)).get(); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException("Parallel execution failed", e); } - */ } else { - covariances = zipMStream.reduce(zeroTensorArr, mStep, combineTensor); + covariances = zipMStream.collect(zeroTensor, mStep, combineTensor); } // renormalize mixing distribution diff --git a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java index 2938ecd8a..a2fc3c13c 100644 --- a/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java +++ b/Clustering/GMM/src/test/java/org/tribuo/clustering/gmm/TestGMM.java @@ -46,6 +46,9 @@ public class TestGMM { private static final GMMTrainer diagonal = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.DIAGONAL, GMMTrainer.Initialisation.RANDOM, 1e-3, 1, 1); + private static final GMMTrainer fullParallel = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.FULL, + GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 4, 1); + private static final GMMTrainer plusPlusFull = new GMMTrainer(5, 50, MultivariateNormalDistribution.CovarianceType.FULL, GMMTrainer.Initialisation.PLUSPLUS, 1e-3, 1, 1); @@ -75,6 +78,11 @@ public void testPlusPlusFullEvaluation() { runEvaluation(plusPlusFull); } + @Test + public void testParallelEvaluation() { + runEvaluation(fullParallel); + } + public static void runEvaluation(GMMTrainer trainer) { Dataset data = new MutableDataset<>(new GaussianClusterDataSource(500, 1L)); Dataset test = ClusteringDataGenerator.gaussianClusters(500, 2L); @@ -150,7 +158,6 @@ public void testPlusPlusInvalidExample() { runInvalidExample(plusPlusFull); } - public void runEmptyExample(GMMTrainer trainer) { assertThrows(IllegalArgumentException.class, () -> { Pair, Dataset> p = ClusteringDataGenerator.denseTrainTest(); @@ -186,7 +193,7 @@ public void testSetInvocationCount() { // The number of times to call train before final training. // Original trainer will be trained numOfInvocations + 1 times - // New trainer will have it's invocation count set to numOfInvocations then trained once + // New trainer will have its invocation count set to numOfInvocations then trained once int numOfInvocations = 2; // Create the first model and train it numOfInvocations + 1 times