Skip to content

Commit c3e9a0f

Browse files
include example for performing post training quantization of Bayesian
neural network models using Bayesian-Torch Quantization framework. Signed-off-by: Ranganath Krishnan <[email protected]>
1 parent 52cea4f commit c3e9a0f

File tree

4 files changed

+129
-14
lines changed

4 files changed

+129
-14
lines changed

README.md

+9-1
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ Bayesian-Torch is designed to be flexible and enables seamless extension of dete
4444

4545

4646
**Key features:**
47-
* [dnn_to_bnn()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/models/dnn_to_bnn.py#L127): An API to convert deterministic deep neural network (dnn) model of any architecture to Bayesian deep neural network (bnn) model, simplifying the model definition i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers. This will enable seamless conversion of existing topology of larger models to Bayesian deep neural network models for extending towards uncertainty-aware applications.
47+
* [dnn_to_bnn()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/models/dnn_to_bnn.py#L127): Seamless conversion of model to be Uncertainty-aware. An API to convert deterministic deep neural network (dnn) model of any architecture to Bayesian deep neural network (bnn) model, simplifying the model definition i.e. drop-in replacements of Convolutional, Linear and LSTM layers to corresponding Bayesian layers. This will enable seamless conversion of existing topology of larger models to Bayesian deep neural network models for extending towards uncertainty-aware applications.
4848
* [MOPED](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/utils/util.py#L72): Specifying weight priors and variational posteriors in Bayesian neural networks with Empirical Bayes [[Krishnan et al. 2020](https://ojs.aaai.org/index.php/AAAI/article/view/5875)]
49+
* [Quantization](https://github.com/IntelLabs/bayesian-torch/tree/main/bayesian_torch/ao): Post Training Quantization of Bayesian deep neural network models with simple API's [enable_prepare()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/ao/quantization/quantize.py#L134) and [convert()](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/ao/quantization/quantize.py#L160)
4950
* [AvUC](https://github.com/IntelLabs/bayesian-torch/blob/main/bayesian_torch/utils/avuc_loss.py): Accuracy versus Uncertainty Calibration loss [[Krishnan and Tickoo 2020](https://proceedings.neurips.cc/paper/2020/file/d3d9446802a44259755d38e6d163e820-Paper.pdf)]
5051

5152
## Installing Bayesian-Torch
@@ -198,6 +199,13 @@ To evaluate deterministic ResNet on CIFAR10, run this command:
198199
sh scripts/test_deterministic_cifar.sh
199200
```
200201

202+
### Post Training Quantization (PTQ)
203+
204+
To quantize Bayesian ResNet (convert to INT8) and evaluate on CIFAR10, run this command:
205+
```test
206+
sh scripts/quantize_bayesian_cifar.sh
207+
```
208+
201209
## Citing
202210

203211
If you use this code, please cite as:

bayesian_torch/ao/quantization/quantize.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,12 @@ def prepare(model):
152152
qmodel.load_state_dict(model.state_dict())
153153
qmodel.eval()
154154
enable_prepare(qmodel)
155-
qmodel.qconfig = torch.quantization.get_default_qconfig("fbgemm")
155+
qmodel.qconfig = torch.quantization.get_default_qconfig("onednn")
156156
qmodel = torch.quantization.prepare(qmodel)
157157

158158
return qmodel
159159

160160
def convert(model):
161161
qmodel = torch.quantization.convert(model) # torch layers
162162
bnn_to_qbnn(qmodel) # bayesian layers
163-
return qmodel
163+
return qmodel

bayesian_torch/examples/main_bayesian_cifar_dnn2bnn.py

+98-2
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,25 @@
22
import os
33
import shutil
44
import time
5-
5+
import random
66
import torch
77
import torch.nn as nn
88
import torch.nn.parallel
99
import torch.backends.cudnn as cudnn
1010
import torch.optim
1111
import torch.utils.data
1212
from torch.utils.tensorboard import SummaryWriter
13+
from torch.utils.data.sampler import SubsetRandomSampler
1314
import torchvision.transforms as transforms
1415
import torchvision.datasets as datasets
1516

1617
import bayesian_torch.models.deterministic.resnet as resnet
1718
import numpy as np
1819
from bayesian_torch.models.dnn_to_bnn import dnn_to_bnn, get_kl_loss
1920

21+
from bayesian_torch.ao.quantization.quantize import enable_prepare, convert
22+
from bayesian_torch.models.bnn_to_qbnn import bnn_to_qbnn
23+
2024
model_names = sorted(
2125
name
2226
for name in resnet.__dict__
@@ -59,6 +63,13 @@
5963
default="./checkpoint/bayesian",
6064
type=str,
6165
)
66+
parser.add_argument(
67+
"--model-checkpoint",
68+
dest="model_checkpoint",
69+
help="Saved checkpoint for evaluating model",
70+
default="",
71+
type=str,
72+
)
6273
parser.add_argument(
6374
"--moped-init-model",
6475
dest="moped_init_model",
@@ -97,7 +108,7 @@
97108
type=int,
98109
default=10,
99110
)
100-
parser.add_argument("--mode", type=str, required=True, help="train | test")
111+
parser.add_argument("--mode", type=str, required=True, help="train | test | ptq | test_ptq")
101112

102113
parser.add_argument(
103114
"--num_monte_carlo",
@@ -221,6 +232,25 @@ def main():
221232
pin_memory=True,
222233
)
223234

235+
calib_loader = torch.utils.data.DataLoader(
236+
datasets.CIFAR10(
237+
root="./data",
238+
train=True,
239+
transform=transforms.Compose(
240+
[
241+
transforms.ToTensor(),
242+
normalize,
243+
]
244+
),
245+
download=True,
246+
),
247+
batch_size=args.batch_size,
248+
sampler=SubsetRandomSampler(random.sample(range(1, 50000), 100)),
249+
num_workers=args.workers,
250+
pin_memory=True,
251+
)
252+
253+
224254
if not os.path.exists(args.save_dir):
225255
os.makedirs(args.save_dir)
226256

@@ -286,6 +316,57 @@ def main():
286316
model.load_state_dict(checkpoint["state_dict"])
287317
evaluate(args, model, val_loader)
288318

319+
elif args.mode == "ptq":
320+
if len(args.model_checkpoint) > 0:
321+
checkpoint_file = args.model_checkpoint
322+
else:
323+
print("please provide valid model-checkpoint")
324+
checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
325+
326+
'''
327+
state_dict = checkpoint['state_dict']
328+
new_state_dict = OrderedDict()
329+
for k, v in state_dict.items():
330+
name = k[7:] # remove `module.`
331+
new_state_dict[name] = v
332+
print('load checkpoint...')
333+
'''
334+
model.load_state_dict(checkpoint['state_dict'])
335+
336+
337+
# post-training quantization
338+
model_int8 = quantize(model, calib_loader, args)
339+
model_int8.eval()
340+
model_int8.cpu()
341+
342+
for i, (data, target) in enumerate(calib_loader):
343+
data = data.cpu()
344+
345+
with torch.no_grad():
346+
traced_model = torch.jit.trace(model_int8, data)
347+
traced_model = torch.jit.freeze(traced_model)
348+
349+
save_path = os.path.join(
350+
args.save_dir,
351+
'quantized_bayesian_{}_cifar.pth'.format(args.arch))
352+
traced_model.save(save_path)
353+
print('INT8 model checkpoint saved at ', save_path)
354+
print('Evaluating quantized INT8 model....')
355+
evaluate(args, traced_model, val_loader)
356+
357+
elif args.mode =='test_ptq':
358+
print('load model...')
359+
if len(args.model_checkpoint) > 0:
360+
checkpoint_file = args.model_checkpoint
361+
else:
362+
print("please provide valid quantized model checkpoint")
363+
model_int8 = torch.jit.load(checkpoint_file)
364+
model_int8.eval()
365+
model_int8.cpu()
366+
model_int8 = torch.jit.freeze(model_int8)
367+
print('Evaluating the INT8 model....')
368+
evaluate(args, model_int8, val_loader)
369+
289370

290371
def train(args, train_loader, model, criterion, optimizer, epoch, tb_writer=None):
291372
batch_time = AverageMeter()
@@ -482,6 +563,21 @@ def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
482563
"""
483564
torch.save(state, filename)
484565

566+
def quantize(model, calib_loader, args, **kwargs):
567+
model.eval()
568+
model.cpu()
569+
model.qconfig = torch.quantization.get_default_qconfig("onednn")
570+
print('Preparing model for quantization....')
571+
enable_prepare(model)
572+
prepared_model = torch.quantization.prepare(model)
573+
print('Calibrating...')
574+
with torch.no_grad():
575+
for batch_idx, (data, target) in enumerate(calib_loader):
576+
data = data.cpu()
577+
_ = prepared_model(data)
578+
print('Calibration complete....')
579+
quantized_model = convert(prepared_model)
580+
return quantized_model
485581

486582
class AverageMeter(object):
487583
"""Computes and stores the average and current value"""

bayesian_torch/models/deterministic/resnet.py

+20-9
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,22 @@ def __init__(self, in_planes, planes, stride=1, option='A'):
4343
padding=1,
4444
bias=False)
4545
self.bn1 = nn.BatchNorm2d(planes)
46+
self.relu1 = nn.ReLU(inplace=True)
4647
self.conv2 = nn.Conv2d(planes,
4748
planes,
4849
kernel_size=3,
4950
stride=1,
5051
padding=1,
5152
bias=False)
5253
self.bn2 = nn.BatchNorm2d(planes)
53-
54+
self.skip_add = nn.quantized.FloatFunctional()
5455
self.shortcut = nn.Sequential()
56+
self.relu2 = nn.ReLU(inplace=True)
57+
5558
if stride != 1 or in_planes != planes:
5659
if option == 'A':
5760
self.shortcut = LambdaLayer(lambda x: F.pad(
58-
x[:, :, ::2, ::2],
61+
x[:, :, ::2, ::2].contiguous(),
5962
(0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
6063
elif option == 'B':
6164
self.shortcut = nn.Sequential(
@@ -67,12 +70,18 @@ def __init__(self, in_planes, planes, stride=1, option='A'):
6770
nn.BatchNorm2d(self.expansion * planes))
6871

6972
def forward(self, x):
70-
out = F.relu(self.bn1(self.conv1(x)))
71-
out = self.bn2(self.conv2(out))
72-
out += self.shortcut(x)
73-
out = F.relu(out)
73+
identity = self.shortcut(x)
74+
out = self.conv1(x)
75+
out = self.bn1(out)
76+
out = self.relu1(out)
77+
78+
out = self.conv2(out)
79+
out = self.bn2(out)
80+
out = self.skip_add.add(out, identity)
81+
#out += self.shortcut(x)
82+
out = self.relu2(out)
7483
return out
75-
84+
7685

7786
class ResNet(nn.Module):
7887
def __init__(self, block, num_blocks, num_classes=10):
@@ -86,6 +95,7 @@ def __init__(self, block, num_blocks, num_classes=10):
8695
padding=1,
8796
bias=False)
8897
self.bn1 = nn.BatchNorm2d(16)
98+
self.relu1 = nn.ReLU(inplace=True)
8999
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
90100
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
91101
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
@@ -103,7 +113,9 @@ def _make_layer(self, block, planes, num_blocks, stride):
103113
return nn.Sequential(*layers)
104114

105115
def forward(self, x):
106-
out = F.relu(self.bn1(self.conv1(x)))
116+
out = self.conv1(x)
117+
out = self.bn1(out)
118+
out = self.relu1(out)
107119
out = self.layer1(out)
108120
out = self.layer2(out)
109121
out = self.layer3(out)
@@ -112,7 +124,6 @@ def forward(self, x):
112124
out = self.linear(out)
113125
return out
114126

115-
116127
def resnet20():
117128
return ResNet(BasicBlock, [3, 3, 3])
118129

0 commit comments

Comments
 (0)