|
46 | 46 | import torch.nn as nn
|
47 | 47 | import torch.nn.functional as F
|
48 | 48 | from torch.nn import Parameter
|
49 |
| -from ..base_variational_layer import BaseVariationalLayer_ |
| 49 | +from ..base_variational_layer import BaseVariationalLayer_, get_kernel_size |
50 | 50 | import math
|
51 | 51 |
|
52 | 52 | __all__ = [
|
@@ -255,26 +255,28 @@ def __init__(self,
|
255 | 255 | self.posterior_rho_init = posterior_rho_init,
|
256 | 256 | self.bias = bias
|
257 | 257 |
|
| 258 | + kernel_size = get_kernel_size(kernel_size, 2) |
| 259 | + |
258 | 260 | self.mu_kernel = Parameter(
|
259 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
260 |
| - kernel_size)) |
| 261 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 262 | + kernel_size[1])) |
261 | 263 | self.rho_kernel = Parameter(
|
262 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
263 |
| - kernel_size)) |
| 264 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 265 | + kernel_size[1])) |
264 | 266 | self.register_buffer(
|
265 | 267 | 'eps_kernel',
|
266 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
267 |
| - kernel_size), |
| 268 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 269 | + kernel_size[1]), |
268 | 270 | persistent=False)
|
269 | 271 | self.register_buffer(
|
270 | 272 | 'prior_weight_mu',
|
271 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
272 |
| - kernel_size), |
| 273 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 274 | + kernel_size[1]), |
273 | 275 | persistent=False)
|
274 | 276 | self.register_buffer(
|
275 | 277 | 'prior_weight_sigma',
|
276 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
277 |
| - kernel_size), |
| 278 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 279 | + kernel_size[1]), |
278 | 280 | persistent=False)
|
279 | 281 |
|
280 | 282 | if self.bias:
|
@@ -403,27 +405,27 @@ def __init__(self,
|
403 | 405 | # variance of weight --> sigma = log (1 + exp(rho))
|
404 | 406 | self.posterior_rho_init = posterior_rho_init,
|
405 | 407 | self.bias = bias
|
406 |
| - |
| 408 | + kernel_size = get_kernel_size(kernel_size, 3) |
407 | 409 | self.mu_kernel = Parameter(
|
408 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
409 |
| - kernel_size, kernel_size)) |
| 410 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 411 | + kernel_size[1], kernel_size[2])) |
410 | 412 | self.rho_kernel = Parameter(
|
411 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
412 |
| - kernel_size, kernel_size)) |
| 413 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 414 | + kernel_size[1], kernel_size[2])) |
413 | 415 | self.register_buffer(
|
414 | 416 | 'eps_kernel',
|
415 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
416 |
| - kernel_size, kernel_size), |
| 417 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 418 | + kernel_size[1], kernel_size[2]), |
417 | 419 | persistent=False)
|
418 | 420 | self.register_buffer(
|
419 | 421 | 'prior_weight_mu',
|
420 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
421 |
| - kernel_size, kernel_size), |
| 422 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 423 | + kernel_size[1], kernel_size[2]), |
422 | 424 | persistent=False)
|
423 | 425 | self.register_buffer(
|
424 | 426 | 'prior_weight_sigma',
|
425 |
| - torch.Tensor(out_channels, in_channels // groups, kernel_size, |
426 |
| - kernel_size, kernel_size), |
| 427 | + torch.Tensor(out_channels, in_channels // groups, kernel_size[0], |
| 428 | + kernel_size[1], kernel_size[2]), |
427 | 429 | persistent=False)
|
428 | 430 |
|
429 | 431 | if self.bias:
|
@@ -698,27 +700,27 @@ def __init__(self,
|
698 | 700 | # variance of weight --> sigma = log (1 + exp(rho))
|
699 | 701 | self.posterior_rho_init = posterior_rho_init,
|
700 | 702 | self.bias = bias
|
701 |
| - |
| 703 | + kernel_size = get_kernel_size(kernel_size, 2) |
702 | 704 | self.mu_kernel = Parameter(
|
703 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
704 |
| - kernel_size)) |
| 705 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 706 | + kernel_size[1])) |
705 | 707 | self.rho_kernel = Parameter(
|
706 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
707 |
| - kernel_size)) |
| 708 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 709 | + kernel_size[1])) |
708 | 710 | self.register_buffer(
|
709 | 711 | 'eps_kernel',
|
710 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
711 |
| - kernel_size), |
| 712 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 713 | + kernel_size[1]), |
712 | 714 | persistent=False)
|
713 | 715 | self.register_buffer(
|
714 | 716 | 'prior_weight_mu',
|
715 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
716 |
| - kernel_size), |
| 717 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 718 | + kernel_size[1]), |
717 | 719 | persistent=False)
|
718 | 720 | self.register_buffer(
|
719 | 721 | 'prior_weight_sigma',
|
720 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
721 |
| - kernel_size), |
| 722 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 723 | + kernel_size[1]), |
722 | 724 | persistent=False)
|
723 | 725 |
|
724 | 726 | if self.bias:
|
@@ -850,27 +852,27 @@ def __init__(self,
|
850 | 852 | # variance of weight --> sigma = log (1 + exp(rho))
|
851 | 853 | self.posterior_rho_init = posterior_rho_init,
|
852 | 854 | self.bias = bias
|
853 |
| - |
| 855 | + kernel_size = get_kernel_size(kernel_size, 3) |
854 | 856 | self.mu_kernel = Parameter(
|
855 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
856 |
| - kernel_size, kernel_size)) |
| 857 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 858 | + kernel_size[1], kernel_size[2])) |
857 | 859 | self.rho_kernel = Parameter(
|
858 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
859 |
| - kernel_size, kernel_size)) |
| 860 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 861 | + kernel_size[1], kernel_size[2])) |
860 | 862 | self.register_buffer(
|
861 | 863 | 'eps_kernel',
|
862 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
863 |
| - kernel_size, kernel_size), |
| 864 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 865 | + kernel_size[1], kernel_size[2]), |
864 | 866 | persistent=False)
|
865 | 867 | self.register_buffer(
|
866 | 868 | 'prior_weight_mu',
|
867 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
868 |
| - kernel_size, kernel_size), |
| 869 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 870 | + kernel_size[1], kernel_size[2]), |
869 | 871 | persistent=False)
|
870 | 872 | self.register_buffer(
|
871 | 873 | 'prior_weight_sigma',
|
872 |
| - torch.Tensor(in_channels, out_channels // groups, kernel_size, |
873 |
| - kernel_size, kernel_size), |
| 874 | + torch.Tensor(in_channels, out_channels // groups, kernel_size[0], |
| 875 | + kernel_size[1], kernel_size[2]), |
874 | 876 | persistent=False)
|
875 | 877 |
|
876 | 878 | if self.bias:
|
|
0 commit comments