Skip to content

Commit

Permalink
support to parse sub-dict for assign_params_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
Snowdar committed Jun 10, 2020
1 parent a757bdb commit 56a1348
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pytorch/launcher/runResnetXvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@
"training":True, "extracted_embedding":"far",
"resnet_params":{
"head_conv":True, "head_conv_params":{"kernel_size":3, "stride":1, "padding":1},
"head_maxpool":True, "head_maxpool_params":{"kernel_size":3, "stride":2, "padding":1},
"head_maxpool":False, "head_maxpool_params":{"kernel_size":3, "stride":2, "padding":1},
"block":"BasicBlock",
"layers":[3, 4, 6, 3],
"planes":[32, 64, 128, 256],
Expand Down
16 changes: 10 additions & 6 deletions pytorch/libs/nnet/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def step(self, epoch, iter):


class GruAffine(torch.nn.Module):
"""xmuspeech (Author: LZ) 2020-02-05
A GRU affine component.
"""A GRU affine component.
Author: Zheng Li xmuspeech 2020-02-05
"""
def __init__(self, input_dim, output_dim):
super(GruAffine, self).__init__()
Expand Down Expand Up @@ -452,7 +452,7 @@ def extra_repr(self):

class AttentiveStatisticsPooling(torch.nn.Module):
""" An attentive statistics pooling layer according to []"""
def __init__(self, input_dim, hidden_size=64, context=[0], stddev=True, eps=1.0e-10):
def __init__(self, input_dim, hidden_size=64, context=[0], stddev=True, stddev_attention=False, eps=1.0e-10):
super(AttentiveStatisticsPooling, self).__init__()

self.stddev = stddev
Expand All @@ -464,7 +464,7 @@ def __init__(self, input_dim, hidden_size=64, context=[0], stddev=True, eps=1.0e
self.output_dim = input_dim

self.eps = eps

self.stddev_attention = stddev_attention
self.attention = AttentionAlphaComponent(input_dim, hidden_size, context)

def forward(self, inputs):
Expand All @@ -480,8 +480,12 @@ def forward(self, inputs):
mean = torch.sum(alpha * inputs, dim=2, keepdim=True)

if self.stddev :
var = torch.sum(alpha * inputs**2, dim=2, keepdim=True) - mean**2
std = torch.sqrt(var.clamp(min=self.eps))
if self.stddev_attention:
var = torch.sum(alpha * inputs**2, dim=2, keepdim=True) - mean**2
std = torch.sqrt(var.clamp(min=self.eps))
else:
var = torch.mean((inputs - mean)**2, dim=2)
std = torch.unsqueeze(torch.sqrt(var.clamp(min=self.eps)), dim=2)
return torch.cat((mean, std), dim=1)
else :
return mean
Expand Down
7 changes: 4 additions & 3 deletions pytorch/libs/nnet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def step(self, lambda_factor):
self.lambda_factor = lambda_factor

def extra_repr(self):
return '(~affine): ~(input_dim={input_dim}, num_targets={num_targets}, method={method}, double={double}, ' \
return '(~affine): (input_dim={input_dim}, num_targets={num_targets}, method={method}, double={double}, ' \
'margin={m}, s={s}, t={t}, feature_normalize={feature_normalize}, mhe_loss={mhe_loss}, mhe_w={mhe_w}, ' \
'eps={eps})'.format(**self.__dict__)

Expand All @@ -294,9 +294,10 @@ class CurricularMarginComponent(torch.nn.Module):
Reference: Huang, Yuge, Yuhan Wang, Ying Tai, Xiaoming Liu, Pengcheng Shen, Shaoxin Li, Jilin Li,
and Feiyue Huang. 2020. “CurricularFace: Adaptive Curriculum Learning Loss for Deep Face
Recognition.” ArXiv E-Prints arXiv:2004.00288.
Github: https://github.com/HuangYG123/CurricularFace.
Github: https://github.com/HuangYG123/CurricularFace. Note, the momentum of this github is a wrong value w.r.t
the above paper. The momentum 't' should not increase so fast and I have corrected it as follow.
"""
def __init__(self, momentum=0.99):
def __init__(self, momentum=0.01):
super(CurricularMarginComponent, self).__init__()
self.momentum = momentum
self.register_buffer('t', torch.zeros(1))
Expand Down
7 changes: 6 additions & 1 deletion pytorch/libs/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,12 @@ def assign_params_dict(default_params:dict, params:dict, force_check=False, supp
for k, v in default_params.items():
if k in params_keys:
if isinstance(v, type(params[k])):
default_params[k] = params[k]
if isinstance(v, dict):
# To parse a sub-dict.
sub_params = assign_params_dict(v, params[k], force_check, support_unknow)
default_params[k] = sub_params
else:
default_params[k] = params[k]
elif isinstance(v, float) and isinstance(params[k], int):
default_params[k] = params[k] * 1.0
elif v is None or params[k] is None:
Expand Down
2 changes: 1 addition & 1 deletion pytorch/model/resnet-xvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def init(self, inputs_dim, num_targets, aug_dropout=0., tail_dropout=0., trainin
## Params.
default_resnet_params = {
"head_conv":True, "head_conv_params":{"kernel_size":3, "stride":1, "padding":1},
"head_maxpool":True, "head_maxpool_params":{"kernel_size":3, "stride":1, "padding":1},
"head_maxpool":False, "head_maxpool_params":{"kernel_size":3, "stride":1, "padding":1},
"block":"BasicBlock",
"layers":[3, 4, 6, 3],
"planes":[32, 64, 128, 256], # a.k.a channels.
Expand Down
2 changes: 1 addition & 1 deletion pytorch/model/snowdar-xvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def init(self, inputs_dim, num_targets, extend=False, skip_connection=False,
SE=False, se_ratio=4,
tdnn_layer_params={},
tdnn6=True, tdnn7_params={},
attentive_pooling=False, attentive_pooling_params={"hidden_size":64},
attentive_pooling=False, attentive_pooling_params={"hidden_size":64, "stddev_attention":False},
LDE_pooling=False, LDE_pooling_params={"c_num":64, "nodes":128},
focal_loss=False, focal_loss_params={"gamma":2},
margin_loss=False, margin_loss_params={},
Expand Down
2 changes: 1 addition & 1 deletion removeUtt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ num=0
if [ "$list" != "" ];then
num=`echo "$list" | wc -l | awk '{print $1}'`
else
echo "Need to removing nothing. It means that your datadir will be recovered form bakeup if you used this script before."
echo "Need to remove nothing. It means that your datadir will be recovered form bakeup if you used this script before."
fi

echo -e "[`echo $list`] $num utts here will be removed."
Expand Down

0 comments on commit 56a1348

Please sign in to comment.