Skip to content

Commit

Permalink
Make the ONNX export code accept an op interface rather than the enum…
Browse files Browse the repository at this point in the history
… itself (#245)

* Adding an op building interface which ONNXOperators implements, to allow users to supply custom ops or different opsets.

* Fixing comments.

* Tidying up the javadoc.

* Tidying up the error messages when building an ONNX op.
  • Loading branch information
Craigacp authored Jul 25, 2022
1 parent 30f113e commit 4c12343
Show file tree
Hide file tree
Showing 5 changed files with 305 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ public ONNXNode writeONNXGraph(ONNXRef<?> input) {
// Make feature pow
ONNXNode inputSquared = input.apply(ONNXOperators.POW, twoConst);


List<ONNXNode> embeddingOutputs = new ArrayList<>();
for(int i = 0; i < outputIDInfo.size(); i++) {

Expand Down
25 changes: 13 additions & 12 deletions Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
* ONNXContext or of {@link ONNXRef}s produced by multiple ONNXContexts is not supported.
* <p>
* The ONNXContext has all of the logic needed to produce ONNX graphs, but is typically used explicitly to produce leaf
* nodes of graphs (inputs, outputs, and weight matrices) that have more fluent interfaces to {@link ONNXContext#operation(ONNXOperators, List, List, Map)}.
* nodes of graphs (inputs, outputs, and weight matrices) that have more fluent interfaces to {@link ONNXContext#operation(ONNXOperator, List, List, Map)}.
* Produced ONNX protobuf objects are encapsulated by instances of {@link ONNXRef} and its subclasses.
*/
public final class ONNXContext {
Expand All @@ -51,18 +51,18 @@ public ONNXContext() {
}

/**
* Base method for creating {@link ONNXNode}s from {@link ONNXOperators} and inputs. Returns an instance of ONNXNode
* Base method for creating {@link ONNXNode}s from {@link ONNXOperator} and inputs. Returns an instance of ONNXNode
* for each output of the ONNXOperator. The graph elements created by the operation are added to the calling
* ONNXContext instance. All inputs must belong to the calling instance of ONNXContext. This is the root method for
* constructing ONNXNodes which all other methods on ONNXContext and {@code ONNXRef} call.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputs A list of names that the output nodes of {@code op} should take.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperator#build(ONNXContext, String, String, Map)}.
* @param <T> The ONNXRef type of inputs
* @return a list of {@link ONNXNode}s that are the output nodes of {@code op}.
*/
public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperators op,
public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperator op,
List<T> inputs,
List<String> outputs,
Map<String, Object> attributes) {
Expand All @@ -78,36 +78,37 @@ public <T extends ONNXRef<?>> List<ONNXNode> operation(ONNXOperators op,
}

/**
* Method for creating {@link ONNXNode}s from {@link ONNXOperators} and inputs. Returns a single ONNXNode and throws
* Method for creating {@link ONNXNode}s from {@link ONNXOperator} and inputs. Returns a single ONNXNode and throws
* IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added
* to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputName Name that the output node of {@code op} should take.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}.
* @param attributes A map of attributes of the operation, passed to {@link ONNXOperator#build(ONNXContext, String, String, Map)}.
* @param <T> The ONNXRef type of inputs
* @return An {@link ONNXNode} that is the output nodes of {@code op}.
*/
public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperators op, List<T> inputs, String outputName, Map<String, Object> attributes) {
public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperator op, List<T> inputs, String outputName, Map<String, Object> attributes) {
List<ONNXNode> opOutputs = operation(op, inputs, Collections.singletonList(outputName), attributes);
if(opOutputs.get(0).backRef.getOutputList().size() > 1) {
throw new IllegalStateException("Requested a single output from operation " + op.opName + " which produced " + opOutputs.get(0).backRef.getOutputList().size() + " outputs");
throw new IllegalStateException("Requested a single output from operation " + op.getOpName() + " which produced " + opOutputs.get(0).backRef.getOutputList().size() + " outputs");
} else {
return opOutputs.get(0);
}
}

/**
* Method for creating {@link ONNXNode}s from {@link ONNXOperators} and inputs. Returns a single ONNXNode and throws
* IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added
* to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext.
* Method for creating {@link ONNXNode}s from {@link ONNXOperator} instances and inputs. Returns a single ONNXNode
* and throws IllegalStateException if the operator has multiple outputs. The graph elements created by the
* operation are added to the calling ONNXContext instance. All inputs must belong to the calling instance of
* ONNXContext.
* @param op An ONNXOperator to add to the graph, taking {@code inputs} as input.
* @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext.
* @param outputName Name that the output node of {@code op} should take.
* @param <T> The ONNXRef type of inputs
* @return An {@link ONNXNode} that is the output nodes of {@code op}.
*/
public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperators op, List<T> inputs, String outputName) {
public <T extends ONNXRef<?>> ONNXNode operation(ONNXOperator op, List<T> inputs, String outputName) {
return operation(op, inputs, outputName, Collections.emptyMap());
}

Expand Down
230 changes: 230 additions & 0 deletions Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.util.onnx;

import ai.onnx.proto.OnnxMl;

import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

import static org.tribuo.util.onnx.ONNXAttribute.VARIADIC_INPUT;

/**
* An interface for ONNX operators. Usually implemented by an enum representing the opset.
*/
public interface ONNXOperator {

/**
* The operator name.
*/
public String getOpName();

/**
* The number of inputs.
*/
public int getNumInputs();

/**
* The number of optional inputs.
*/
public int getNumOptionalInputs();

/**
* The number of outputs.
*/
public int getNumOutputs();

/**
* The operator attributes.
*/
public Map<String,ONNXAttribute> getAttributes();

/**
* The mandatory attribute names.
*/
public Set<String> getMandatoryAttributeNames();

/**
* Returns the opset version.
* @return The opset version.
*/
public int getOpVersion();

/**
* Returns the opset domain.
* <p>
* May be {@code null} if it is the default ONNX domain;
* @return The opset domain.
*/
public String getOpDomain();

/**
* Returns the opset proto for these operators.
* @return The opset proto.
*/
default public OnnxMl.OperatorSetIdProto opsetProto() {
return OnnxMl.OperatorSetIdProto.newBuilder().setDomain(getOpDomain()).setVersion(getOpVersion()).build();
}

/**
* Builds this node based on the supplied inputs and output.
* Throws {@link IllegalArgumentException} if this operator takes more than a single input or output.
* @param context The onnx context used to ensure this node has a unique name.
* @param input The name of the input.
* @param output The name of the output.
* @return The NodeProto.
*/
default public OnnxMl.NodeProto build(ONNXContext context, String input, String output) {
return build(context,new String[]{input},new String[]{output}, Collections.emptyMap());
}

/**
* Builds this node based on the supplied inputs and output.
* Throws {@link IllegalArgumentException} if this operator takes more than a single input or output.
* May throw {@link UnsupportedOperationException} if the attribute type is not supported.
* @param context The onnx context used to ensure this node has a unique name.
* @param input The names of the input.
* @param output The name of the output.
* @param attributeValues The attribute names and values.
* @return The NodeProto.
*/
default public OnnxMl.NodeProto build(ONNXContext context, String input, String output, Map<String,Object> attributeValues) {
return build(context,new String[]{input},new String[]{output},attributeValues);
}

/**
* Builds this node based on the supplied inputs and output.
* Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong.
* @param context The onnx context used to ensure this node has a unique name.
* @param inputs The names of the inputs.
* @param output The name of the output.
* @return The NodeProto.
*/
default public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String output) {
return build(context,inputs,new String[]{output},Collections.emptyMap());
}

/**
* Builds this node based on the supplied inputs and output.
* Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong.
* May throw {@link UnsupportedOperationException} if the attribute type is not supported.
* @param context The onnx context used to ensure this node has a unique name.
* @param inputs The names of the inputs.
* @param output The name of the output.
* @param attributeValues The attribute names and values.
* @return The NodeProto.
*/
default public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String output, Map<String,Object> attributeValues) {
return build(context,inputs,new String[]{output},attributeValues);
}

/**
* Builds this node based on the supplied input and outputs.
* Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong.
* @param context The onnx context used to ensure this node has a unique name.
* @param input The name of the input.
* @param outputs The names of the outputs.
* @return The NodeProto.
*/
default public OnnxMl.NodeProto build(ONNXContext context, String input, String[] outputs) {
return build(context,new String[]{input},outputs,Collections.emptyMap());
}

/**
* Builds this node based on the supplied input and outputs.
* Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong.
* May throw {@link UnsupportedOperationException} if the attribute type is not supported.
* @param context The onnx context used to ensure this node has a unique name.
* @param input The name of the input.
* @param outputs The names of the outputs.
* @param attributeValues The attribute names and values.
* @return The NodeProto.
*/
default public OnnxMl.NodeProto build(ONNXContext context, String input, String[] outputs, Map<String,Object> attributeValues) {
return build(context,new String[]{input},outputs,attributeValues);
}

/**
* Builds this node based on the supplied inputs and outputs.
* Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong.
* @param context The onnx context used to ensure this node has a unique name.
* @param inputs The names of the inputs.
* @param outputs The names of the outputs.
* @return The NodeProto.
*/
default public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String[] outputs) {
return build(context,inputs,outputs,Collections.emptyMap());
}

/**
* Builds this node based on the supplied inputs and outputs.
* Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong.
* May throw {@link UnsupportedOperationException} if the attribute type is not supported.
* @param context The onnx context used to ensure this node has a unique name.
* @param inputs The names of the inputs.
* @param outputs The names of the outputs.
* @param attributeValues The attribute names and values.
* @return The NodeProto.
*/
default public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String[] outputs, Map<String,Object> attributeValues) {
int numInputs = getNumInputs();
int numOptionalInputs = getNumOptionalInputs();
int numOutputs = getNumOutputs();
String opName = getOpName();
String domain = getOpDomain();
Map<String, ONNXAttribute> attributes = getAttributes();
Set<String> mandatoryAttributeNames = getMandatoryAttributeNames();

String opStatus = String.format("Building op %s:%s(%d(+%d)) -> %d", domain, opName, numInputs, numOptionalInputs, numOutputs);

if ((numInputs != VARIADIC_INPUT) && ((inputs.length < numInputs) || (inputs.length > numInputs + numOptionalInputs))) {
throw new IllegalArgumentException(opStatus + ". Expected " + numInputs + " inputs, with " + numOptionalInputs + " optional inputs, but received " + inputs.length);
} else if ((numInputs == VARIADIC_INPUT) && (inputs.length == 0)) {
throw new IllegalArgumentException(opStatus + ". Expected at least one input for variadic input, received zero");
}
if (outputs.length != numOutputs) {
throw new IllegalArgumentException(opStatus + ". Expected " + numOutputs + " outputs, but received " + outputs.length);
}
if (!attributes.keySet().containsAll(attributeValues.keySet())) {
throw new IllegalArgumentException(opStatus + ". Unexpected attribute found, received " + attributeValues.keySet() + ", expected values from " + attributes.keySet());
}
if (!attributeValues.keySet().containsAll(mandatoryAttributeNames)) {
throw new IllegalArgumentException(opStatus + ". Expected to find all mandatory attributes, received " + attributeValues.keySet() + ", expected " + mandatoryAttributeNames);
}

Logger.getLogger("org.tribuo.util.onnx.ONNXOperator").fine(opStatus);
OnnxMl.NodeProto.Builder nodeBuilder = OnnxMl.NodeProto.newBuilder();
for (String i : inputs) {
nodeBuilder.addInput(i);
}
for (String o : outputs) {
nodeBuilder.addOutput(o);
}
nodeBuilder.setName(context.generateUniqueName(opName));
nodeBuilder.setOpType(opName);
if (domain != null) {
nodeBuilder.setDomain(domain);
}
for (Map.Entry<String,Object> e : attributeValues.entrySet()) {
ONNXAttribute attr = attributes.get(e.getKey());
nodeBuilder.addAttribute(attr.build(e.getValue()));
}
return nodeBuilder.build();
}
}
Loading

0 comments on commit 4c12343

Please sign in to comment.