From 5f47771f9434dcc7cb202e1a3072032193a3dbc8 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Tue, 5 Nov 2024 11:35:14 +0000 Subject: [PATCH] 2024-11-05 nightly release (8ab3385f4187bad56c66dba29152e34c158f368a) --- .../annotate_channels_last_dim_order_pass.py | 5 +- .../_passes/insert_squeeze_after_sum_pass.py | 14 +- .../arm/_passes/size_adjust_conv2d_pass.py | 4 +- backends/arm/operators/op_addmm.py | 38 ++- backends/arm/operators/op_bmm.py | 16 +- backends/arm/operators/op_conv2d.py | 22 +- backends/arm/operators/op_exp.py | 7 +- backends/arm/operators/op_full.py | 11 +- backends/arm/operators/op_hardtanh.py | 13 +- backends/arm/operators/op_log.py | 7 +- backends/arm/operators/op_mm.py | 16 +- backends/arm/operators/op_mul.py | 4 +- backends/arm/operators/op_placeholder.py | 17 +- backends/arm/operators/op_reciprocal.py | 7 +- backends/arm/operators/op_relu.py | 2 +- backends/arm/operators/op_rsqrt.py | 7 +- backends/arm/operators/op_sigmoid.py | 7 +- backends/arm/operators/op_tanh.py | 7 +- backends/arm/quantizer/TARGETS | 3 - .../generic_annotator.py | 3 + .../quantization_annotation/mm_annotator.py | 4 +- backends/arm/test/ops/test_bmm.py | 20 +- backends/arm/test/ops/test_linear.py | 2 +- backends/arm/tosa_quant_utils.py | 214 +++++++++++------ backends/arm/tosa_utils.py | 20 +- extension/android/build.gradle | 1 + .../java/org/pytorch/executorch/EValue.java | 8 +- .../org/pytorch/executorch/EValueTest.java | 218 ++++++++++++++++++ 28 files changed, 518 insertions(+), 179 deletions(-) create mode 100644 extension/android/src/test/java/org/pytorch/executorch/EValueTest.java diff --git a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py index 77def9e7cd..786117e645 100644 --- a/backends/arm/_passes/annotate_channels_last_dim_order_pass.py +++ b/backends/arm/_passes/annotate_channels_last_dim_order_pass.py @@ -14,7 +14,7 @@ get_first_fake_tensor, insert_q_dq_pair, ) -from executorch.backends.arm.tosa_quant_utils import dq_op, q_op +from executorch.backends.arm.tosa_quant_utils import dq_op, q_op, register_passable_op from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -42,6 +42,9 @@ def _transpose_impl(*args, **kwargs): return args[0] +register_passable_op(torch.ops.passthrough_to_tosa._transpose) + + class AnnotateChannelsLastDimOrder(ExportPass): """ Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order diff --git a/backends/arm/_passes/insert_squeeze_after_sum_pass.py b/backends/arm/_passes/insert_squeeze_after_sum_pass.py index 152d5c95f6..adf2b4f491 100644 --- a/backends/arm/_passes/insert_squeeze_after_sum_pass.py +++ b/backends/arm/_passes/insert_squeeze_after_sum_pass.py @@ -8,9 +8,7 @@ import torch import torch.fx -from executorch.backends.arm._passes.arm_pass_utils import create_node, insert_q_dq_pair - -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args, is_quant_node +from executorch.backends.arm._passes.arm_pass_utils import create_node from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -28,8 +26,6 @@ class InsertSqueezeAfterSumPass(ExportPass): sum(dims, keep_dim = False) After pass: sum(dims, keep_dim = True) - (q) - (dq) squeeze(dim = dims) """ @@ -45,12 +41,6 @@ def call(self, graph_module: torch.fx.GraphModule): continue dim_list = cast(list[int], sum_node.args[1]) - quantized = is_quant_node(sum_node) - if quantized: - qparams = get_quant_node_args(sum_node.all_input_nodes[0]) - qparams = qparams + (torch.int8,) - else: - qparams = None # Add keep_dim = True arg to sum node. sum_node.args = sum_node.args[0:2] + (True,) @@ -61,8 +51,6 @@ def call(self, graph_module: torch.fx.GraphModule): ) sum_node.replace_all_uses_with(squeeze_node) squeeze_node.args = (sum_node, dim_list) - if quantized: - sum_node = insert_q_dq_pair(graph_module.graph, sum_node, qparams) graph_module.graph.eliminate_dead_code() graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/arm/_passes/size_adjust_conv2d_pass.py b/backends/arm/_passes/size_adjust_conv2d_pass.py index 980ab09e59..c7bd27dcce 100644 --- a/backends/arm/_passes/size_adjust_conv2d_pass.py +++ b/backends/arm/_passes/size_adjust_conv2d_pass.py @@ -9,7 +9,7 @@ from typing import cast, Optional import torch.fx -from executorch.backends.arm.tosa_quant_utils import is_quant_node +from executorch.backends.arm.tosa_quant_utils import is_node_quantized from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch._ops import OpOverload @@ -113,7 +113,7 @@ def call(self, graph_module: torch.fx.GraphModule): slice_node = graph.create_node( "call_function", self.slice_op, (last_node,) + args ) - if is_quant_node(last_node): + if is_node_quantized(last_node): q_params = last_node.args[1:] dq_node = insert_q_dq_pair( graph_module.graph, slice_node, q_params diff --git a/backends/arm/operators/op_addmm.py b/backends/arm/operators/op_addmm.py index b4f782db4a..64de62767e 100644 --- a/backends/arm/operators/op_addmm.py +++ b/backends/arm/operators/op_addmm.py @@ -14,10 +14,13 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + search_quant_arg_downstream, + search_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import build_reshape -from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp @@ -67,12 +70,7 @@ def define_node( input_zp = 0 if is_quant_node: input_node = node.all_input_nodes[1] - # rank > 2 linear layer - if input_node.target == exir_ops.edge.aten.view_copy.default: - quant_node = input_node.all_input_nodes[0] - else: - quant_node = input_node - input_zp = get_quant_node_args(quant_node).zp + input_zp = search_quant_arg_upstream(input_node).zp attr.ConvAttribute( pad=pad_attr, stride=stride_attr, @@ -107,24 +105,16 @@ def define_node( # Read inputs' parent nodes _, input_node, weight_node = node.all_input_nodes - # rank > 2 linear layer - if input_node.target == exir_ops.edge.aten.view_copy.default: - quant_node = input_node.all_input_nodes[0] - input_scale = get_quant_node_args(quant_node).scale - consumer_node = list(node.users)[0] - consumer_consumer_node = list(consumer_node.users)[0] - quant_args = get_quant_node_args(consumer_consumer_node) - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp - else: - input_scale = get_quant_node_args(input_node).scale - consumer_node = list(node.users)[0] - quant_args = get_quant_node_args(consumer_node) - consumer_node_scale = quant_args.scale - consumer_node_node_zp = quant_args.zp + qargs = search_quant_arg_upstream(input_node) + input_scale = qargs.scale + consumer_node = list(node.users)[0] + quant_args = search_quant_arg_downstream(consumer_node) + + consumer_node_scale = quant_args.scale + consumer_node_node_zp = quant_args.zp weight_node_q_node = weight_node.all_input_nodes[0] - weight_scale = get_quant_node_args(weight_node_q_node).scale + weight_scale = search_quant_arg_upstream(weight_node_q_node).scale output_rescale_scale = (input_scale * weight_scale) / consumer_node_scale diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 161b5d2239..c4067e5a7c 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -14,7 +14,11 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + search_quant_arg_downstream, + search_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import get_two_inputs from serializer.tosa_serializer import TosaOp @@ -42,8 +46,10 @@ def define_node( # For INT8, we need to get the zero points and add an intermediate tensor # for a later rescale. if is_quant_node: - input0_zp = get_quant_node_args(input0).zp - input1_zp = get_quant_node_args(input1).zp + input0_q_params = search_quant_arg_upstream(input0) + input1_q_params = search_quant_arg_upstream(input1) + input0_zp = input0_q_params.zp + input1_zp = input1_q_params.zp bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32) bmm_output_name = bmm_result.name else: @@ -63,9 +69,7 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = get_quant_node_args(input0) - input1_q_params = get_quant_node_args(input1) - output_q_params = get_quant_node_args(list(node.users)[0]) + output_q_params = search_quant_arg_downstream(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_conv2d.py b/backends/arm/operators/op_conv2d.py index 64cde0724f..8b2627ceda 100644 --- a/backends/arm/operators/op_conv2d.py +++ b/backends/arm/operators/op_conv2d.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe -from typing import cast, List +from typing import List import serializer.tosa_serializer as ts import torch @@ -15,9 +15,10 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( build_rescale_conv_output, - get_quant_node_args, + search_quant_arg_downstream, + search_quant_arg_upstream, ) -from executorch.backends.arm.tosa_utils import build_reshape, getNodeArgs, tosa_shape +from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape from serializer.tosa_serializer import TosaOp @@ -82,7 +83,9 @@ def define_node( ) input_zp = ( - get_quant_node_args(node.all_input_nodes[0]).zp if is_quant_node else 0 + search_quant_arg_upstream(node.all_input_nodes[0]).zp + if is_quant_node + else 0 ) attr.ConvAttribute( @@ -158,9 +161,10 @@ def define_node( # integer value domain of the next op. Otherwise return float32 output. if is_quant_node: # Get scale_factor from input, weight, and output. - _, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0])) - _, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1])) - _, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0]) + input_scale = search_quant_arg_upstream(node.all_input_nodes[0]).scale + weight_scale = search_quant_arg_upstream(node.all_input_nodes[1]).scale + output_qargs = search_quant_arg_downstream(list(node.users)[0]) + build_rescale_conv_output( tosa_graph, # pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined. @@ -169,6 +173,6 @@ def define_node( actual_out_type, input_scale, weight_scale, - output_scale, - output_zp, + output_qargs.scale, + output_qargs.zp, ) diff --git a/backends/arm/operators/op_exp.py b/backends/arm/operators/op_exp.py index 0e0a75dcc4..115ee4606c 100644 --- a/backends/arm/operators/op_exp.py +++ b/backends/arm/operators/op_exp.py @@ -17,9 +17,10 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, QuantArgs, quantize_value, + search_quant_arg_downstream, + search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -48,9 +49,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = search_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = search_quant_arg_downstream(output_node) table = exp_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_full.py b/backends/arm/operators/op_full.py index cf67975e0d..b2c14e4d46 100644 --- a/backends/arm/operators/op_full.py +++ b/backends/arm/operators/op_full.py @@ -14,7 +14,10 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + quantize_value, + search_quant_arg_downstream, +) from executorch.backends.arm.tosa_utils import tosa_shape from torch.fx import Node @@ -39,10 +42,8 @@ def define_node( value = inputs[1].number if is_quant_node: - qargs = get_quant_node_args(list(node.users)[0]) - qvalue = np.clip( - np.round(value / qargs.scale) + qargs.zp, qargs.qmin, qargs.qmax - ) + qargs = search_quant_arg_downstream(list(node.users)[0]) + qvalue = quantize_value(value, qargs) dtype = ts.DType.INT8 data = np.full(shape, qvalue, dtype=np.int8) else: diff --git a/backends/arm/operators/op_hardtanh.py b/backends/arm/operators/op_hardtanh.py index 62c0a27f05..184bb8173d 100644 --- a/backends/arm/operators/op_hardtanh.py +++ b/backends/arm/operators/op_hardtanh.py @@ -14,7 +14,10 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + quantize_value, + search_quant_arg_upstream, +) from serializer.tosa_serializer import TosaOp @@ -37,12 +40,10 @@ def define_node( if is_quant_node: # Get quant parameters - scale, zp, qmin, qmax = get_quant_node_args(node.all_input_nodes[0]) + qargs = search_quant_arg_upstream(node.all_input_nodes[0]) # Convert to quantized representation - clamp_min_qs = round((inputs[1].number / scale) + zp) - clamp_min_qs = max(clamp_min_qs, qmin) - clamp_max_qs = round((inputs[2].number / scale) + zp) - clamp_max_qs = min(clamp_max_qs, qmax) + clamp_min_qs = quantize_value(inputs[1].number, qargs) + clamp_max_qs = quantize_value(inputs[2].number, qargs) # Set fp values to 0.0 since they are not used clamp_min_fp = 0.0 clamp_max_fp = 0.0 diff --git a/backends/arm/operators/op_log.py b/backends/arm/operators/op_log.py index 5276173efa..8512e3eb30 100644 --- a/backends/arm/operators/op_log.py +++ b/backends/arm/operators/op_log.py @@ -17,9 +17,10 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, QuantArgs, quantize_value, + search_quant_arg_downstream, + search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = search_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = search_quant_arg_downstream(output_node) table = log_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_mm.py b/backends/arm/operators/op_mm.py index ebddb3a40e..b59baed69a 100644 --- a/backends/arm/operators/op_mm.py +++ b/backends/arm/operators/op_mm.py @@ -14,7 +14,11 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_quant_utils import build_rescale, get_quant_node_args +from executorch.backends.arm.tosa_quant_utils import ( + build_rescale, + search_quant_arg_downstream, + search_quant_arg_upstream, +) from executorch.backends.arm.tosa_utils import ( build_reshape, expand_dims, @@ -54,8 +58,8 @@ def define_node( # For INT8, we need to get the zero point, otherwise it is 0 input0_zp, input1_zp = 0, 0 if is_quant_node: - input0_zp = get_quant_node_args(input0).zp - input1_zp = get_quant_node_args(input1).zp + input0_zp = search_quant_arg_upstream(input0).zp + input1_zp = search_quant_arg_upstream(input1).zp mat_mul_result = tosa_graph.addIntermediate( output_new_shape, ts.DType.INT32 if is_quant_node else output.dtype @@ -86,9 +90,9 @@ def define_node( # As INT8 accumulates into INT32, we need to rescale it back to INT8 if is_quant_node: - input0_q_params = get_quant_node_args(input0) - input1_q_params = get_quant_node_args(input1) - output_q_params = get_quant_node_args(list(node.users)[0]) + input0_q_params = search_quant_arg_upstream(input0) + input1_q_params = search_quant_arg_upstream(input1) + output_q_params = search_quant_arg_downstream(list(node.users)[0]) final_output_scale = ( input0_q_params.scale * input1_q_params.scale diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index c152e8759e..8d50756711 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -37,10 +37,10 @@ def define_node( if is_quant_node: input_A = inputs[0] input_B = inputs[1] - input_A_qargs = tqutils.get_quant_node_args( + input_A_qargs = tqutils.search_quant_arg_upstream( cast(torch.fx.Node, node.args[0]) ) - input_B_qargs = tqutils.get_quant_node_args( + input_B_qargs = tqutils.search_quant_arg_upstream( cast(torch.fx.Node, node.args[1]) ) diff --git a/backends/arm/operators/op_placeholder.py b/backends/arm/operators/op_placeholder.py index 2618c9e71d..00bebba09d 100644 --- a/backends/arm/operators/op_placeholder.py +++ b/backends/arm/operators/op_placeholder.py @@ -10,13 +10,14 @@ import torch.fx from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_arg_dtype, - get_quant_node_args, - is_quant_arg, + get_quantized_node_output_dtype, + is_node_quantized, + search_quant_arg_upstream, ) from executorch.backends.arm.tosa_utils import ( is_bias_node_for_quantized_addmm, is_bias_node_for_quantized_conv, + map_dtype, tosa_shape, ) from executorch.exir.dialects._ops import ops as exir_ops @@ -41,7 +42,11 @@ def process_inputs( tensor = ts.TosaSerializerTensor( inputs[0].name, tosa_shape(input_shape, input_dim_order), - get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype, + ( + map_dtype(get_quantized_node_output_dtype(node)) + if is_node_quantized(node) + else inputs[0].dtype + ), data=None, placeholderFilename=inputs[0].name + ".npy", ) @@ -75,8 +80,8 @@ def process_quantized_bias( _, ) = consumer_node.all_input_nodes - input_node_scale = get_quant_node_args(input_node).scale - weight_node_scale = get_quant_node_args(weight_node).scale + input_node_scale = search_quant_arg_upstream(input_node).scale + weight_node_scale = search_quant_arg_upstream(weight_node).scale bias_values_quantized = ( (parameter_values / (input_node_scale * weight_node_scale)) .round() diff --git a/backends/arm/operators/op_reciprocal.py b/backends/arm/operators/op_reciprocal.py index 3d43fd8f7d..051d8bf4d7 100644 --- a/backends/arm/operators/op_reciprocal.py +++ b/backends/arm/operators/op_reciprocal.py @@ -15,9 +15,10 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, QuantArgs, quantize_value, + search_quant_arg_downstream, + search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp @@ -41,8 +42,8 @@ def define_node( if is_quant_node: input = inputs[0] - input_qargs = get_quant_node_args(node.all_input_nodes[0]) - output_qargs = get_quant_node_args(list(node.users)[0]) + input_qargs = search_quant_arg_upstream(node.all_input_nodes[0]) + output_qargs = search_quant_arg_downstream(list(node.users)[0]) div_table = div_table_8bit(input_qargs, output_qargs) diff --git a/backends/arm/operators/op_relu.py b/backends/arm/operators/op_relu.py index 20bba3f654..afc2fd88d6 100644 --- a/backends/arm/operators/op_relu.py +++ b/backends/arm/operators/op_relu.py @@ -38,7 +38,7 @@ def define_node( clamp_min_qs = 0 clamp_max_qs = 0 if is_quant_node: - out_qargs = tqutils.get_quant_node_args(list(node.users)[0]) + out_qargs = tqutils.search_quant_arg_downstream(list(node.users)[0]) clamp_min_qs = tqutils.quantize_value(0, out_qargs) clamp_max_qs = tqutils.quantize_value(float("inf"), out_qargs) diff --git a/backends/arm/operators/op_rsqrt.py b/backends/arm/operators/op_rsqrt.py index 9225c7d938..d256a1c633 100644 --- a/backends/arm/operators/op_rsqrt.py +++ b/backends/arm/operators/op_rsqrt.py @@ -16,9 +16,10 @@ from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, QuantArgs, quantize_value, + search_quant_arg_downstream, + search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp @@ -39,9 +40,9 @@ def define_node( # Assume quantized input is 8 bit. # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = search_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = search_quant_arg_downstream(output_node) table = rsqrt_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() table_attr.TableAttribute(table) diff --git a/backends/arm/operators/op_sigmoid.py b/backends/arm/operators/op_sigmoid.py index 0087b1f7a8..d0e321f6fd 100644 --- a/backends/arm/operators/op_sigmoid.py +++ b/backends/arm/operators/op_sigmoid.py @@ -17,9 +17,10 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, QuantArgs, quantize_value, + search_quant_arg_downstream, + search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = search_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = search_quant_arg_downstream(output_node) table = sigmoid_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/operators/op_tanh.py b/backends/arm/operators/op_tanh.py index 20f343a7f1..7a556a5379 100644 --- a/backends/arm/operators/op_tanh.py +++ b/backends/arm/operators/op_tanh.py @@ -17,9 +17,10 @@ from executorch.backends.arm.tosa_quant_utils import ( dequantize_value, - get_quant_node_args, QuantArgs, quantize_value, + search_quant_arg_downstream, + search_quant_arg_upstream, ) from serializer.tosa_serializer import TosaOp from torch.fx import Node @@ -49,9 +50,9 @@ def define_node( # Create attribute for 8 bit table lookup. input_node = node.all_input_nodes[0] - in_quantargs = get_quant_node_args(input_node) + in_quantargs = search_quant_arg_upstream(input_node) output_node = list(node.users)[0] - out_quantargs = get_quant_node_args(output_node) + out_quantargs = search_quant_arg_downstream(output_node) table = tanh_table_8bit(in_quantargs, out_quantargs) table_attr = ts.TosaSerializerAttribute() diff --git a/backends/arm/quantizer/TARGETS b/backends/arm/quantizer/TARGETS index 840586488b..a2445f26c0 100644 --- a/backends/arm/quantizer/TARGETS +++ b/backends/arm/quantizer/TARGETS @@ -3,7 +3,6 @@ load("@fbcode_macros//build_defs:python_library.bzl", "python_library") python_library( name = "arm_quantizer", srcs = ["arm_quantizer.py"], - typing = True, deps = [ ":arm_quantizer_utils", "//caffe2:torch", @@ -15,7 +14,6 @@ python_library( python_library( name = "quantization_config", srcs = ["quantization_config.py"], - typing = True, deps = [ "//caffe2:torch", ], @@ -24,7 +22,6 @@ python_library( python_library( name = "arm_quantizer_utils", srcs = ["arm_quantizer_utils.py"], - typing = True, deps = [ ":quantization_config", ], diff --git a/backends/arm/quantizer/quantization_annotation/generic_annotator.py b/backends/arm/quantizer/quantization_annotation/generic_annotator.py index f91df1398e..a35f5c0fda 100644 --- a/backends/arm/quantizer/quantization_annotation/generic_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/generic_annotator.py @@ -27,6 +27,9 @@ torch.ops.aten.unsqueeze.default, torch.ops.aten.unsqueeze_copy.default, torch.ops.aten.reshape.default, + torch.ops.aten.repeat.default, + torch.ops.aten.expand_copy.default, + torch.ops.aten.expand.default, # Disabling these as there seems to be an issue with support for complex # datatypes in torch: # torch.ops.aten.view_as_complex.default, diff --git a/backends/arm/quantizer/quantization_annotation/mm_annotator.py b/backends/arm/quantizer/quantization_annotation/mm_annotator.py index b48c6d5990..60d9adb1c3 100644 --- a/backends/arm/quantizer/quantization_annotation/mm_annotator.py +++ b/backends/arm/quantizer/quantization_annotation/mm_annotator.py @@ -24,7 +24,9 @@ def _annotate_mm( quantization_config: QuantizationConfig, filter_fn: Optional[Callable[[Node], bool]] = None, ) -> Optional[List[List[Node]]]: - mm_partitions = get_source_partitions(gm.graph, [torch.mm, torch.bmm], filter_fn) + mm_partitions = get_source_partitions( + gm.graph, [torch.mm, torch.bmm, torch.matmul], filter_fn + ) mm_partitions = list(itertools.chain.from_iterable(mm_partitions.values())) annotated_partitions = [] for mm_partition in mm_partitions: diff --git a/backends/arm/test/ops/test_bmm.py b/backends/arm/test/ops/test_bmm.py index e4e6abb7bb..a61cc7f1c8 100644 --- a/backends/arm/test/ops/test_bmm.py +++ b/backends/arm/test/ops/test_bmm.py @@ -32,6 +32,12 @@ class BMM(torch.nn.Module): def forward(self, x, y): return torch.bmm(x, y) + class MatMul(torch.nn.Module): + test_parameters = [(torch.rand(2, 3, 5), torch.rand(2, 5, 2))] + + def forward(self, x, y): + return torch.matmul(x, y) + class BMMSingleInput(torch.nn.Module): test_parameters = [ (torch.rand(20, 3, 3),), @@ -53,9 +59,9 @@ def _test_bmm_tosa_MI_pipeline( compile_spec=common.get_tosa_compile_spec(), ) .export() - .check_count({"torch.ops.aten.bmm.default": 1}) .check_not(["torch.ops.quantized_decomposed"]) .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -74,9 +80,9 @@ def _test_bmm_tosa_BI_pipeline( ) .quantize() .export() - .check_count({"torch.ops.aten.bmm.default": 1}) .check(["torch.ops.quantized_decomposed"]) .to_edge() + .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) .partition() .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) @@ -116,6 +122,16 @@ def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor): test_data = (operand1,) self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data) + @parameterized.expand(MatMul.test_parameters) + def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) + + @parameterized.expand(MatMul.test_parameters) + def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): + test_data = (operand1, operand2) + self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) + @parameterized.expand(BMM.test_parameters) def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): test_data = (operand1, operand2) diff --git a/backends/arm/test/ops/test_linear.py b/backends/arm/test/ops/test_linear.py index 3f68ab0251..7d46354588 100644 --- a/backends/arm/test/ops/test_linear.py +++ b/backends/arm/test/ops/test_linear.py @@ -151,7 +151,7 @@ def _test_linear_tosa_BI_pipeline( .partition() .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() - .run_method_and_compare_outputs(inputs=test_data, qtol=True) + .run_method_and_compare_outputs(inputs=test_data, qtol=1) ) def _test_linear_tosa_ethosu_BI_pipeline( diff --git a/backends/arm/tosa_quant_utils.py b/backends/arm/tosa_quant_utils.py index fe408e41b3..d195c7f446 100644 --- a/backends/arm/tosa_quant_utils.py +++ b/backends/arm/tosa_quant_utils.py @@ -15,14 +15,31 @@ import serializer.tosa_serializer as ts import torch.fx import tosa.Op as TosaOp -from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg +from executorch.backends.arm.tosa_mapping import TosaArg from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaSerializerTensor from torch.fx import Node + q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default -dq_q_ops = [q_op, dq_op] +dq_q_ops = (q_op, dq_op) +passable_ops = [ + exir_ops.edge.aten.view_copy.default, + exir_ops.edge.aten.permute_copy.default, + exir_ops.edge.aten.squeeze_copy.dims, + exir_ops.edge.aten.unsqueeze_copy.default, + exir_ops.edge.aten.split_with_sizes_copy.default, + exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.aten.slice_copy.Tensor, + exir_ops.edge.aten.cat.default, +] + + +def register_passable_op(op): + """We need to be able to add custom ops such as tosa_transpose to the passable_op list after they have been created""" + passable_ops.append(op) class QuantArgs(NamedTuple): @@ -30,6 +47,19 @@ class QuantArgs(NamedTuple): zp: int qmin: int qmax: int + dtype: torch.dtype + + def quantize_value(self, x): + if not isinstance(x, torch.Tensor): + x = torch.Tensor([x]) + return torch.clip( + torch.round(x / self.scale) + self.zp, + self.qmin, + self.qmax, + ).to(self.dtype) + + def dequantize_value(self, qx): + return (qx - self.zp) * self.scale def quantize_value(x, qargs: QuantArgs, dtype=np.int8): @@ -44,81 +74,135 @@ def dequantize_value(qx, qargs: QuantArgs): return (qx - qargs.zp) * qargs.scale -def is_quant_node(node: torch.fx.Node): +def qargs_from_qnode(node: torch.fx.Node): + assert node.target in dq_q_ops, f"Op {node} is not a quant node." - consumer_node_condition = False - if len(list(node.users)) > 0: - consumer_node = list(node.users)[0] + return QuantArgs(*node.args[1:]) - # For Rank > 2 Linear layers, the quant node is after the view_copy - if ( - node.target == exir_ops.edge.aten.addmm.default - and consumer_node.target == exir_ops.edge.aten.view_copy.default - ): - consumer_consumer_node = list(consumer_node.users)[0] - return True if consumer_consumer_node.target == q_op else False - consumer_node_condition = consumer_node.target == q_op - input_node_condition = False - if len(node.all_input_nodes) > 0: - input = node.all_input_nodes[0] - input_node_condition = input.target in dq_q_ops +def get_neighbour_quant_args( + node: torch.fx.Node, +) -> tuple[list[QuantArgs], list[QuantArgs]]: + user_q_args = [] - return node.target in dq_q_ops or consumer_node_condition or input_node_condition + for user in node.users: + q_args = search_quant_arg_downstream(user) + if q_args: + user_q_args.append(q_args) + input_q_nodes = [] + for input_node in node.all_input_nodes: + q_args = search_quant_arg_upstream(input_node) + if q_args: + input_q_nodes.append(q_args) + return user_q_args, input_q_nodes -def get_quant_node_dtype(node: torch.fx.Node): - # pyre-ignore[16]: Undefined attribute. - if "tosa" in node.target.__name__: - return node.meta["val"].dtype - if node.target in dq_q_ops: - return node.args[5] - - # if not a tosa node, nor a q/dq op, walk the graph until we find a q op - consumer_node = list(node.users)[0] - while True: - if consumer_node.target in dq_q_ops: - return consumer_node.args[5] +def all_q_args_equal(q_arg_list: list[QuantArgs]) -> bool: + first_q_arg = q_arg_list[0] + for q_arg in q_arg_list: + if q_arg != first_q_arg: + return False + return True - # Try to move on to the next node - if len(consumer_node.users) == 0: - raise RuntimeError(f"No quantized node found in graph for node {node}") - consumer_node = list(consumer_node.users)[0] +def is_node_quantized(node: torch.fx.Node) -> bool: + if node.target in dq_q_ops: + return True -def is_quant_arg(arg): - consumer_node = list(arg.users)[0] - return consumer_node.target == q_op + user_q_args, input_q_args = get_neighbour_quant_args(node) + # If we did not find any neighbouring quant nodes, we are not quantized. + if len(input_q_args) == 0 and len(user_q_args) == 0: + return False -def get_quant_arg_dtype(node: torch.fx.Node): - consumer_node = list(node.users)[0] + if node.target in passable_ops: + assert all_q_args_equal( + user_q_args + input_q_args + ), f"Node {node} needs same quantization parameters on all inputs and outputs." - # Get type of quant node, args differ from per_tensor and per_channel. - if consumer_node.target == q_op: - if is_quant_arg(node): - return map_dtype(consumer_node.args[5]) - else: - raise RuntimeError("Quantization argument not found") + return True -def get_quant_node_args(node: torch.fx.Node): +def search_quant_arg_downstream(node: torch.fx.Node) -> QuantArgs | None: """ - Get the quantization parameters from a quant node. - - Args: - node: The quant node. - Returns: - QuantArgs: scale, zp, qmin, qmax + Iterates downward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple consumers is encountered, + find QuantArgs for all consumers and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without consumers is encountered, return None. """ - quant_args = [TosaArg(arg) for arg in node.args] - return QuantArgs( - quant_args[1].number, - quant_args[2].number, - quant_args[3].number, - quant_args[4].number, - ) + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + consumer_nodes = list(node.users) + if len(consumer_nodes) == 0: + return None + elif len(consumer_nodes) == 1: + return search_quant_arg_downstream(consumer_nodes[0]) + else: + consumer_qargs: list[QuantArgs] = [] + for input in consumer_nodes: + quant_args = search_quant_arg_downstream(input) + if quant_args: + consumer_qargs.append(quant_args) + if len(quant_args) == 0: + return None + assert all_q_args_equal( + consumer_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different consumers." + return consumer_qargs[0] + + +def search_quant_arg_upstream(node: torch.fx.Node) -> QuantArgs | None: + """ + Iterates upward in the graph passing through 'passable_ops' to find and return a quantization node, + starting with 'node'. + If a passable node with multiple inputs is encountered, + find QuantArgs for all inputs and assert that they are equal. + If a node not in passable_ops is encountered, return None. + If a node without inputs is encountered, return None. + """ + + if node.target in dq_q_ops: + return qargs_from_qnode(node) + if node.target not in passable_ops: + return None + input_nodes = list(node.all_input_nodes) + if len(input_nodes) == 0: + return None + elif len(input_nodes) == 1: + return search_quant_arg_upstream(input_nodes[0]) + else: + input_qargs: list[QuantArgs] = [] + for input in input_nodes: + quant_args = search_quant_arg_upstream(input) + if quant_args: + input_qargs.append(quant_args) + if len(quant_args) == 0: + return None + assert all_q_args_equal( + input_qargs + ), f"Encountered a op, {node}, in passable_ops with different QuantArgs for different inputs." + return input_qargs[0] + + +def get_quantized_node_output_dtype(node: torch.fx.Node): + if hasattr(node.target, "__name__") and "tosa" in node.target.__name__: + return node.meta["val"].dtype + if node.target in dq_q_ops: + return node.args[5] + + # if not a tosa node, nor a q/dq op, walk the graph until we find a q op + user_q_args, input_q_args = get_neighbour_quant_args(node) + if len(user_q_args) > 0: + return user_q_args[0].dtype + elif node.target in passable_ops and len(input_q_args): + return input_q_args[0].dtype + else: + raise RuntimeError("No quantized node found in graph") # Check if scale32 mode is used for given output element type @@ -267,14 +351,14 @@ def rescale_nodes_to_int32( needed by rescale_node_back_to_int8. """ - tensors = [TosaArg(node.args[0]) for node in nodes] + tensors = [TosaArg(node) for node in nodes] # Reshape tensor according to tosa dim order for tensor in tensors: dim_order = tensor.dim_order tensor.shape = [tensor.shape[i] for i in dim_order] - qargs = [get_quant_node_args(node) for node in nodes] + qargs = [search_quant_arg_upstream(node) for node in nodes] # Scale the int8 quantized input to a common scale in the integer # domain @@ -307,7 +391,7 @@ def rescale_node_back_to_int8( scale: the scaling factor used to rescale to int32, from the function 'rescale_nodes_to_int32' tosa_graph: the tosa_graph to manipulate. """ - qargs_out = get_quant_node_args(list(node.users)[0]) + qargs_out = search_quant_arg_downstream(list(node.users)[0]) output_rescale_scale = scale / qargs_out.scale # Rescale Back to INT8 @@ -334,7 +418,7 @@ def build_rescale_conv_output( output_zp, ): # TODO add check to verify if this is a Per-channel quantization. - post_conv2d_scale = (input_scale.number * weight_scale.number) / output_scale.number + post_conv2d_scale = (input_scale * weight_scale) / output_scale # Since we assume the input tensor that is being rescaled is int32 date type, zero point must be 0. build_rescale( @@ -345,6 +429,6 @@ def build_rescale_conv_output( output_type, op.shape, 0, - output_zp.number, + output_zp, ) return diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index cfafac1676..b3e9f4e1c3 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -16,10 +16,11 @@ from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg from executorch.backends.arm.tosa_quant_utils import ( - get_quant_node_args, - get_quant_node_dtype, - is_quant_node, + get_quantized_node_output_dtype, + is_node_quantized, q_op, + search_quant_arg_downstream, + search_quant_arg_upstream, ) from executorch.exir.dialects._ops import ops as exir_ops from serializer.tosa_serializer import TosaOp @@ -237,8 +238,8 @@ def build_avg_pool_2d_common( output_zp = 0 if is_quant_node: - input_zp = get_quant_node_args(cast(torch.fx.Node, node.args[0])).zp - output_zp = get_quant_node_args(list(node.users)[0]).zp + input_zp = search_quant_arg_upstream(cast(torch.fx.Node, node.args[0])).zp + output_zp = search_quant_arg_downstream(list(node.users)[0]).zp attr = ts.TosaSerializerAttribute() attr.PoolAttribute( @@ -297,6 +298,11 @@ def process_call_function( # Convert output (this node itself) output = TosaArg(node) + is_quant_node = is_node_quantized(node) + if is_quant_node: + output_dtype = map_dtype(get_quantized_node_output_dtype(node)) + else: + output_dtype = output.dtype tosa_graph.currRegion.currBasicBlock.addTensor( output.name, ( @@ -304,7 +310,7 @@ def process_call_function( if is_permute_node_before_addmm(node) else tosa_shape(output.shape, output.dim_order) ), - map_dtype(get_quant_node_dtype(node)) if is_quant_node(node) else output.dtype, + output_dtype, ) # Visiting each Node @@ -316,7 +322,7 @@ def process_call_function( tosa_graph, inputs, output, - is_quant_node(node), + is_quant_node, ) else: raise RuntimeError(f"Unknown operator {node.target}") diff --git a/extension/android/build.gradle b/extension/android/build.gradle index b40f08e0c4..de243154d6 100644 --- a/extension/android/build.gradle +++ b/extension/android/build.gradle @@ -20,5 +20,6 @@ task makeJar(type: Jar) { dependencies { implementation 'com.facebook.fbjni:fbjni-java-only:0.2.2' implementation 'com.facebook.soloader:nativeloader:0.10.5' + testImplementation 'junit:junit:4.13.2' } } diff --git a/extension/android/src/main/java/org/pytorch/executorch/EValue.java b/extension/android/src/main/java/org/pytorch/executorch/EValue.java index 0065d80872..016b6a3e09 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/EValue.java +++ b/extension/android/src/main/java/org/pytorch/executorch/EValue.java @@ -61,7 +61,7 @@ public class EValue { "ListInt", "ListTensor", "ListScalar", - "ListOptionalScalar", + "ListOptionalTensor", }; @DoNotStrip private final int mTypeCode; @@ -267,6 +267,12 @@ public Tensor[] toTensorList() { return (Tensor[]) mData; } + @DoNotStrip + public Optional[] toOptionalTensorList() { + preconditionType(TYPE_CODE_LIST_OPTIONAL_TENSOR, mTypeCode); + return (Optional[]) mData; + } + private void preconditionType(int typeCodeExpected, int typeCode) { if (typeCode != typeCodeExpected) { throw new IllegalStateException( diff --git a/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java b/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java new file mode 100644 index 0000000000..35367883ef --- /dev/null +++ b/extension/android/src/test/java/org/pytorch/executorch/EValueTest.java @@ -0,0 +1,218 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.executorch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; + +import com.facebook.jni.annotations.DoNotStrip; + +import java.util.List; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Locale; +import java.util.Optional; + +import org.pytorch.executorch.Tensor.Tensor_int64; +import org.pytorch.executorch.annotations.Experimental; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link EValue}. */ +@RunWith(JUnit4.class) +public class EValueTest { + + @Test + public void testNone() { + EValue evalue = EValue.optionalNone(); + assertTrue(evalue.isNone()); + } + + @Test + public void testTensorValue() { + long[] data = {1, 2, 3}; + long[] shape = {1, 3}; + EValue evalue = EValue.from(Tensor.fromBlob(data, shape)); + assertTrue(evalue.isTensor()); + assertTrue(Arrays.equals(evalue.toTensor().shape, shape)); + assertTrue(Arrays.equals(evalue.toTensor().getDataAsLongArray(), data)); + } + + @Test + public void testBoolValue() { + EValue evalue = EValue.from(true); + assertTrue(evalue.isBool()); + assertTrue(evalue.toBool()); + } + + @Test + public void testIntValue() { + EValue evalue = EValue.from(1); + assertTrue(evalue.isInt()); + assertEquals(evalue.toInt(), 1); + } + + @Test + public void testDoubleValue() { + EValue evalue = EValue.from(0.1d); + assertTrue(evalue.isDouble()); + assertEquals(evalue.toDouble(), 0.1d, 0.0001d); + } + + @Test + public void testStringValue() { + EValue evalue = EValue.from("a"); + assertTrue(evalue.isString()); + assertEquals(evalue.toStr(), "a"); + } + + @Test + public void testBoolListValue() { + boolean[] value = {true, false, true}; + EValue evalue = EValue.listFrom(value); + assertTrue(evalue.isBoolList()); + assertTrue(Arrays.equals(value, evalue.toBoolList())); + } + + @Test + public void testIntListValue() { + long[] value = {Long.MIN_VALUE, 0, Long.MAX_VALUE}; + EValue evalue = EValue.listFrom(value); + assertTrue(evalue.isIntList()); + assertTrue(Arrays.equals(value, evalue.toIntList())); + } + + @Test + public void testDoubleListValue() { + double[] value = {Double.MIN_VALUE,0.1d, 0.01d, 0.001d, Double.MAX_VALUE}; + EValue evalue = EValue.listFrom(value); + assertTrue(evalue.isDoubleList()); + assertTrue(Arrays.equals(value, evalue.toDoubleList())); + } + + @Test + public void testTensorListValue() { + long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}}; + long[][] shape = {{1, 3}, {2, 3}}; + Tensor[] tensors = {Tensor.fromBlob(data[0], shape[0]), Tensor.fromBlob(data[1], shape[1])}; + + EValue evalue = EValue.listFrom(tensors); + assertTrue(evalue.isTensorList()); + + assertTrue(Arrays.equals(evalue.toTensorList()[0].shape, shape[0])); + assertTrue(Arrays.equals(evalue.toTensorList()[0].getDataAsLongArray(), data[0])); + + assertTrue(Arrays.equals(evalue.toTensorList()[1].shape, shape[1])); + assertTrue(Arrays.equals(evalue.toTensorList()[1].getDataAsLongArray(), data[1])); + } + + @Test + @SuppressWarnings("unchecked") + public void testOptionalTensorListValue() { + long[][] data = {{1, 2, 3}, {1, 2, 3, 4, 5, 6}}; + long[][] shape = {{1, 3}, {2, 3}}; + + EValue evalue = EValue.listFrom( + Optional.empty(), + Optional.of(Tensor.fromBlob(data[0], shape[0])), + Optional.of(Tensor.fromBlob(data[1], shape[1]))); + assertTrue(evalue.isOptionalTensorList()); + + assertTrue(evalue.toOptionalTensorList()[0].isEmpty()); + + assertTrue(evalue.toOptionalTensorList()[1].isPresent()); + assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().shape, shape[0])); + assertTrue(Arrays.equals(evalue.toOptionalTensorList()[1].get().getDataAsLongArray(), data[0])); + + assertTrue(evalue.toOptionalTensorList()[2].isPresent()); + assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().shape, shape[1])); + assertTrue(Arrays.equals(evalue.toOptionalTensorList()[2].get().getDataAsLongArray(), data[1])); + } + + @Test + public void testAllIllegalCast() { + EValue evalue = EValue.optionalNone(); + assertTrue(evalue.isNone()); + + // try Tensor + assertFalse(evalue.isTensor()); + try { + evalue.toTensor(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try bool + assertFalse(evalue.isBool()); + try { + evalue.toBool(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try int + assertFalse(evalue.isInt()); + try { + evalue.toInt(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try double + assertFalse(evalue.isDouble()); + try { + evalue.toDouble(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try string + assertFalse(evalue.isString()); + try { + evalue.toStr(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try bool list + assertFalse(evalue.isBoolList()); + try { + evalue.toBoolList(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try int list + assertFalse(evalue.isIntList()); + try { + evalue.toIntList(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try double list + assertFalse(evalue.isDoubleList()); + try { + evalue.toBool(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try Tensor list + assertFalse(evalue.isTensorList()); + try { + evalue.toTensorList(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + + // try optional Tensor list + assertFalse(evalue.isOptionalTensorList()); + try { + evalue.toOptionalTensorList(); + fail("Should have thrown an exception"); + } catch (IllegalStateException e) {} + } +}