-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdomain_discriminator.py
56 lines (46 loc) · 1.85 KB
/
domain_discriminator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
@author: Junguang Jiang
@contact: [email protected]
"""
from typing import List, Dict
import torch.nn as nn
__all__ = ['DomainDiscriminator']
class DomainDiscriminator(nn.Sequential):
r"""Domain discriminator model from
`Domain-Adversarial Training of Neural Networks (ICML 2015) <https://arxiv.org/abs/1505.07818>`_
Distinguish whether the input features come from the source domain or the target domain.
The source domain label is 1 and the target domain label is 0.
Args:
in_feature (int): dimension of the input feature
hidden_size (int): dimension of the hidden features
batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`.
Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True.
Shape:
- Inputs: (minibatch, `in_feature`)
- Outputs: :math:`(minibatch, 1)`
"""
def __init__(self, in_feature: int, hidden_size: int, batch_norm=True):
if batch_norm:
super(DomainDiscriminator, self).__init__(
nn.Linear(in_feature, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
else:
super(DomainDiscriminator, self).__init__(
nn.Linear(in_feature, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
def get_parameters(self) -> List[Dict]:
return [{"params": self.parameters(), "lr": 1.}]