Skip to content

Commit

Permalink
Merge pull request js-ish#14 from js-ish/feat-listrank
Browse files Browse the repository at this point in the history
improve learn to rank
  • Loading branch information
yandy authored Jun 13, 2024
2 parents ea089c5 + 290064b commit 9155056
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 38 deletions.
14 changes: 9 additions & 5 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,16 @@ def __call__(
Mutations(Individual Sample) and Smiles Interaction
MutSmiReg
MutSmis{Pair/List}
MutsSmi{Pair/List}
MutSmis{Pair/List}wiseRank
MutsSmi{Pair/List}wiseRank
"""


class MutSmiReg(_DrugcellAdamr2MutSmi):
pass


class MutSmisPairwise(_DrugcellAdamr2MutSmis):
class MutSmisPairwiseRank(_DrugcellAdamr2MutSmis):
def __call__(
self,
muts: typing.Sequence[list],
Expand All @@ -163,6 +163,10 @@ def __call__(
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
mut_x, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len)
out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device)
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1
out = torch.zeros(rout.size(0), dtype=rout.dtype, device=self.device)
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1.0
return mut_x, smi_tgt, out


class MutSmisListwiseRank(_DrugcellAdamr2MutSmis):
pass
13 changes: 13 additions & 0 deletions dooc/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
import torch.nn as nn


class ListNetLoss(nn.Module):
def __init__(self, reduction: str = 'mean') -> None:
super().__init__()
assert reduction in ['mean', 'sum']
self.reduction = reduction

def forward(self, predict: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
out = - (target.softmax(dim=-1) * predict.log_softmax(dim=-1))
return getattr(out, self.reduction)()
25 changes: 14 additions & 11 deletions dooc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
Mutations(Individual Sample) and Smiles Interaction
MutSmiReg
MutSmis{Pair/List}
MutsSmi{Pair/List}
MutSmisRank
MutsSmiRank
"""


Expand All @@ -20,22 +20,25 @@ def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CO
super().__init__(mut_conf, smi_conf)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(self, *args, **kwargs) -> torch.Tensor:
return self.reg(super().forward(*args, **kwargs)) # [b, 1]
def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
return self.reg(super().forward(mut_x, smi_tgt)) # [b, 1]


class MutSmisPairwise(dnets.DrugcellAdamr2MutSmisXattn):
class MutSmisRank(dnets.DrugcellAdamr2MutSmisXattn):

def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(mut_conf, smi_conf)
self.pairwise_rank = heads.PairwiseRankHead(self.smi_conf.d_model)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(self, *args, **kwargs) -> torch.Tensor:
return self.pairwise_rank(super().forward(*args, **kwargs)) # [b, 2]
def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
return self.reg(super().forward(mut_x, smi_tgt)).squeeze(-1) # [b, n]

def forward_cmp(self, *args, **kwargs) -> float:
def forward_cmp(self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> float:
"""
for infer, no batch dim
"""
out = self.forward(*args, **kwargs)
return (out[1] - out[0]).item()
assert mut_x.dim() == 1 and smi_tgt.dim() == 2
out = self.forward(mut_x, smi_tgt) # [2]
return (out[0] - out[1]).item()
19 changes: 0 additions & 19 deletions dooc/nets/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,3 @@ def __init__(self, in_features: int):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)


class PairwiseRankHead(nn.Module):
def __init__(self, d_features: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Flatten(-2),
nn.Linear(d_features * 2, d_features),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_features, 2)
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [b, 2, d_features]
"""
assert x.size(-2) == 2
return self.mlp(x) # [b, 2] 1: x1 > x2, 0: x1 <= x2
18 changes: 16 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ def test_MutSmiReg(smi_tkz):
assert out.shape == (2, 1)


def test_MutSmisPairwise(smi_tkz):
ds = datasets.MutSmisPairwise(smi_tkz)
def test_MutSmisPairwiseRank(smi_tkz):
ds = datasets.MutSmisPairwiseRank(smi_tkz)
lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]]
lvals = [[0.88, 0.89], [0.82, 0.9]]
muts = [[random.choice([0, 1]) for _ in range(52)],
Expand All @@ -87,3 +87,17 @@ def test_MutSmisPairwise(smi_tkz):
assert smi_tgt.shape == (2, 2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2,)


def test_MutSmisListwiseRank(smi_tkz):
ds = datasets.MutSmisListwiseRank(smi_tkz)
lsmis = [["CC[N+]CCBr", "Cc1ccc1", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1", "Cc1ccc1"]]
lvals = [[0.88, 0.89, 0.89], [0.82, 0.9, 0.9]]
muts = [[random.choice([0, 1]) for _ in range(52)],
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, lsmis, lvals[:1])
mut_x, smi_tgt, out = ds(muts, lsmis, lvals)
assert smi_tgt.shape == (2, 3, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2, 3)
11 changes: 11 additions & 0 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch
from dooc import loss

def test_ListNetLoss():
predict = torch.randn(5, 3)
target = torch.randn(5, 3)
loss_mean = loss.ListNetLoss(reduction='mean')
mean = loss_mean(predict, target)
loss_sum = loss.ListNetLoss(reduction='sum')
sum = loss_sum(predict, target)
assert sum / 15 == mean
2 changes: 1 addition & 1 deletion tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Pointwise(pipelines._MutSmi, pipelines._MutSmisRank):
assert len(out) == 3
assert out[1] == "CC[N+](C)(C)Cc1ccccc1Br"

model = models.MutSmisPairwise()
model = models.MutSmisRank()
pipeline = pipelines.MutSmisRank(smi_tokenizer=smi_tkz, model=model)
out = pipeline(mutation, smiles)
assert isinstance(out, list)
Expand Down

0 comments on commit 9155056

Please sign in to comment.