Skip to content

Commit 984fd10

Browse files
Merge pull request #18 from msubedar/main
Added support for arbitrary kernel sizes for Bayesian Conv layers
2 parents f6f516e + a8543ad commit 984fd10

File tree

4 files changed

+99
-89
lines changed

4 files changed

+99
-89
lines changed

bayesian_torch/layers/base_variational_layer.py

+6
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929
import torch
3030
import torch.nn as nn
3131
import torch.distributions as distributions
32+
from itertools import repeat
33+
import collections
3234

35+
def get_kernel_size(x, n):
36+
if isinstance(x, collections.abc.Iterable):
37+
return tuple(x)
38+
return tuple(repeat(x, n))
3339

3440
class BaseVariationalLayer_(nn.Module):
3541
def __init__(self):

bayesian_torch/layers/flipout_layers/conv_flipout.py

+46-44
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import torch
3737
import torch.nn as nn
3838
import torch.nn.functional as F
39-
from ..base_variational_layer import BaseVariationalLayer_
39+
from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size
4040

4141
from torch.distributions.normal import Normal
4242
from torch.distributions.uniform import Uniform
@@ -263,28 +263,28 @@ def __init__(self,
263263
self.bias = bias
264264

265265
self.kl = 0
266-
266+
kernel_size = get_kernel_size(kernel_size, 2)
267267
self.mu_kernel = nn.Parameter(
268-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
269-
kernel_size))
268+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
269+
kernel_size[1]))
270270
self.rho_kernel = nn.Parameter(
271-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
272-
kernel_size))
271+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
272+
kernel_size[1]))
273273

274274
self.register_buffer(
275275
'eps_kernel',
276-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
277-
kernel_size),
276+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
277+
kernel_size[1]),
278278
persistent=False)
279279
self.register_buffer(
280280
'prior_weight_mu',
281-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
282-
kernel_size),
281+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
282+
kernel_size[1]),
283283
persistent=False)
284284
self.register_buffer(
285285
'prior_weight_sigma',
286-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
287-
kernel_size),
286+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
287+
kernel_size[1]),
288288
persistent=False)
289289

290290
if self.bias:
@@ -430,27 +430,29 @@ def __init__(self,
430430
self.posterior_mu_init = posterior_mu_init
431431
self.posterior_rho_init = posterior_rho_init
432432

433+
kernel_size = get_kernel_size(kernel_size, 3)
434+
433435
self.mu_kernel = nn.Parameter(
434-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
435-
kernel_size, kernel_size))
436+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
437+
kernel_size[1], kernel_size[2]))
436438
self.rho_kernel = nn.Parameter(
437-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
438-
kernel_size, kernel_size))
439+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
440+
kernel_size[1], kernel_size[2]))
439441

440442
self.register_buffer(
441443
'eps_kernel',
442-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
443-
kernel_size, kernel_size),
444+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
445+
kernel_size[1], kernel_size[2]),
444446
persistent=False)
445447
self.register_buffer(
446448
'prior_weight_mu',
447-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
448-
kernel_size, kernel_size),
449+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
450+
kernel_size[1], kernel_size[2]),
449451
persistent=False)
450452
self.register_buffer(
451453
'prior_weight_sigma',
452-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
453-
kernel_size, kernel_size),
454+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
455+
kernel_size[1], kernel_size[2]),
454456
persistent=False)
455457

456458
if self.bias:
@@ -760,28 +762,28 @@ def __init__(self,
760762
self.prior_variance = prior_variance
761763
self.posterior_mu_init = posterior_mu_init
762764
self.posterior_rho_init = posterior_rho_init
763-
765+
kernel_size = get_kernel_size(kernel_size, 2)
764766
self.mu_kernel = nn.Parameter(
765-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
766-
kernel_size))
767+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
768+
kernel_size[1]))
767769
self.rho_kernel = nn.Parameter(
768-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
769-
kernel_size))
770+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
771+
kernel_size[1]))
770772

771773
self.register_buffer(
772774
'eps_kernel',
773-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
774-
kernel_size),
775+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
776+
kernel_size[1]),
775777
persistent=False)
776778
self.register_buffer(
777779
'prior_weight_mu',
778-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
779-
kernel_size),
780+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
781+
kernel_size[1]),
780782
persistent=False)
781783
self.register_buffer(
782784
'prior_weight_sigma',
783-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
784-
kernel_size),
785+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
786+
kernel_size[1]),
785787
persistent=False)
786788

787789
if self.bias:
@@ -928,28 +930,28 @@ def __init__(self,
928930
self.bias = bias
929931

930932
self.kl = 0
931-
933+
kernel_size = get_kernel_size(kernel_size, 3)
932934
self.mu_kernel = nn.Parameter(
933-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
934-
kernel_size, kernel_size))
935+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
936+
kernel_size[1], kernel_size[2]))
935937
self.rho_kernel = nn.Parameter(
936-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
937-
kernel_size, kernel_size))
938+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
939+
kernel_size[1], kernel_size[2]))
938940

939941
self.register_buffer(
940942
'eps_kernel',
941-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
942-
kernel_size, kernel_size),
943+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
944+
kernel_size[1], kernel_size[2]),
943945
persistent=False)
944946
self.register_buffer(
945947
'prior_weight_mu',
946-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
947-
kernel_size, kernel_size),
948+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
949+
kernel_size[1], kernel_size[2]),
948950
persistent=False)
949951
self.register_buffer(
950952
'prior_weight_sigma',
951-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
952-
kernel_size, kernel_size),
953+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
954+
kernel_size[1], kernel_size[2]),
953955
persistent=False)
954956

