Skip to content

Commit

Permalink
Documentation updates for 4.2 (#205)
Browse files Browse the repository at this point in the history
* Updating docs for 4.2.

* Migrating StripProvenance over to use the ProvenanceSerialization interface.

* Updating the configuration tutorial.

* Initial readme updates for 4.2

* Adding the start of the onnx export tutorial.

* Adding first draft of release notes for 4.2.

* Fixing the JEP 290 filter.

* Changing Model.castModel so it's not static.

* Adding a reproducibility tutorial and updating the irises and onnx export tutorials.

* Updating gitignore file.

* Updating reproducibility tutorial with a bigger diff example.

* Updating the v4.2 release notes.

* Adding TF-Java PR number.

* Updating docs for the HDBSCAN implementation.

* Updating 4.2 release notes.

* Adding Tribuo v4.1.1 release notes.

* Finishing ONNX export tutorial.

* Updating ONNX export tutorial.

* Docs updates after rebase.

* CastModel fix in OCIModelCLI.

* Javadoc updates after rebase.

* Fixing the circular PR reference in the 4.2 release notes.

* Fixing an accidental deletion in the external models tutorial.
  • Loading branch information
Craigacp authored Dec 18, 2021
1 parent de2e5d4 commit 8058657
Show file tree
Hide file tree
Showing 34 changed files with 3,313 additions and 129 deletions.
20 changes: 18 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,32 @@ bin/
.*.swp

# Other files
*.jar
*.class
*.er
*.log
*.bck
*.so
*.patch

# Binaries
*.jar
*.class

# Archives
*.gz
*.zip

# Serialised models
*.ser

# Temporary stuff
junk/*
.DS_Store
.ipynb_checkpoints

# Profiling files
*.jfr
*.iprof
*.jfc

# Tutorial files
tutorials/*.svm
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ protected LibSVMTrainer() {}
/**
* Constructs a LibSVMTrainer from the parameters.
* @param parameters The SVM parameters.
* @param seed The RNG seed.
*/
protected LibSVMTrainer(SVMParameters<T> parameters, long seed) {
this.parameters = parameters.getParameters();
Expand Down
15 changes: 8 additions & 7 deletions Core/src/main/java/org/tribuo/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -300,20 +300,21 @@ public String toString() {

/**
* Casts the model to the specified output type, assuming it is valid.
* <p>
* If it's not valid, throws {@link ClassCastException}.
* @param inputModel The model to cast.
* <p>
* This method is intended for use on a deserialized model to restore it's
* generic type in a safe way.
* @param outputType The output type to cast to.
* @param <T> The output type.
* @param <U> The output type.
* @return The model cast to the correct value.
*/
public static <T extends Output<T>> Model<T> castModel(Model<?> inputModel, Class<T> outputType) {
if (inputModel.validate(outputType)) {
public <U extends Output<U>> Model<U> castModel(Class<U> outputType) {
if (validate(outputType)) {
@SuppressWarnings("unchecked") // guarded by validate
Model<T> castedModel = (Model<T>) inputModel;
Model<U> castedModel = (Model<U>) this;
return castedModel;
} else {
throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + inputModel.toString());
throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + this.toString());
}
}

Expand Down
2 changes: 2 additions & 0 deletions Core/src/main/java/org/tribuo/ensemble/BaggingTrainer.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
* "The Elements of Statistical Learning"
* Springer 2001. <a href="http://web.stanford.edu/~hastie/ElemStatLearn/">PDF</a>
* </pre>
* @param <T> The prediction type.
*/
public class BaggingTrainer<T extends Output<T>> implements Trainer<T> {

Expand Down Expand Up @@ -177,6 +178,7 @@ public EnsembleModel<T> train(Dataset<T> examples, Map<String, Provenance> runPr
* @param labelIDs The output domain.
* @param randInt A random int from an rng instance
* @param runProvenance Provenance for this instance.
* @param invocationCount The invocation count for the inner trainer.
* @return The trained ensemble member.
*/
protected Model<T> trainSingleModel(Dataset<T> examples, ImmutableFeatureMap featureIDs, ImmutableOutputInfo<T> labelIDs, int randInt, Map<String,Provenance> runProvenance, int invocationCount) {
Expand Down
3 changes: 2 additions & 1 deletion Core/src/main/java/org/tribuo/ensemble/EnsembleCombiner.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ default ONNXNode exportCombiner(ONNXNode input) {
* will be required to provide ONNX support.
* @param input the node to be ensembled according to this implementation.
* @param weight The node of weights for ensembling.
* @param <U> The type of the weights input reference.
* @return The leaf node of the graph of operations added to ensemble input.
*/
default <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
default <U extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, U weight) {
Logger.getLogger(this.getClass().getName()).severe("Tried to export an ensemble combiner to ONNX format, but this is not implemented.");
throw new IllegalStateException("This ensemble cannot be exported as the combiner '" + this.getClass() + "' uses the default implementation of EnsembleCombiner.exportCombiner.");
}
Expand Down
4 changes: 2 additions & 2 deletions Data/src/main/java/org/tribuo/data/sql/SQLDBConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private SQLDBConfig() {}
/**
* Constructs a SQL database configuration.
* <p>
* Note it is recommended that wallet based connections are used rather than this constructor using {@link SQLDBConfig(String,Map)}.
* Note it is recommended that wallet based connections are used rather than this constructor using {@link #SQLDBConfig(String,Map)}.
* @param connectionString The connection string.
* @param username The username.
* @param password The password.
Expand All @@ -87,7 +87,7 @@ public SQLDBConfig(String connectionString, String username, String password, Ma
/**
* Constructs a SQL database configuration.
* <p>
* Note it is recommended that wallet based connections are used rather than this constructor using {@link SQLDBConfig(String,Map)}.
* Note it is recommended that wallet based connections are used rather than this constructor using {@link #SQLDBConfig(String,Map)}.
* @param host The host to connect to.
* @param port The port to connect on.
* @param db The db name.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ public static ConfigFileAuthenticationDetailsProvider makeAuthProvider(Path conf
* @param configFile The OCI configuration file, if null use the default file.
* @param endpointURL The endpoint URL.
* @param outputConverter The converter for the specified output type.
* @param <T> The output type.
* @return An OCIModel ready to score new inputs.
*/
public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> factory,
Expand All @@ -332,6 +333,7 @@ public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T>
* @param profileName The profile name in the OCI configuration file, if null uses the default profile.
* @param endpointURL The endpoint URL.
* @param outputConverter The converter for the specified output type.
* @param <T> The output type.
* @return An OCIModel ready to score new inputs.
*/
public static <T extends Output<T>> OCIModel<T> createOCIModel(OutputFactory<T> factory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ private static void createModelAndDeploy(OCIModelOptions options) throws IOExcep
// Load the Tribuo model
Model<Label> model;
try (ObjectInputStream ois = new ObjectInputStream(Files.newInputStream(options.modelPath))) {
model = Model.castModel((Model<?>) ois.readObject(),Label.class);
model = ((Model<?>)ois.readObject()).castModel(Label.class);
}
if (!(model instanceof ONNXExportable)) {
throw new IllegalArgumentException("Model not ONNXExportable, received " + model.toString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ public static <T extends Output<T>, U extends Model<T> & ONNXExportable> String
/**
* Creates the OCI DS model artifact zip file.
* @param onnxFile The ONNX file to create.
* @param config The model artifact configuration.
* @return The path referring to the zip file.
* @throws IOException If the file could not be created or the ONNX file could not be read.
*/
Expand Down
15 changes: 4 additions & 11 deletions Json/src/main/java/org/tribuo/json/StripProvenance.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@

package org.tribuo.json;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.oracle.labs.mlrg.olcut.config.ConfigurationManager;
import com.oracle.labs.mlrg.olcut.config.Option;
import com.oracle.labs.mlrg.olcut.config.Options;
import com.oracle.labs.mlrg.olcut.config.UsageException;
import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceModule;
import com.oracle.labs.mlrg.olcut.config.json.JsonProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance;
import com.oracle.labs.mlrg.olcut.util.IOUtil;
import com.oracle.labs.mlrg.olcut.util.LabsLogFormatter;
Expand Down Expand Up @@ -315,11 +312,8 @@ public static <T extends Output<T>> void main(String[] args) {
ModelProvenance oldProvenance = input.getProvenance();

logger.info("Marshalling provenance and creating JSON.");
List<ObjectMarshalledProvenance> list = ProvenanceUtil.marshalProvenance(oldProvenance);
ObjectMapper mapper = new ObjectMapper();
mapper.registerModule(new JsonProvenanceModule());
mapper.enable(SerializationFeature.INDENT_OUTPUT);
String jsonResult = mapper.writeValueAsString(list);
JsonProvenanceSerialization jsonProvenanceSerialization = new JsonProvenanceSerialization(true);
String jsonResult = jsonProvenanceSerialization.marshalAndSerialize(oldProvenance);

logger.info("Hashing JSON file");
MessageDigest digest = o.hashType.getDigest();
Expand All @@ -340,8 +334,7 @@ public static <T extends Output<T>> void main(String[] args) {

ModelProvenance newProvenance = tuple.provenance;
logger.info("Marshalling provenance and creating JSON.");
List<ObjectMarshalledProvenance> newList = ProvenanceUtil.marshalProvenance(newProvenance);
String newJsonResult = mapper.writeValueAsString(newList);
String newJsonResult = jsonProvenanceSerialization.marshalAndSerialize(newProvenance);

logger.info("Old provenance = \n" + jsonResult);
logger.info("New provenance = \n" + newJsonResult);
Expand Down
25 changes: 17 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ architectures on Windows 10, macOS and Linux (RHEL/OL/CentOS 7+), as these are
supported platforms for the native libraries with which we interface. If you're
interested in another platform and wish to use one of the native library
interfaces (ONNX Runtime, TensorFlow, and XGBoost), we recommend reaching out
to the developers of those libraries.
to the developers of those libraries. Note the reproducibility package
requires Java 17, and as such is not part of the `tribuo-all` Maven Central
deployment.

## Documentation

Expand Down Expand Up @@ -85,6 +87,7 @@ Tribuo has implementations or interfaces for:
|Algorithm|Implementation|Notes|
|---|---|---|
|Linear models|Tribuo|Uses SGD and allows any gradient optimizer|
|Factorization Machines|Tribuo|Uses SGD and allows any gradient optimizer|
|CART|Tribuo||
|SVM-SGD|Tribuo|An implementation of the Pegasos algorithm|
|Adaboost.SAMME|Tribuo|Can use any Tribuo classification trainer as the base learner|
Expand All @@ -109,6 +112,7 @@ output.
|Algorithm|Implementation|Notes|
|---|---|---|
|Linear models|Tribuo|Uses SGD and allows any gradient optimizer|
|Factorization Machines|Tribuo|Uses SGD and allows any gradient optimizer|
|CART|Tribuo||
|Lasso|Tribuo|Using the LARS algorithm|
|Elastic Net|Tribuo|Using the co-ordinate descent algorithm|
Expand All @@ -124,6 +128,7 @@ algorithms over time.

|Algorithm|Implementation|Notes|
|---|---|---|
|HDBSCAN\*|Tribuo||
|K-Means|Tribuo|Includes both sequential and parallel backends, and the K-Means++ initialisation algorithm|

### Anomaly Detection
Expand All @@ -146,7 +151,9 @@ more multi-label specific implementations over time.
|Algorithm|Implementation|Notes|
|---|---|---|
|Independent wrapper|Tribuo|Converts a multi-class classification algorithm into a multi-label one by producing a separate classifier for each label|
|Classifier Chains|Tribuo|Provides classifier chains and randomized classifier chain ensembles using any of Tribuo's multi-class classification algorithms|
|Linear models|Tribuo|Uses SGD and allows any gradient optimizer|
|Factorization Machines|Tribuo|Uses SGD and allows any gradient optimizer|

### Interfaces

Expand All @@ -158,10 +165,10 @@ discuss how it would fit into Tribuo.
Currently we have interfaces to:

* [LibLinear](https://github.com/bwaldvogel/liblinear-java) - via the LibLinear-java port of the original [LibLinear](https://www.csie.ntu.edu.tw/~cjlin/liblinear/) (v2.43).
* [LibSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) - using the pure Java transformed version of the C++ implementation (v3.24).
* [ONNX Runtime](https://onnxruntime.ai) - via the Java API contributed by our group (v1.7.0).
* [TensorFlow](https://tensorflow.org) - Using [TensorFlow Java](https://github.com/tensorflow/java) v0.3.1 (based on TensorFlow v2.4.1). This allows the training and deployment of TensorFlow models entirely in Java.
* [XGBoost](https://xgboost.ai) - via the built in XGBoost4J API (v1.4.1).
* [LibSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) - using the pure Java transformed version of the C++ implementation (v3.25).
* [ONNX Runtime](https://onnxruntime.ai) - via the Java API contributed by our group (v1.9.0).
* [TensorFlow](https://tensorflow.org) - Using [TensorFlow Java](https://github.com/tensorflow/java) v0.4.0 (based on TensorFlow v2.7.0). This allows the training and deployment of TensorFlow models entirely in Java.
* [XGBoost](https://xgboost.ai) - via the built in XGBoost4J API (v1.5.0).

## Binaries

Expand All @@ -187,7 +194,7 @@ implementation ("org.tribuo:tribuo-all:4.1.0@pom") {
```

The `tribuo-all` dependency is a pom which depends on all the Tribuo
subprojects.
subprojects except for the reproducibility project which requires Java 17.

Most of Tribuo is pure Java and thus cross-platform, however some of the
interfaces link to libraries which use native code. Those interfaces
Expand All @@ -197,11 +204,13 @@ are supplied. If you need support for a specific platform, reach out to the
maintainers of those projects. As of the 4.1 release these native packages
all provide x86\_64 binaries for Windows, macOS and Linux. It is also possible
to compile each package for macOS ARM64 (i.e., Apple Silicon), though there are
no binaries available on Maven Central for that platform.
no binaries available on Maven Central for that platform. When developing
on an ARM platform you can select the `arm` profile in Tribuo's pom.xml to
disable the native library tests.

Individual jars are published for each Tribuo module. It is preferable to
depend only on the modules necessary for the specific project. This prevents
your code from unnecessarily pulling in large dependencies like TensorFlow
your code from unnecessarily pulling in large dependencies like TensorFlow.

## Compiling from source

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public ONNXContext() {
* ONNXContext instance. All inputs must belong to the calling instance of ONNXContext. This is the root method for
* constructing ONNXNodes which all other methods on ONNXContext and {@code ONNXRef} call.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@ONNXRef}s created by this instance of ONNXContext.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputs A list of names that the output nodes of {@code op} should take.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}.
* @param <T> The ONNXRef type of inputs
Expand All @@ -82,7 +82,7 @@ public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperators op,
* IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added
* to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@ONNXRef}s created by this instance of ONNXContext.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputName Name that the output node of {@code op} should take.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}.
* @param <T> The ONNXRef type of inputs
Expand All @@ -102,7 +102,7 @@ public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperators op, List<T> input
* IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added
* to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@ONNXRef}s created by this instance of ONNXContext.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputName Name that the output node of {@code op} should take.
* @param <T> The ONNXRef type of inputs
* @return An {@link ONNXNode} that is the output nodes of {@code op}.
Expand Down
25 changes: 21 additions & 4 deletions Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,44 @@
* can thus be passed around without needing to pass their governing context as well.
* <p>
* N.B. This class will be sealed once the library is updated past Java 8. Users should not subclass this class.
* @param <T>
* @param <T> The protobuf type this reference generates.
*/
public abstract class ONNXRef<T extends GeneratedMessageV3> {
// Unfortunately there is no other shared supertype for OnnxML protobufs
protected final T backRef;
private final String baseName;
protected final ONNXContext context;


/**
* Creates an ONNXRef for the specified context, protobuf and name.
* @param context The ONNXContext we're operating in.
* @param backRef The protobuf reference.
* @param baseName The name of this reference.
*/
ONNXRef(ONNXContext context, T backRef, String baseName) {
this.context = context;
this.backRef = backRef;
this.baseName = baseName;
}

/**
* Gets the output name of this object.
* @return The output name.
*/
public abstract String getReference();

/**
* The name of this object.
* @return The name.
*/
public String getBaseName() {
return baseName;
}

/**
* The context this reference operates in.
* @return The context.
*/
public ONNXContext onnxContext() {
return context;
}
Expand All @@ -66,7 +83,7 @@ public ONNXContext onnxContext() {
* as the first argument to {@code inputs}, with {@code otherInputs} append as subsequent arguments. The other
* arguments behave as in the analogous method on ONNXContext.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param otherInputs A list of {@ONNXRef}s created by this instance of ONNXContext.
* @param otherInputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputs A list of names that the output nodes of {@code op} should take.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}.
* @return a list of {@link ONNXNode}s that are the output nodes of {@code op}.
Expand Down Expand Up @@ -199,7 +216,7 @@ public <Ret extends ONNXRef<?>> Ret assignTo(Ret output) {
/**
* Casts this ONNXRef to a different type using the {@link ONNXOperators#CAST} operation, and returning the output
* node of that op. Currently supports only float, double, int, and long, which are specified by their respective
* {@link Class} objects (eg. {@link float.class}). Throws {@link IllegalArgumentException} when an unsupported cast
* {@link Class} objects (e.g., {@code float.class}). Throws {@link IllegalArgumentException} when an unsupported cast
* is requested.
* @param clazz The class object specifying the type to cast to.
* @return An ONNXRef representing this object cast into the requested type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/

/**
* Interfaces and utilities for writing <a href="https://onnx.ai>ONNX</a> models from Java.
* Interfaces and utilities for writing <a href="https://onnx.ai">ONNX</a> models from Java.
* <p>
* Developed to support <a href="https://tribuo.org">Tribuo</a>, but can be used to export
* other machine learning models from JVM languages.
Expand Down
Loading

0 comments on commit 8058657

Please sign in to comment.