Skip to content

Commit

Permalink
Qualcomm AI Engine Direct - Optimize static llama phase 2
Browse files Browse the repository at this point in the history
Differential Revision: D67755292

Pull Request resolved: #7466
  • Loading branch information
shewu-quic authored Jan 8, 2025
1 parent 54f0786 commit 2e24b4e
Show file tree
Hide file tree
Showing 70 changed files with 936 additions and 401 deletions.
90 changes: 55 additions & 35 deletions backends/qualcomm/_passes/annotate_quant_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
import torch
from executorch.backends.qualcomm.builders.utils import get_parameter, set_parameter
from executorch.backends.qualcomm.utils.constants import (
QCOM_AXIS,
QCOM_DTYPE,
QCOM_ENCODING,
QCOM_QUANT_ATTRS,
QCOM_QUANT_MAX,
QCOM_QUANT_MIN,
QCOM_REQUANTIZE,
QCOM_SCALE,
QCOM_SCALES,
QCOM_ZERO_POINT,
QCOM_ZERO_POINTS,
)
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -52,60 +58,74 @@ def _expand(self, tensor, dim, axis) -> torch.Tensor:
order[axis], order[0] = order[0], order[axis]
return tensor.permute(order)

# Find the the last dq node between regular op nodes
# Find the the last dq nodes between regular op nodes
# Return dq2 in example below when q1 is given as node parameter:
# ... -> n1 -> q1 -> dq1 -> q2 -> dq2 -> n2 -> ...
def _find_last_dq_node(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
if list(node.users)[0].target in q_ops.union(dq_ops):
return self._find_last_dq_node(list(node.users)[0])
return node
def _find_last_dq_nodes(self, node: torch.fx.node.Node) -> torch.fx.node.Node:
if node is None:
return []

# If the node is last dq between regular op node, return it in a list
if node.target in dq_ops:
if all(user.target not in q_ops for user in node.users):
return [node]

last_dq_nodes = []
for user in list(node.users):
last_dq_nodes.extend(self._find_last_dq_nodes(user))

return last_dq_nodes

def _annotate_requant(self, n):
# Record requant attributes:
# node1 -> q_ui8 -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
# We store quant info for dq_ui8 and q_int32 in node1.meta
# node1 -> q_ui8 (n) -> dq_ui8 -> q_int32 -> dq_int32 -> node2 -> ....
# We store {node2: quant_attr in dq_int32} in node1.meta
if n.target in q_ops and n.args[0].target not in dq_ops:
dq_node = self._find_last_dq_node(n)
dq_nodes = self._find_last_dq_nodes(n)
q_attrs = get_quant_attrs(self.edge_program, n)
dq_attrs = get_quant_attrs(self.edge_program, dq_node)

# TODO: Store multiple pairs of requantize attributes when we have an op builder
# that has multiple outputs that requires quant attributes.
if self.skip_advanced_requant:
if q_attrs["dtype"] != dq_attrs["dtype"]:
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
else:
# When dtype is the same but other specs such as scale and offset are different,
# insert requant to improve accuracy.
# Users can turn this feature off if any inference speed drop is observed.
if any(
q_attrs[attr] != dq_attrs[attr]
for attr in [
"scale",
"zero_point",
"quant_min",
"quant_max",
"dtype",
]
):
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
n.args[0].meta[QCOM_REQUANTIZE] = dq_attrs
for dq_node in dq_nodes:
dq_attrs = get_quant_attrs(self.edge_program, dq_node)
# TODO: Store multiple pairs of requantize attributes when we have an op builder
# that has multiple outputs that requires quant attributes.
if self.skip_advanced_requant:
if q_attrs[QCOM_DTYPE] != dq_attrs[QCOM_DTYPE]:
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
user_node = list(dq_node.users)[0]
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs
else:
# When dtype is the same but other specs such as scale and offset are different,
# insert requant to improve accuracy.
# Users can turn this feature off if any inference speed drop is observed.
if any(
q_attrs[attr] != dq_attrs[attr]
for attr in [
QCOM_SCALE,
QCOM_ZERO_POINT,
QCOM_QUANT_MIN,
QCOM_QUANT_MAX,
QCOM_DTYPE,
]
):
dq_attrs[QCOM_ENCODING] = q_attrs[QCOM_ENCODING]
user_node = list(dq_node.users)[0]
n.args[0].meta.setdefault(QCOM_REQUANTIZE, {})
n.args[0].meta[QCOM_REQUANTIZE][user_node.name] = dq_attrs

# Dequant all the fold_quant parameters back to fp32.
# If an operation is not supported by QNN and got fallback, it will expect a fp32 param.
def _dequant_fold_params(self, n, quant_attrs, param):
if quant_attrs[QCOM_ENCODING] in [
exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
]:
dim, axis = param.dim(), quant_attrs["axis"]
dim, axis = param.dim(), quant_attrs[QCOM_AXIS]
scales = self._expand(quant_attrs[QCOM_SCALES], dim, axis)
offsets = self._expand(quant_attrs[QCOM_ZERO_POINTS], dim, axis)
param = param.sub(offsets).mul(scales).to(torch.float32).contiguous()
set_parameter(param, n.args[0], self.edge_program)
else:
scale = quant_attrs["scale"]
offset = quant_attrs["zero_point"]
scale = quant_attrs[QCOM_SCALE]
offset = quant_attrs[QCOM_ZERO_POINT]
param = param.sub(offset).mul(scale).to(torch.float32).contiguous()
set_parameter(param, n.args[0], self.edge_program)

Expand Down
66 changes: 52 additions & 14 deletions backends/qualcomm/_passes/insert_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import defaultdict
from typing import Dict, List

import torch

from executorch.backends.qualcomm.utils.constants import (
Expand Down Expand Up @@ -38,6 +41,42 @@ def __init__(
super(InsertRequantize, self).__init__()
self.edge_program = edge_program

def _make_hashable(self, value):
if isinstance(value, dict):
return tuple(sorted(value.items()))
return value

def _invert_dict(self, requantize_dict):
inverted_dict = defaultdict(list)
for user_node_name, quant_attr in requantize_dict.items():
hashable_quant_attr = self._make_hashable(quant_attr)
inverted_dict[hashable_quant_attr].append(user_node_name)
return inverted_dict

def _insert_to_copy(
self,
graph_module: torch.fx.GraphModule,
node: torch.fx.node,
quant_attr: Dict,
user_nodes: List[str],
):
with graph_module.graph.inserting_after(node):
users = list(node.users.keys())
inserted_n = graph_module.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(node,),
)
inserted_n.meta["val"] = node.meta["val"]
inserted_n.meta[QCOM_QUANT_ATTRS] = quant_attr

# create node and replace input
if node.meta.get(QCOM_QUANTIZED_IO):
inserted_n.meta[QCOM_QUANTIZED_IO] = node.meta[QCOM_QUANTIZED_IO]

for user in filter(lambda u: u.name in user_nodes, users):
user.replace_input_with(node, inserted_n)

# TODO: Implement this function when we have an op with
# multiple outputs that requires quant attributes.
def _multi_output_annotation(self) -> None:
Expand All @@ -46,21 +85,20 @@ def _multi_output_annotation(self) -> None:
def _single_output_annotation(
self, gm: torch.fx.GraphModule, n: torch.fx.node
) -> None:
with gm.graph.inserting_after(n):
users = list(n.users.keys())
inserted_n = gm.graph.create_node(
"call_function",
exir_ops.edge.aten._to_copy.default,
(n,),
)

inserted_n.meta["val"] = n.meta["val"]
inserted_n.meta[QCOM_QUANT_ATTRS] = n.meta.pop(QCOM_REQUANTIZE)
if n.meta.get(QCOM_QUANTIZED_IO):
inserted_n.meta[QCOM_QUANTIZED_IO] = n.meta[QCOM_QUANTIZED_IO]
# {user_node_name: quant_attr}
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
# {quant_attr: user_node_name_list}
group_quant_attr_dict = self._invert_dict(requantize_dict)
# TODO: If users of the node contain output node,
# we replace the node with to_copy op. However, it would
# be problem when the node has multiple to_copy ops
add_output = len(group_quant_attr_dict) == 1

for user in users:
user.replace_input_with(n, inserted_n)
for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
user_nodes_copy = user_nodes.copy()
if add_output:
user_nodes_copy.append("output")
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)

def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
for n in graph_module.graph.nodes:
Expand Down
3 changes: 0 additions & 3 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
QCOM_INSERTED_PERMUTE,
QCOM_LAYOUT_CHANGE,
QCOM_QUANT_ATTRS,
QCOM_REQUANTIZE,
)
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -133,8 +132,6 @@ def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
# if dimemsion is not kept, we'll have no clue how to do layout transform
if len(node.args) < 3 or not node.args[2]:
return False
if node.target in self.qdq_opset:
return QCOM_REQUANTIZE in node.meta
return node.target in self.layout_agnostic_ops