955957
if self.bias:

bayesian_torch/layers/variational_layers/conv_variational.py

+46-44
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
import torch.nn as nn
4747
import torch.nn.functional as F
4848
from torch.nn import Parameter
49-
from ..base_variational_layer import BaseVariationalLayer_
49+
from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size
5050
import math
5151

5252
__all__ = [
@@ -255,26 +255,28 @@ def __init__(self,
255255
self.posterior_rho_init = posterior_rho_init,
256256
self.bias = bias
257257

258+
kernel_size = get_kernel_size(kernel_size, 2)
259+
258260
self.mu_kernel = Parameter(
259-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
260-
kernel_size))
261+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
262+
kernel_size[1]))
261263
self.rho_kernel = Parameter(
262-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
263-
kernel_size))
264+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
265+
kernel_size[1]))
264266
self.register_buffer(
265267
'eps_kernel',
266-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
267-
kernel_size),
268+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
269+
kernel_size[1]),
268270
persistent=False)
269271
self.register_buffer(
270272
'prior_weight_mu',
271-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
272-
kernel_size),
273+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
274+
kernel_size[1]),
273275
persistent=False)
274276
self.register_buffer(
275277
'prior_weight_sigma',
276-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
277-
kernel_size),
278+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
279+
kernel_size[1]),
278280
persistent=False)
279281

280282
if self.bias:
@@ -403,27 +405,27 @@ def __init__(self,
403405
# variance of weight --> sigma = log (1 + exp(rho))
404406
self.posterior_rho_init = posterior_rho_init,
405407
self.bias = bias
406-
408+
kernel_size = get_kernel_size(kernel_size, 3)
407409
self.mu_kernel = Parameter(
408-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
409-
kernel_size, kernel_size))
410+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
411+
kernel_size[1], kernel_size[2]))
410412
self.rho_kernel = Parameter(
411-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
412-
kernel_size, kernel_size))
413+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
414+
kernel_size[1], kernel_size[2]))
413415
self.register_buffer(
414416
'eps_kernel',
415-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
416-
kernel_size, kernel_size),
417+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
418+
kernel_size[1], kernel_size[2]),
417419
persistent=False)
418420
self.register_buffer(
419421
'prior_weight_mu',
420-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
421-
kernel_size, kernel_size),
422+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
423+
kernel_size[1], kernel_size[2]),
422424
persistent=False)
423425
self.register_buffer(
424426
'prior_weight_sigma',
425-
torch.Tensor(out_channels, in_channels // groups, kernel_size,
426-
kernel_size, kernel_size),
427+
torch.Tensor(out_channels, in_channels // groups, kernel_size[0],
428+
kernel_size[1], kernel_size[2]),
427429
persistent=False)
428430

429431
if self.bias:
@@ -698,27 +700,27 @@ def __init__(self,
698700
# variance of weight --> sigma = log (1 + exp(rho))
699701
self.posterior_rho_init = posterior_rho_init,
700702
self.bias = bias
701-
703+
kernel_size = get_kernel_size(kernel_size, 2)
702704
self.mu_kernel = Parameter(
703-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
704-
kernel_size))
705+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
706+
kernel_size[1]))
705707
self.rho_kernel = Parameter(
706-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
707-
kernel_size))
708+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
709+
kernel_size[1]))
708710
self.register_buffer(
709711
'eps_kernel',
710-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
711-
kernel_size),
712+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
713+
kernel_size[1]),
712714
persistent=False)
713715
self.register_buffer(
714716
'prior_weight_mu',
715-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
716-
kernel_size),
717+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
718+
kernel_size[1]),
717719
persistent=False)
718720
self.register_buffer(
719721
'prior_weight_sigma',
720-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
721-
kernel_size),
722+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
723+
kernel_size[1]),
722724
persistent=False)
723725

724726
if self.bias:
@@ -850,27 +852,27 @@ def __init__(self,
850852
# variance of weight --> sigma = log (1 + exp(rho))
851853
self.posterior_rho_init = posterior_rho_init,
852854
self.bias = bias
853-
855+
kernel_size = get_kernel_size(kernel_size, 3)
854856
self.mu_kernel = Parameter(
855-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
856-
kernel_size, kernel_size))
857+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
858+
kernel_size[1], kernel_size[2]))
857859
self.rho_kernel = Parameter(
858-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
859-
kernel_size, kernel_size))
860+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
861+
kernel_size[1], kernel_size[2]))
860862
self.register_buffer(
861863
'eps_kernel',
862-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
863-
kernel_size, kernel_size),
864+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
865+
kernel_size[1], kernel_size[2]),
864866
persistent=False)
865867
self.register_buffer(
866868
'prior_weight_mu',
867-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
868-
kernel_size, kernel_size),
869+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
870+
kernel_size[1], kernel_size[2]),
869871
persistent=False)
870872
self.register_buffer(
871873
'prior_weight_sigma',
872-
torch.Tensor(in_channels, out_channels // groups, kernel_size,
873-
kernel_size, kernel_size),
874+
torch.Tensor(in_channels, out_channels // groups, kernel_size[0],
875+
kernel_size[1], kernel_size[2]),
874876
persistent=False)
875877

876878
if self.bias:

bayesian_torch/models/dnn_to_bnn.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def bnn_conv_layer(params, d):
7979
bnn_layer = layer_fn(
8080
in_channels=d.in_channels,
8181
out_channels=d.out_channels,
82-
kernel_size=d.kernel_size[0],
82+
kernel_size=d.kernel_size,
8383
stride=d.stride,
8484
padding=d.padding,
8585
dilation=d.dilation,

0 commit comments

Comments
 (0)