Skip to content

Commit

Permalink
fix: unused param due to dummy_param #113
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jul 17, 2024
1 parent cff0bab commit b0e5568
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
23 changes: 22 additions & 1 deletion playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,28 @@
smi, lig = next(iter(torch.load('../data/DavisKibaDataset/davis/nomsa_binary_gvp_binary/test/data_mol.pt').items()))
pid, pro = next(iter(torch.load('../data/DavisKibaDataset/davis/nomsa_binary_gvp_binary/test/data_pro.pt').items()))

m(pro,lig)
#%%
from copy import deepcopy
s0 = deepcopy(m.state_dict())

# %% train with single sample
from torch import nn
criterion = nn.MSELoss()
optim = torch.optim.Adam(m.parameters(), lr=1)

m.train()
loss = criterion(m(pro, lig), torch.tensor([1.0]))

optim.zero_grad()
loss.backward()
optim.step()
#%%
for k in m.state_dict():
v1 = s0[k]
v2 = m.state_dict()[k]
if torch.allclose(v1, v2):
print(k)



#%%
Expand Down
7 changes: 2 additions & 5 deletions src/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def __init__(self, in_dims, out_dims, h_dim=None,
self.ws = nn.Linear(self.si, self.so)

self.scalar_act, self.vector_act = activations
self.dummy_param = nn.Parameter(torch.empty(0))

def forward(self, x):
'''
Expand Down Expand Up @@ -187,7 +186,7 @@ def forward(self, x):
s = self.ws(x)
if self.vo: # vector dim is zero
v = torch.zeros(s.shape[0], self.vo, 3,
device=self.dummy_param.device)
device=x.device)
if self.scalar_act:
s = self.scalar_act(s)

Expand All @@ -201,17 +200,15 @@ class _VDropout(nn.Module):
def __init__(self, drop_rate):
super(_VDropout, self).__init__()
self.drop_rate = drop_rate
self.dummy_param = nn.Parameter(torch.empty(0))

def forward(self, x):
'''
:param x: `torch.Tensor` corresponding to vector channels
'''
device = self.dummy_param.device
if not self.training:
return x
mask = torch.bernoulli(
(1 - self.drop_rate) * torch.ones(x.shape[:-1], device=device)
(1 - self.drop_rate) * torch.ones(x.shape[:-1], device=x.device)
).unsqueeze(-1)
x = mask * x / (1 - self.drop_rate)
return x
Expand Down

0 comments on commit b0e5568

Please sign in to comment.