def is_edge_condition(self, node):
Expand Down
15 changes: 8 additions & 7 deletions backends/qualcomm/builders/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,21 +206,21 @@ Now, we can start to fill in function body step by step:
input_tensor = self.get_tensor(input_node, node)
input_tensor_wrapper = self.define_tensor(
input_node,
node,
input_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=True,
)
```
Through the information in [Check Operator Spec](#check-operator-spec) section, we could easily extract the desired nodes.<br/>
The `get_tensor` method is responsible for retrieving torch tensor in correct axis order if `layout_transform` pass happened to apply.<br/>
The `define_tensor` method is for generating tensor object for QNN API and will be memorized by aforementioned `node_to_wrappers`.<br/>
And yet, there are arguments worth for addressing more:
- **node**: current graph node
- **tensor_source_node**: current graph source node of the tensor
- **target_build_node**: current node to build, which is important for fixed point mixed-precision to work properly
- **tensor**: torch tensor emitted by node
- **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters
- **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN)
- **is_input_tensor**: flag to tell if current tensor is input activation or parameter, which is important for fixed point mixed-precision to work properly
- **node_name**: (optional) tensor name for user to specify
- **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object

Expand All @@ -230,23 +230,24 @@ Now, we can start to fill in function body step by step:
weight_tensor = get_parameter(weight_node, self.edge_program)
weight_tensor_wrapper = self.define_tensor(
weight_node,
node,
weight_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)
bias_node = node.args[3]
bias_tensor = get_parameter(bias_node, self.edge_program)
bias_tensor_wrapper = self.define_tensor(
bias_node,
node,
bias_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
nodes_to_wrappers,
is_input_tensor=False,
)
```
The logic should be similar and straightforward. Please carefully set arguments `tensor_type`, `is_input_tensor` according to tensors' property.
The logic should be similar and straightforward. Please carefully set arguments `tensor_type`
according to tensors' property.
3. Define parameters:
```python
Expand All @@ -266,11 +267,11 @@ Now, we can start to fill in function body step by step:
```python
output_tensor = self.get_tensor(node, node, 0)
output_tensor_wrapper = self.define_tensor(
node,
node,
output_tensor,
PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
nodes_to_wrappers,
is_input_tensor=False,
)
```
Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method.
Expand Down
39 changes: 23 additions & 16 deletions backends/qualcomm/builders/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,19 @@ def make_qnn_per_tensor_config(self, quant_attrs: Dict):
)

def get_quant_encoding_conf(
self, node: torch.fx.Node, is_input_tensor: bool = False
self, node: torch.fx.Node, target_node: torch.fx.Node
) -> Tuple[Any, Dict]:
if not node.meta.get(QCOM_QUANT_ATTRS, None):
return (
PyQnnWrapper.Qnn_QuantizationEncoding_t.QNN_QUANTIZATION_ENCODING_UNDEFINED,
{},
)
is_input_tensor = node != target_node
quant_attrs = (
node.meta[QCOM_REQUANTIZE]
if QCOM_REQUANTIZE in node.meta and is_input_tensor
node.meta[QCOM_REQUANTIZE][target_node.name]
if QCOM_REQUANTIZE in node.meta
and is_input_tensor
and target_node.name in node.meta[QCOM_REQUANTIZE]
else node.meta[QCOM_QUANT_ATTRS]
)
if quant_attrs[QCOM_ENCODING] in PER_CHANNEL_ENCODING:
Expand Down Expand Up @@ -282,40 +285,44 @@ def define_custom_tensor_wrapper(

def define_tensor(
self,
node: torch.fx.Node,
tensor_source_node: torch.fx.Node,
target_build_node: torch.fx.Node,
tensor: torch.Tensor,
tensor_type: PyQnnWrapper.Qnn_TensorType_t,
nodes_to_wrappers: Dict[str, Dict[int, PyQnnWrapper.TensorWrapper]],
is_input_tensor: bool,
node_name: str = None,
wrapper_idx: int = 0,
) -> PyQnnWrapper.TensorWrapper:
"""
Covert torch.Tensor to TensorWrapper
Args:
node: EdgeIR Node
tensor_source_node: EdgeIR Node
target_build_node: Current node to build
tensor: EdgeIR Tensor
tensor_type: QNN tensor type
nodes_to_wrappers: Set contains edge_graph values(node targets)
is_input_tensor: Whether tensor is a fake input tensor relatively to
the op builder that is calling this function
"""
if node_name is None:
node_name = node.name
node_name = tensor_source_node.name

if cached := nodes_to_wrappers[node_name].get(wrapper_idx, None):
return cached

tensor_name = f"{node.name}_{wrapper_idx}"
if is_graph_input(node, self.edge_program):
tensor_name = "input_" + str(self.external_ids[node]) + "_" + tensor_name
if is_graph_output(node):
tensor_name = f"{tensor_source_node.name}_{wrapper_idx}"
if is_graph_input(tensor_source_node, self.edge_program):
tensor_name = (
"input_"
+ str(self.external_ids[tensor_source_node])
+ "_"
+ tensor_name
)
if is_graph_output(tensor_source_node):
tensor_name = "output_" + tensor_name
dims = [1] if len(tensor.size()) == 0 else tensor.size()
tensor_type = self.get_tensor_type(node, tensor_type)
tensor_type = self.get_tensor_type(tensor_source_node, tensor_type)
quant_encoding, quant_configs = self.get_quant_encoding_conf(
node, is_input_tensor
tensor_source_node, target_build_node
)
dtype = self.get_data_type(tensor, quant_configs)
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
Expand All @@ -334,7 +341,7 @@ def define_tensor(
if quant_configs:
tensor = self.get_quant_tensor_value(
tensor,
node.meta[QCOM_QUANT_ATTRS],
tensor_source_node.meta[QCOM_QUANT_ATTRS],
quant_configs,
)
tensor_wrapper = PyQnnWrapper.TensorWrapper(
Expand Down
Loading

0 comments on commit 2e24b4e

Please sign in to comment.