-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathinit_like_torch.py
22 lines (20 loc) · 1001 Bytes
/
init_like_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
from chainer import links as L
import numpy as np
def init_like_torch(link):
# Mimic torch's default parameter initialization
# TODO(muupan): Use chainer's initializers when it is merged
for l in link.links():
if isinstance(l, L.Linear):
out_channels, in_channels = l.W.data.shape
stdv = 1 / np.sqrt(in_channels)
l.W.data[:] = np.random.uniform(-stdv, stdv, size=l.W.data.shape)
if l.b is not None:
l.b.data[:] = np.random.uniform(-stdv, stdv,
size=l.b.data.shape)
elif isinstance(l, L.Convolution2D):
out_channels, in_channels, kh, kw = l.W.data.shape
stdv = 1 / np.sqrt(in_channels * kh * kw)
l.W.data[:] = np.random.uniform(-stdv, stdv, size=l.W.data.shape)
if l.b is not None:
l.b.data[:] = np.random.uniform(-stdv, stdv,
size=l.b.data.shape)