diff --git a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java index ee706522b..00479f468 100644 --- a/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java +++ b/Common/SGD/src/main/java/org/tribuo/common/sgd/AbstractFMModel.java @@ -223,7 +223,6 @@ public ONNXNode writeONNXGraph(ONNXRef input) { // Make feature pow ONNXNode inputSquared = input.apply(ONNXOperators.POW, twoConst); - List embeddingOutputs = new ArrayList<>(); for(int i = 0; i < outputIDInfo.size(); i++) { diff --git a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java index 420850730..f93e6fd28 100644 --- a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXContext.java @@ -33,7 +33,7 @@ * ONNXContext or of {@link ONNXRef}s produced by multiple ONNXContexts is not supported. *

* The ONNXContext has all of the logic needed to produce ONNX graphs, but is typically used explicitly to produce leaf - * nodes of graphs (inputs, outputs, and weight matrices) that have more fluent interfaces to {@link ONNXContext#operation(ONNXOperators, List, List, Map)}. + * nodes of graphs (inputs, outputs, and weight matrices) that have more fluent interfaces to {@link ONNXContext#operation(ONNXOperator, List, List, Map)}. * Produced ONNX protobuf objects are encapsulated by instances of {@link ONNXRef} and its subclasses. */ public final class ONNXContext { @@ -51,18 +51,18 @@ public ONNXContext() { } /** - * Base method for creating {@link ONNXNode}s from {@link ONNXOperators} and inputs. Returns an instance of ONNXNode + * Base method for creating {@link ONNXNode}s from {@link ONNXOperator} and inputs. Returns an instance of ONNXNode * for each output of the ONNXOperator. The graph elements created by the operation are added to the calling * ONNXContext instance. All inputs must belong to the calling instance of ONNXContext. This is the root method for * constructing ONNXNodes which all other methods on ONNXContext and {@code ONNXRef} call. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext. * @param outputs A list of names that the output nodes of {@code op} should take. - * @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}. + * @param attributes A map of attributes of the operation, passed to {@link ONNXOperator#build(ONNXContext, String, String, Map)}. * @param The ONNXRef type of inputs * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public > List operation(ONNXOperators op, + public > List operation(ONNXOperator op, List inputs, List outputs, Map attributes) { @@ -78,36 +78,37 @@ public > List operation(ONNXOperators op, } /** - * Method for creating {@link ONNXNode}s from {@link ONNXOperators} and inputs. Returns a single ONNXNode and throws + * Method for creating {@link ONNXNode}s from {@link ONNXOperator} and inputs. Returns a single ONNXNode and throws * IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added * to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext. * @param outputName Name that the output node of {@code op} should take. - * @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}. + * @param attributes A map of attributes of the operation, passed to {@link ONNXOperator#build(ONNXContext, String, String, Map)}. * @param The ONNXRef type of inputs * @return An {@link ONNXNode} that is the output nodes of {@code op}. */ - public > ONNXNode operation(ONNXOperators op, List inputs, String outputName, Map attributes) { + public > ONNXNode operation(ONNXOperator op, List inputs, String outputName, Map attributes) { List opOutputs = operation(op, inputs, Collections.singletonList(outputName), attributes); if(opOutputs.get(0).backRef.getOutputList().size() > 1) { - throw new IllegalStateException("Requested a single output from operation " + op.opName + " which produced " + opOutputs.get(0).backRef.getOutputList().size() + " outputs"); + throw new IllegalStateException("Requested a single output from operation " + op.getOpName() + " which produced " + opOutputs.get(0).backRef.getOutputList().size() + " outputs"); } else { return opOutputs.get(0); } } /** - * Method for creating {@link ONNXNode}s from {@link ONNXOperators} and inputs. Returns a single ONNXNode and throws - * IllegalStateException if the operator has multiple outputs. The graph elements created by the operation are added - * to the calling ONNXContext instance. All inputs must belong to the calling instance of ONNXContext. + * Method for creating {@link ONNXNode}s from {@link ONNXOperator} instances and inputs. Returns a single ONNXNode + * and throws IllegalStateException if the operator has multiple outputs. The graph elements created by the + * operation are added to the calling ONNXContext instance. All inputs must belong to the calling instance of + * ONNXContext. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param inputs A list of {@link ONNXRef}s created by this instance of ONNXContext. * @param outputName Name that the output node of {@code op} should take. * @param The ONNXRef type of inputs * @return An {@link ONNXNode} that is the output nodes of {@code op}. */ - public > ONNXNode operation(ONNXOperators op, List inputs, String outputName) { + public > ONNXNode operation(ONNXOperator op, List inputs, String outputName) { return operation(op, inputs, outputName, Collections.emptyMap()); } diff --git a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperator.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperator.java new file mode 100644 index 000000000..98f72dfa5 --- /dev/null +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperator.java @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.tribuo.util.onnx; + +import ai.onnx.proto.OnnxMl; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; + +import static org.tribuo.util.onnx.ONNXAttribute.VARIADIC_INPUT; + +/** + * An interface for ONNX operators. Usually implemented by an enum representing the opset. + */ +public interface ONNXOperator { + + /** + * The operator name. + */ + public String getOpName(); + + /** + * The number of inputs. + */ + public int getNumInputs(); + + /** + * The number of optional inputs. + */ + public int getNumOptionalInputs(); + + /** + * The number of outputs. + */ + public int getNumOutputs(); + + /** + * The operator attributes. + */ + public Map getAttributes(); + + /** + * The mandatory attribute names. + */ + public Set getMandatoryAttributeNames(); + + /** + * Returns the opset version. + * @return The opset version. + */ + public int getOpVersion(); + + /** + * Returns the opset domain. + *

+ * May be {@code null} if it is the default ONNX domain; + * @return The opset domain. + */ + public String getOpDomain(); + + /** + * Returns the opset proto for these operators. + * @return The opset proto. + */ + default public OnnxMl.OperatorSetIdProto opsetProto() { + return OnnxMl.OperatorSetIdProto.newBuilder().setDomain(getOpDomain()).setVersion(getOpVersion()).build(); + } + + /** + * Builds this node based on the supplied inputs and output. + * Throws {@link IllegalArgumentException} if this operator takes more than a single input or output. + * @param context The onnx context used to ensure this node has a unique name. + * @param input The name of the input. + * @param output The name of the output. + * @return The NodeProto. + */ + default public OnnxMl.NodeProto build(ONNXContext context, String input, String output) { + return build(context,new String[]{input},new String[]{output}, Collections.emptyMap()); + } + + /** + * Builds this node based on the supplied inputs and output. + * Throws {@link IllegalArgumentException} if this operator takes more than a single input or output. + * May throw {@link UnsupportedOperationException} if the attribute type is not supported. + * @param context The onnx context used to ensure this node has a unique name. + * @param input The names of the input. + * @param output The name of the output. + * @param attributeValues The attribute names and values. + * @return The NodeProto. + */ + default public OnnxMl.NodeProto build(ONNXContext context, String input, String output, Map attributeValues) { + return build(context,new String[]{input},new String[]{output},attributeValues); + } + + /** + * Builds this node based on the supplied inputs and output. + * Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong. + * @param context The onnx context used to ensure this node has a unique name. + * @param inputs The names of the inputs. + * @param output The name of the output. + * @return The NodeProto. + */ + default public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String output) { + return build(context,inputs,new String[]{output},Collections.emptyMap()); + } + + /** + * Builds this node based on the supplied inputs and output. + * Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong. + * May throw {@link UnsupportedOperationException} if the attribute type is not supported. + * @param context The onnx context used to ensure this node has a unique name. + * @param inputs The names of the inputs. + * @param output The name of the output. + * @param attributeValues The attribute names and values. + * @return The NodeProto. + */ + default public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String output, Map attributeValues) { + return build(context,inputs,new String[]{output},attributeValues); + } + + /** + * Builds this node based on the supplied input and outputs. + * Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong. + * @param context The onnx context used to ensure this node has a unique name. + * @param input The name of the input. + * @param outputs The names of the outputs. + * @return The NodeProto. + */ + default public OnnxMl.NodeProto build(ONNXContext context, String input, String[] outputs) { + return build(context,new String[]{input},outputs,Collections.emptyMap()); + } + + /** + * Builds this node based on the supplied input and outputs. + * Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong. + * May throw {@link UnsupportedOperationException} if the attribute type is not supported. + * @param context The onnx context used to ensure this node has a unique name. + * @param input The name of the input. + * @param outputs The names of the outputs. + * @param attributeValues The attribute names and values. + * @return The NodeProto. + */ + default public OnnxMl.NodeProto build(ONNXContext context, String input, String[] outputs, Map attributeValues) { + return build(context,new String[]{input},outputs,attributeValues); + } + + /** + * Builds this node based on the supplied inputs and outputs. + * Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong. + * @param context The onnx context used to ensure this node has a unique name. + * @param inputs The names of the inputs. + * @param outputs The names of the outputs. + * @return The NodeProto. + */ + default public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String[] outputs) { + return build(context,inputs,outputs,Collections.emptyMap()); + } + + /** + * Builds this node based on the supplied inputs and outputs. + * Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong. + * May throw {@link UnsupportedOperationException} if the attribute type is not supported. + * @param context The onnx context used to ensure this node has a unique name. + * @param inputs The names of the inputs. + * @param outputs The names of the outputs. + * @param attributeValues The attribute names and values. + * @return The NodeProto. + */ + default public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String[] outputs, Map attributeValues) { + int numInputs = getNumInputs(); + int numOptionalInputs = getNumOptionalInputs(); + int numOutputs = getNumOutputs(); + String opName = getOpName(); + String domain = getOpDomain(); + Map attributes = getAttributes(); + Set mandatoryAttributeNames = getMandatoryAttributeNames(); + + String opStatus = String.format("Building op %s:%s(%d(+%d)) -> %d", domain, opName, numInputs, numOptionalInputs, numOutputs); + + if ((numInputs != VARIADIC_INPUT) && ((inputs.length < numInputs) || (inputs.length > numInputs + numOptionalInputs))) { + throw new IllegalArgumentException(opStatus + ". Expected " + numInputs + " inputs, with " + numOptionalInputs + " optional inputs, but received " + inputs.length); + } else if ((numInputs == VARIADIC_INPUT) && (inputs.length == 0)) { + throw new IllegalArgumentException(opStatus + ". Expected at least one input for variadic input, received zero"); + } + if (outputs.length != numOutputs) { + throw new IllegalArgumentException(opStatus + ". Expected " + numOutputs + " outputs, but received " + outputs.length); + } + if (!attributes.keySet().containsAll(attributeValues.keySet())) { + throw new IllegalArgumentException(opStatus + ". Unexpected attribute found, received " + attributeValues.keySet() + ", expected values from " + attributes.keySet()); + } + if (!attributeValues.keySet().containsAll(mandatoryAttributeNames)) { + throw new IllegalArgumentException(opStatus + ". Expected to find all mandatory attributes, received " + attributeValues.keySet() + ", expected " + mandatoryAttributeNames); + } + + Logger.getLogger("org.tribuo.util.onnx.ONNXOperator").fine(opStatus); + OnnxMl.NodeProto.Builder nodeBuilder = OnnxMl.NodeProto.newBuilder(); + for (String i : inputs) { + nodeBuilder.addInput(i); + } + for (String o : outputs) { + nodeBuilder.addOutput(o); + } + nodeBuilder.setName(context.generateUniqueName(opName)); + nodeBuilder.setOpType(opName); + if (domain != null) { + nodeBuilder.setDomain(domain); + } + for (Map.Entry e : attributeValues.entrySet()) { + ONNXAttribute attr = attributes.get(e.getKey()); + nodeBuilder.addAttribute(attr.build(e.getValue())); + } + return nodeBuilder.build(); + } +} diff --git a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java index 028a62318..b8af005ca 100644 --- a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXOperators.java @@ -29,9 +29,11 @@ import static org.tribuo.util.onnx.ONNXAttribute.VARIADIC_INPUT; /** - * The supported ONNX operators. + * ONNX Opset 13, and ONNX-ML version 1. + *

+ * In a future version of Tribuo this class will be split into two enums, one for ONNX opset 13 and one for ONNX-ML v1. */ -public enum ONNXOperators { +public enum ONNXOperators implements ONNXOperator { /** * Identity. */ @@ -408,141 +410,44 @@ private ONNXOperators(String value, int numInputs, int numOptionalInputs, int nu this.domain = domain; } - /** - * Builds this node based on the supplied inputs and output. - * Throws {@link IllegalArgumentException} if this operator takes more than a single input or output. - * @param context The onnx context used to ensure this node has a unique name. - * @param input The name of the input. - * @param output The name of the output. - * @return The NodeProto. - */ - public OnnxMl.NodeProto build(ONNXContext context, String input, String output) { - return build(context,new String[]{input},new String[]{output},Collections.emptyMap()); + @Override + public String getOpName() { + return opName; } - /** - * Builds this node based on the supplied inputs and output. - * Throws {@link IllegalArgumentException} if this operator takes more than a single input or output. - * May throw {@link UnsupportedOperationException} if the attribute type is not supported. - * @param context The onnx context used to ensure this node has a unique name. - * @param input The names of the input. - * @param output The name of the output. - * @param attributeValues The attribute names and values. - * @return The NodeProto. - */ - public OnnxMl.NodeProto build(ONNXContext context, String input, String output, Map attributeValues) { - return build(context,new String[]{input},new String[]{output},attributeValues); + @Override + public int getNumInputs() { + return numInputs; } - /** - * Builds this node based on the supplied inputs and output. - * Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong. - * @param context The onnx context used to ensure this node has a unique name. - * @param inputs The names of the inputs. - * @param output The name of the output. - * @return The NodeProto. - */ - public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String output) { - return build(context,inputs,new String[]{output},Collections.emptyMap()); + @Override + public int getNumOptionalInputs() { + return numOptionalInputs; } - /** - * Builds this node based on the supplied inputs and output. - * Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong. - * May throw {@link UnsupportedOperationException} if the attribute type is not supported. - * @param context The onnx context used to ensure this node has a unique name. - * @param inputs The names of the inputs. - * @param output The name of the output. - * @param attributeValues The attribute names and values. - * @return The NodeProto. - */ - public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String output, Map attributeValues) { - return build(context,inputs,new String[]{output},attributeValues); + @Override + public int getNumOutputs() { + return numOutputs; } - /** - * Builds this node based on the supplied input and outputs. - * Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong. - * @param context The onnx context used to ensure this node has a unique name. - * @param input The name of the input. - * @param outputs The names of the outputs. - * @return The NodeProto. - */ - public OnnxMl.NodeProto build(ONNXContext context, String input, String[] outputs) { - return build(context,new String[]{input},outputs,Collections.emptyMap()); + @Override + public Map getAttributes() { + return attributes; } - /** - * Builds this node based on the supplied input and outputs. - * Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong. - * May throw {@link UnsupportedOperationException} if the attribute type is not supported. - * @param context The onnx context used to ensure this node has a unique name. - * @param input The name of the input. - * @param outputs The names of the outputs. - * @param attributeValues The attribute names and values. - * @return The NodeProto. - */ - public OnnxMl.NodeProto build(ONNXContext context, String input, String[] outputs, Map attributeValues) { - return build(context,new String[]{input},outputs,attributeValues); + @Override + public Set getMandatoryAttributeNames() { + return mandatoryAttributeNames; } - /** - * Builds this node based on the supplied inputs and outputs. - * Throws {@link IllegalArgumentException} if the number of inputs or outputs is wrong. - * @param context The onnx context used to ensure this node has a unique name. - * @param inputs The names of the inputs. - * @param outputs The names of the outputs. - * @return The NodeProto. - */ - public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String[] outputs) { - return build(context,inputs,outputs,Collections.emptyMap()); + @Override + public int getOpVersion() { + return getOpsetVersion(); } - /** - * Builds this node based on the supplied inputs and outputs. - * Throws {@link IllegalArgumentException} if the number of inputs, outputs or attributes is wrong. - * May throw {@link UnsupportedOperationException} if the attribute type is not supported. - * @param context The onnx context used to ensure this node has a unique name. - * @param inputs The names of the inputs. - * @param outputs The names of the outputs. - * @param attributeValues The attribute names and values. - * @return The NodeProto. - */ - public OnnxMl.NodeProto build(ONNXContext context, String[] inputs, String[] outputs, Map attributeValues) { - if ((numInputs != VARIADIC_INPUT) && ((inputs.length < numInputs) || (inputs.length > numInputs + numOptionalInputs))) { - throw new IllegalArgumentException("Expected " + numInputs + " inputs, with " + numOptionalInputs + " optional inputs, but received " + inputs.length); - } else if ((numInputs == VARIADIC_INPUT) && (inputs.length == 0)) { - throw new IllegalArgumentException("Expected at least one input for variadic input, received zero"); - } - if (outputs.length != numOutputs) { - throw new IllegalArgumentException("Expected " + numOutputs + " outputs, but received " + outputs.length); - } - if (attributeValues.size() > attributes.size()) { - throw new IllegalArgumentException("Found more attributes than expected, received " + attributeValues.size() + ", expected at most " + attributes.size()); - } - if (!attributes.keySet().containsAll(attributeValues.keySet())) { - throw new IllegalArgumentException("Unexpected attribute found, received " + attributeValues.keySet() + ", expected values from " + attributes.keySet()); - } - if (!attributeValues.keySet().containsAll(mandatoryAttributeNames)) { - throw new IllegalArgumentException("Expected to find all mandatory attributes, received " + attributeValues.keySet() + ", expected " + mandatoryAttributeNames); - } - OnnxMl.NodeProto.Builder nodeBuilder = OnnxMl.NodeProto.newBuilder(); - for (String i : inputs) { - nodeBuilder.addInput(i); - } - for (String o : outputs) { - nodeBuilder.addOutput(o); - } - nodeBuilder.setName(context.generateUniqueName(opName)); - nodeBuilder.setOpType(opName); - if (domain != null) { - nodeBuilder.setDomain(domain); - } - for (Map.Entry e : attributeValues.entrySet()) { - ONNXAttribute attr = attributes.get(e.getKey()); - nodeBuilder.addAttribute(attr.build(e.getValue())); - } - return nodeBuilder.build(); + @Override + public String getOpDomain() { + return domain; } /** diff --git a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java index 197dea39f..8d1ec29eb 100644 --- a/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java +++ b/Util/ONNXExport/src/main/java/org/tribuo/util/onnx/ONNXRef.java @@ -29,7 +29,7 @@ /** * An abstract reference that represents both a node in an ONNX computation graph and a container for a specific ONNX * proto object that denotes that node. In its role as the former it provides a fluent interface for applying - * {@link ONNXOperators} to {@link ONNXRef} instances. ONNXRef instances are ultimately created by an {@link ONNXContext} + * {@link ONNXOperator}s to {@link ONNXRef} instances. ONNXRef instances are ultimately created by an {@link ONNXContext} * instance, and ONNXRefs created by different instances of ONNXContext are incompatible. All ONNX proto objects * produced by calling {@code apply} methods on ONNXRefs are added to a {@link ai.onnx.proto.OnnxMl.GraphProto} field * in their governing ONNXContext. Instances of ONNXRef have a backreference to the ONNXContext that created them and @@ -79,16 +79,16 @@ public ONNXContext onnxContext() { } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, List, Map)}, using this ONNXRef + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, List, Map)}, using this ONNXRef * as the first argument to {@code inputs}, with {@code otherInputs} append as subsequent arguments. The other * arguments behave as in the analogous method on ONNXContext. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param otherInputs A list of {@link ONNXRef}s created by this instance of ONNXContext. * @param outputs A list of names that the output nodes of {@code op} should take. - * @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}. + * @param attributes A map of attributes of the operation, passed to {@link ONNXOperator#build(ONNXContext, String, String, Map)}. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public List apply(ONNXOperators op, List> otherInputs, List outputs, Map attributes) { + public List apply(ONNXOperator op, List> otherInputs, List outputs, Map attributes) { List> allInputs = new ArrayList<>(); allInputs.add(this); allInputs.addAll(otherInputs); @@ -96,109 +96,109 @@ public List apply(ONNXOperators op, List> otherInputs, List } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, List, Map)}, using this ONNXRef + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, List, Map)}, using this ONNXRef * as the argument to {@code inputs}. The other arguments behave as in the analogous method on ONNXContext. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param outputs A list of names that the output nodes of {@code op} should take. - * @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}. + * @param attributes A map of attributes of the operation, passed to {@link ONNXOperator#build(ONNXContext, String, String, Map)}. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public List apply(ONNXOperators op, List outputs, Map attributes) { + public List apply(ONNXOperator op, List outputs, Map attributes) { return context.operation(op, Collections.singletonList(this), outputs, attributes); } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, String)}, using this ONNXRef - * as the argument to {@code inputs}. Output names are generated based on the {@link ONNXOperators#opName} and the + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, String)}, using this ONNXRef + * as the argument to {@code inputs}. Output names are generated based on the {@link ONNXOperator#getOpName} and the * name of the input nodes. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public ONNXNode apply(ONNXOperators op) { - return context.operation(op, Collections.singletonList(this), getBaseName() + "_" + op.opName, Collections.emptyMap()); + public ONNXNode apply(ONNXOperator op) { + return context.operation(op, Collections.singletonList(this), getBaseName() + "_" + op.getOpName(), Collections.emptyMap()); } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, String)}, using this ONNXRef + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, String)}, using this ONNXRef * as the argument to {@code inputs}. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param outputName A name that the output node of {@code op} will take. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public ONNXNode apply(ONNXOperators op, String outputName) { + public ONNXNode apply(ONNXOperator op, String outputName) { return context.operation(op, Collections.singletonList(this), outputName, Collections.emptyMap()); } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, String, Map)}, using this ONNXRef - * as the argument to {@code inputs}. Output names are generated based on the {@link ONNXOperators#opName} and the + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, String, Map)}, using this ONNXRef + * as the argument to {@code inputs}. Output names are generated based on the {@link ONNXOperator#getOpName} and the * name of the input nodes. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. - * @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}. + * @param attributes A map of attributes of the operation, passed to {@link ONNXOperator#build(ONNXContext, String, String, Map)}. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public ONNXNode apply(ONNXOperators op, Map attributes) { - return context.operation(op, Collections.singletonList(this), getBaseName() + "_" + op.opName, attributes); + public ONNXNode apply(ONNXOperator op, Map attributes) { + return context.operation(op, Collections.singletonList(this), getBaseName() + "_" + op.getOpName(), attributes); } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, String, Map)}, passing this ONNXRef + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, String, Map)}, passing this ONNXRef * and {@code other} as a length 2 list to {@code inputs}. The other arguments behave as in the analogous method on - * ONNXContext. Output names are generated based on the {@link ONNXOperators#opName} and the name of the input nodes. + * ONNXContext. Output names are generated based on the {@link ONNXOperator#getOpName} and the name of the input nodes. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param other A second input argument to {@code op} - * @param attributes A map of attributes of the operation, passed to {@link ONNXOperators#build(ONNXContext, String, String, Map)}. + * @param attributes A map of attributes of the operation, passed to {@link ONNXOperator#build(ONNXContext, String, String, Map)}. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public ONNXNode apply(ONNXOperators op, ONNXRef other, Map attributes) { - return context.operation(op, Arrays.asList(this, other), getBaseName() + "_" + op.opName + "_" + other.getBaseName(), attributes); + public ONNXNode apply(ONNXOperator op, ONNXRef other, Map attributes) { + return context.operation(op, Arrays.asList(this, other), getBaseName() + "_" + op.getOpName() + "_" + other.getBaseName(), attributes); } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, String)}, passing this ONNXRef + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, String)}, passing this ONNXRef * and {@code other} as a length 2 list to {@code inputs}. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param other A second input argument to {@code op} * @param outputName A name that the output node of {@code op} will take. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public ONNXNode apply(ONNXOperators op, ONNXRef other, String outputName) { + public ONNXNode apply(ONNXOperator op, ONNXRef other, String outputName) { return context.operation(op, Arrays.asList(this, other), outputName, Collections.emptyMap()); } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, String, Map)}, passing this ONNXRef + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, String, Map)}, passing this ONNXRef * and {@code other} as a length 2 list to {@code inputs}. Output names are generated based on the - * {@link ONNXOperators#opName} and the name of the input nodes. + * {@link ONNXOperator#getOpName} and the name of the input nodes. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param other A second input argument to {@code op} * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public ONNXNode apply(ONNXOperators op, ONNXRef other) { - return context.operation(op, Arrays.asList(this, other), getBaseName() + "_" + op.opName + "_" + other.getBaseName(), Collections.emptyMap()); + public ONNXNode apply(ONNXOperator op, ONNXRef other) { + return context.operation(op, Arrays.asList(this, other), getBaseName() + "_" + op.getOpName() + "_" + other.getBaseName(), Collections.emptyMap()); } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, String, Map)}, using this ONNXRef + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, String, Map)}, using this ONNXRef * as the first argument to {@code inputs}, with {@code otherInputs} append as subsequent arguments. Output names - * are generated based on the {@link ONNXOperators#opName} and the name of the input nodes. + * are generated based on the {@link ONNXOperator#getOpName} and the name of the input nodes. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param others List of ONNXRefs supplied as inputs to {@code op} after this ONNXRef. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public ONNXNode apply(ONNXOperators op, List> others) { + public ONNXNode apply(ONNXOperator op, List> others) { return apply(op, others, Collections.singletonList(getBaseName() + "_" + others.stream().map(ONNXRef::getBaseName).collect(Collectors.joining("_"))), Collections.emptyMap()).get(0); } /** - * Convenience method that calls {@link ONNXContext#operation(ONNXOperators, List, String, Map)}, using this ONNXRef + * Convenience method that calls {@link ONNXContext#operation(ONNXOperator, List, String, Map)}, using this ONNXRef * as the argument to {@code inputs}, with {@code otherInputs} append as subsequent arguments. * @param op An ONNXOperator to add to the graph, taking {@code inputs} as input. * @param others List of ONNXRefs supplied as inputs to {@code op} after this ONNXRef. * @param outputName The name for the constructed node. * @return a list of {@link ONNXNode}s that are the output nodes of {@code op}. */ - public ONNXNode apply(ONNXOperators op, List> others, String outputName) { + public ONNXNode apply(ONNXOperator op, List> others, String outputName) { return apply(op, others, Collections.singletonList(outputName), Collections.emptyMap()).get(0); }