Skip to content

Commit

Permalink
Add --quantize option to export_nanogpt.py
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Apr 5, 2024
1 parent 3815a72 commit 7f0a908
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 4 deletions.
3 changes: 2 additions & 1 deletion nanogpt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,5 @@ target_link_libraries(
nanogpt_runner
PRIVATE
extension_module
portable_ops_lib)
portable_ops_lib
xnnpack_backend)
67 changes: 64 additions & 3 deletions nanogpt/export_nanogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
from typing import Optional

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.backends.transforms.duplicate_dynamic_quant_chain import (
DuplicateDynamicQuantChainPass,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackPartitioner,
XnnpackDynamicallyQuantizedPartitioner,
)
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.exir import (
EdgeCompileConfig,
Expand All @@ -25,10 +31,48 @@
from test_utils import check_executorch_output_consistency, ErrorLimits
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.export import export

from torch.nn.attention import SDPBackend

# This variable sets the number of tokens to generate.
# It aims to balance between generating a sentence with sufficient information and maintaining reasonable computation time.
# In this case, we've set it to generate 20 tokens.
GENERATE_SEQ_LENGTH = 20


# This is a wrapper class for the NanoGPT model, designed for demonstration purposes.
# It includes a custom forward function that generates a sentence of a specified length
# based on a given tokenized prompt with a single forward pass.
# Please note that this wrapper is quite resource-intensive due to the inclusion of a for loop for sentence generation.
# For a more efficient sequence generation, please refer to the implementation in the llama runner.
class NanoGPT(nn.Module):
def __init__(self):
super().__init__()
self.model = GPT.from_pretrained("gpt2") # use gpt2 weight as pretrained weight

def forward(self, idx):
for _ in range(GENERATE_SEQ_LENGTH):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = (
idx
if idx.size(1) <= self.model.config.block_size
else idx[:, -self.model.config.block_size :]
)
# forward the model to get the logits for the index in the sequence
logits, _ = self.model(idx_cond)
# choose the highest probability token as the next index to continue the sequence with
idx_next = torch.argmax(logits).view(1, 1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

return idx


def main(args):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Prep ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -41,7 +85,20 @@ def main(args):
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
m = capture_pre_autograd_graph(model, example_inputs)

if args.backend == "XnnPack":
if args.quantize:
# Use dynamic, per-channel quantization.
xnnpack_quant_config = get_symmetric_quantization_config(
is_per_channel=True, is_dynamic=True
)
xnnpack_quantizer = XNNPACKQuantizer()
xnnpack_quantizer.set_global(xnnpack_quant_config)

m = prepare_pt2e(m, xnnpack_quantizer)
m(*example_inputs)
m = convert_pt2e(m, fold_quantize=False)
DuplicateDynamicQuantChainPass()(m)

if args.backend.lower() == "xnnpack":
edge_config = get_xnnpack_edge_compile_config()
else:
edge_config = EdgeCompileConfig(_check_ir_validity=False)
Expand All @@ -56,6 +113,9 @@ def main(args):
if args.backend == "XnnPack":
print("Lowering to XnnPack...")
edge_manager = edge_manager.to_backend(XnnpackPartitioner())
elif args.quantize: # Note the using XnnpackPartitioner for everything should also work for quant.
print("Lowering to XNNPACK (quantized)...")
edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner())

print("Creating ExecuTorch program...")
et_program: ExecutorchProgramManager = edge_manager.to_executorch()
Expand Down Expand Up @@ -86,11 +146,12 @@ def main(args):
if __name__ == "__main__":
# Parse command-line arguments
parser = argparse.ArgumentParser(description="NanoGPT example.")
parser.add_argument("--backend", type=str, choices=["XnnPack", None], default=None)
parser.add_argument("--backend", type=str.lower, choices=["xnnpack", None], default=None)
parser.add_argument("--output_file", type=str, default="nanogpt.pte")
parser.add_argument("--verifiy_runtime", action="store_true", default=False)
parser.add_argument("--atol", type=float, default=None)
parser.add_argument("--rtol", type=float, default=None)
parser.add_argument("--quantize", action="store_true", default=False)

args = parser.parse_args()
main(args)

0 comments on commit 7f0a908

Please sign in to comment.