-
Notifications
You must be signed in to change notification settings - Fork 1
/
initializer.py
51 lines (35 loc) · 1.26 KB
/
initializer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from typing import Tuple
import torch
import utils
__INITIALIZER__ = {}
def register_initializer(name: str):
def wrapper(func):
if __INITIALIZER__.get(name, None):
raise NameError(f"Name {name} is already registered.")
__INITIALIZER__[name] = func
return func
return wrapper
def get_initializer(name: str):
if __INITIALIZER__.get(name, None) is None:
raise NameError(f"Name {name} is not defined.")
return __INITIALIZER__[name]
@register_initializer(name='gaussian')
def gaussian_initializer(shape: Tuple[int]) -> torch.Tensor:
return torch.randn(shape)
@register_initializer(name='spectral')
def spectral_initializer(amplitude: torch.Tensor, power_iteration: int) -> torch.Tensor:
intensity = amplitude ** 2
z0 = torch.randn(amplitude.shape).to(amplitude.device)
z0 = z0 / torch.norm(z0)
z0 = power_method_for_spectral_init(z0, power_iteration)
# scale eigenvector
z0 *= torch.sqrt(intensity.mean())
return z0
# =================
# Helper functions
# =================
def power_method_for_spectral_init(z0: torch.Tensor, iteration: int):
for _ in range(iteration):
z0 = utils.ifft2d(utils.fft2d(z0)) * torch.numel(z0)
z0 = z0 / torch.norm(z0)
return z0