|
2 | 2 | import os
|
3 | 3 | import shutil
|
4 | 4 | import time
|
5 |
| - |
| 5 | +import random |
6 | 6 | import torch
|
7 | 7 | import torch.nn as nn
|
8 | 8 | import torch.nn.parallel
|
9 | 9 | import torch.backends.cudnn as cudnn
|
10 | 10 | import torch.optim
|
11 | 11 | import torch.utils.data
|
12 | 12 | from torch.utils.tensorboard import SummaryWriter
|
| 13 | +from torch.utils.data.sampler import SubsetRandomSampler |
13 | 14 | import torchvision.transforms as transforms
|
14 | 15 | import torchvision.datasets as datasets
|
15 | 16 |
|
16 | 17 | import bayesian_torch.models.deterministic.resnet as resnet
|
17 | 18 | import numpy as np
|
18 | 19 | from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
|
19 | 20 |
|
| 21 | +from bayesian_torch.ao.quantization.quantize import enable_prepare, convert |
| 22 | +from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn |
| 23 | + |
20 | 24 | model_names = sorted(
|
21 | 25 | name
|
22 | 26 | for name in resnet.__dict__
|
|
59 | 63 | default="./checkpoint/bayesian",
|
60 | 64 | type=str,
|
61 | 65 | )
|
| 66 | +parser.add_argument( |
| 67 | + "--model-checkpoint", |
| 68 | + dest="model_checkpoint", |
| 69 | + help="Saved checkpoint for evaluating model", |
| 70 | + default="", |
| 71 | + type=str, |
| 72 | +) |
62 | 73 | parser.add_argument(
|
63 | 74 | "--moped-init-model",
|
64 | 75 | dest="moped_init_model",
|
|
97 | 108 | type=int,
|
98 | 109 | default=10,
|
99 | 110 | )
|
100 |
| -parser.add_argument("--mode", type=str, required=True, help="train | test") |
| 111 | +parser.add_argument("--mode", type=str, required=True, help="train | test | ptq | test_ptq") |
101 | 112 |
|
102 | 113 | parser.add_argument(
|
103 | 114 | "--num_monte_carlo",
|
@@ -221,6 +232,25 @@ def main():
|
221 | 232 | pin_memory=True,
|
222 | 233 | )
|
223 | 234 |
|
| 235 | + calib_loader = torch.utils.data.DataLoader( |
| 236 | + datasets.CIFAR10( |
| 237 | + root="./data", |
| 238 | + train=True, |
| 239 | + transform=transforms.Compose( |
| 240 | + [ |
| 241 | + transforms.ToTensor(), |
| 242 | + normalize, |
| 243 | + ] |
| 244 | + ), |
| 245 | + download=True, |
| 246 | + ), |
| 247 | + batch_size=args.batch_size, |
| 248 | + sampler=SubsetRandomSampler(random.sample(range(1, 50000), 100)), |
| 249 | + num_workers=args.workers, |
| 250 | + pin_memory=True, |
| 251 | + ) |
| 252 | + |
| 253 | + |
224 | 254 | if not os.path.exists(args.save_dir):
|
225 | 255 | os.makedirs(args.save_dir)
|
226 | 256 |
|
@@ -286,6 +316,57 @@ def main():
|
286 | 316 | model.load_state_dict(checkpoint["state_dict"])
|
287 | 317 | evaluate(args, model, val_loader)
|
288 | 318 |
|
| 319 | + elif args.mode == "ptq": |
| 320 | + if len(args.model_checkpoint) > 0: |
| 321 | + checkpoint_file = args.model_checkpoint |
| 322 | + else: |
| 323 | + print("please provide valid model-checkpoint") |
| 324 | + checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu")) |
| 325 | + |
| 326 | + ''' |
| 327 | + state_dict = checkpoint['state_dict'] |
| 328 | + new_state_dict = OrderedDict() |
| 329 | + for k, v in state_dict.items(): |
| 330 | + name = k[7:] # remove `module.` |
| 331 | + new_state_dict[name] = v |
| 332 | + print('load checkpoint...') |
| 333 | + ''' |
| 334 | + model.load_state_dict(checkpoint['state_dict']) |
| 335 | + |
| 336 | + |
| 337 | + # post-training quantization |
| 338 | + model_int8 = quantize(model, calib_loader, args) |
| 339 | + model_int8.eval() |
| 340 | + model_int8.cpu() |
| 341 | + |
| 342 | + for i, (data, target) in enumerate(calib_loader): |
| 343 | + data = data.cpu() |
| 344 | + |
| 345 | + with torch.no_grad(): |
| 346 | + traced_model = torch.jit.trace(model_int8, data) |
| 347 | + traced_model = torch.jit.freeze(traced_model) |
| 348 | + |
| 349 | + save_path = os.path.join( |
| 350 | + args.save_dir, |
| 351 | + 'quantized_bayesian_{}_cifar.pth'.format(args.arch)) |
| 352 | + traced_model.save(save_path) |
| 353 | + print('INT8 model checkpoint saved at ', save_path) |
| 354 | + print('Evaluating quantized INT8 model....') |
| 355 | + evaluate(args, traced_model, val_loader) |
| 356 | + |
| 357 | + elif args.mode =='test_ptq': |
| 358 | + print('load model...') |
| 359 | + if len(args.model_checkpoint) > 0: |
| 360 | + checkpoint_file = args.model_checkpoint |
| 361 | + else: |
| 362 | + print("please provide valid quantized model checkpoint") |
| 363 | + model_int8 = torch.jit.load(checkpoint_file) |
| 364 | + model_int8.eval() |
| 365 | + model_int8.cpu() |
| 366 | + model_int8 = torch.jit.freeze(model_int8) |
| 367 | + print('Evaluating the INT8 model....') |
| 368 | + evaluate(args, model_int8, val_loader) |
| 369 | + |
289 | 370 |
|
290 | 371 | def train(args, train_loader, model, criterion, optimizer, epoch, tb_writer=None):
|
291 | 372 | batch_time = AverageMeter()
|
@@ -482,6 +563,21 @@ def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
|
482 | 563 | """
|
483 | 564 | torch.save(state, filename)
|
484 | 565 |
|
| 566 | +def quantize(model, calib_loader, args, **kwargs): |
| 567 | + model.eval() |
| 568 | + model.cpu() |
| 569 | + model.qconfig = torch.quantization.get_default_qconfig("onednn") |
| 570 | + print('Preparing model for quantization....') |
| 571 | + enable_prepare(model) |
| 572 | + prepared_model = torch.quantization.prepare(model) |
| 573 | + print('Calibrating...') |
| 574 | + with torch.no_grad(): |
| 575 | + for batch_idx, (data, target) in enumerate(calib_loader): |
| 576 | + data = data.cpu() |
| 577 | + _ = prepared_model(data) |
| 578 | + print('Calibration complete....') |
| 579 | + quantized_model = convert(prepared_model) |
| 580 | + return quantized_model |
485 | 581 |
|
486 | 582 | class AverageMeter(object):
|
487 | 583 | """Computes and stores the average and current value"""
|
|
0 commit comments