Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gaussian Mixture Model implementation #369

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -178,36 +178,7 @@ public static ChainViterbiResults viterbi(ChainCliqueValues scores) {
* @return log sum exp input[i].
*/
public static double sumLogProbs(DenseVector input) {
double LOG_TOLERANCE = 30.0;

double maxValue = input.get(0);
int maxIdx = 0;
for (int i = 1; i < input.size(); i++) {
double value = input.get(i);
if (value > maxValue) {
maxValue = value;
maxIdx = i;
}
}
if (maxValue == Double.NEGATIVE_INFINITY) {
return maxValue;
}

boolean anyAdded = false;
double intermediate = 0.0;
double cutoff = maxValue - LOG_TOLERANCE;
for (int i = 0; i < input.size(); i++) {
double value = input.get(i);
if (value >= cutoff && i != maxIdx && !Double.isInfinite(value)) {
anyAdded = true;
intermediate += Math.exp(value - maxValue);
}
}
if (anyAdded) {
return maxValue + Math.log1p(intermediate);
} else {
return maxValue;
}
return input.logSumExp();
}

/**
Expand Down
98 changes: 98 additions & 0 deletions Clustering/GMM/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-clustering</artifactId>
<version>5.0.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<name>Clustering-GMM</name>
<artifactId>tribuo-clustering-gmm</artifactId>
<packaging>jar</packaging>
<properties>
<maven.compiler.release>17</maven.compiler.release>
</properties>

<dependencies>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-core</artifactId>
</dependency>
<dependency>
<groupId>org.tribuo</groupId>
<artifactId>tribuo-data</artifactId>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-math</artifactId>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-clustering-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.oracle.labs.olcut</groupId>
<artifactId>olcut-core</artifactId>
</dependency>
<!-- test time dependencies -->
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-core</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-assembly-plugin</artifactId>
<configuration>
<descriptorRefs>
<descriptorRef>jar-with-dependencies</descriptorRef>
</descriptorRefs>
</configuration>
<executions>
<execution>
<id>make-assembly</id> <!-- this is used for inheritance merges -->
<phase>package</phase> <!-- bind to the packaging phase -->
<goals>
<goal>single</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.clustering.gmm;

import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import org.tribuo.Trainer;
import org.tribuo.clustering.gmm.GMMTrainer.Initialisation;
import org.tribuo.math.distributions.MultivariateNormalDistribution.CovarianceType;
import org.tribuo.math.distributions.MultivariateNormalDistribution;

import java.util.logging.Logger;

/**
* OLCUT {@link Options} for the GMM implementation.
*/
public class GMMOptions implements Options {
private static final Logger logger = Logger.getLogger(GMMOptions.class.getName());

/**
* Iterations of the GMM algorithm. Defaults to 10.
*/
@Option(longName = "gmm-interations", usage = "Iterations of the GMM algorithm. Defaults to 10.")
public int iterations = 10;
/**
* Number of centroids/Gaussians in GMM. Defaults to 10.
*/
@Option(longName = "gmm-num-centroids", usage = "Number of centroids in GMM. Defaults to 10.")
public int centroids = 10;
/**
* The covariance type of the Gaussians.
*/
@Option(charName = 'v', longName = "covariance-type", usage = "Set the covariance type.")
public CovarianceType covarianceType = MultivariateNormalDistribution.CovarianceType.DIAGONAL;
/**
* Initialisation function in GMM. Defaults to RANDOM.
*/
@Option(longName = "gmm-initialisation", usage = "Initialisation function in GMM. Defaults to RANDOM.")
public Initialisation initialisation = GMMTrainer.Initialisation.RANDOM;
/**
* Convergence tolerance to terminate EM early.
*/
@Option(longName = "gmm-tolerance", usage = "The convergence threshold.")
public double tolerance = 1e-3f;
/**
* Number of computation threads in GMM. Defaults to 4.
*/
@Option(longName = "gmm-num-threads", usage = "Number of computation threads in GMM. Defaults to 4.")
public int numThreads = 4;
/**
* The RNG seed.
*/
@Option(longName = "gmm-seed", usage = "Sets the random seed for GMM.")
public long seed = Trainer.DEFAULT_SEED;

/**
* Gets the configured GMMTrainer using the options in this object.
* @return A GMMTrainer.
*/
public GMMTrainer getTrainer() {
logger.info("Configuring GMM Trainer");
return new GMMTrainer(centroids, iterations, covarianceType, initialisation, tolerance, numThreads, seed);
}
}
Loading