Skip to content

Commit

Permalink
Revert "Merge pull request js-ish#12 from origyZ/feat-adamr2"
Browse files Browse the repository at this point in the history
This reverts commit ea089c5, reversing
changes made to 574980b.
  • Loading branch information
zzzseeu committed Jun 12, 2024
1 parent ea089c5 commit f37c1d4
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 239 deletions.
55 changes: 9 additions & 46 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,53 +92,16 @@ def __call__(
return mut_x, smi_src, smi_tgt, out


class _DrugcellAdamr2Base(_DrugcellAdamrBase):
def _smi_tokens(
self,
smiles: typing.Sequence[str],
seq_len: int = 200,
) -> torch.Tensor:
return self._smi_tokenize(
[f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len)
class _DrugcellAdamr2Base(_SmiMutBase):
pass


class _DrugcellAdamr2MutSmi(_DrugcellAdamr2Base):
def __call__(
self,
muts: typing.Sequence[list],
smis: typing.Sequence[str],
vals: typing.Sequence[float],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
assert len(smis) == len(vals) and len(muts) == len(vals)
mut_x = self._mut_tokens(muts)
smi_tgt = self._smi_tokens(smis, seq_len)
out = self._out(vals).unsqueeze(-1)
return mut_x, smi_tgt, out
pass


class _DrugcellAdamr2MutSmis(_DrugcellAdamr2Base):
def __call__(
self,
muts: typing.Sequence[list],
lsmis: typing.Sequence[typing.Sequence[str]],
lvals: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
"""
muts: [mut1, mut2, ...] mut1: [gene1, gene2, ...]
bsmiles: [[smi11, smi12], [smi21, smi22], ...]
bvlaues: [[val11, val12], [val21, val22], ...]
"""
assert len(lsmis) == len(lvals) and len(muts) == len(lvals)
mut_x = self._mut_tokens(muts)
batchlen = len(lsmis)
listlen = len(lsmis[0])
smiles = [smi for bsmi in lsmis for smi in bsmi]
smi_tgt = self._smi_tokens(smiles, seq_len)
smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_tgt.size(-1))
out = self._out(lvals)
return mut_x, smi_tgt, out
class _DrugcellAdamr2MutSmisPairwiseRank(_DrugcellAdamr2Base):
pass


"""
Expand All @@ -150,19 +113,19 @@ def __call__(
"""


class MutSmiReg(_DrugcellAdamr2MutSmi):
class MutSmiReg(_DrugcellAdamrMutSmi):
pass


class MutSmisPairwise(_DrugcellAdamr2MutSmis):
class MutSmisPairwise(_DrugcellAdamrMutSmis):
def __call__(
self,
muts: typing.Sequence[list],
lsmiles: typing.Sequence[typing.Sequence[str]],
lvalues: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
mut_x, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len)
mut_x, smi_src, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len)
out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device)
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1
return mut_x, smi_tgt, out
return mut_x, smi_src, smi_tgt, out
15 changes: 7 additions & 8 deletions dooc/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
from moltx import nets as mnets
from moltx import models as mmodels
from dooc import nets as dnets
from dooc.nets import heads, drugcell
from dooc.nets import heads


"""
Expand All @@ -14,20 +13,20 @@
"""


class MutSmiReg(dnets.DrugcellAdamr2MutSmiXattn):
class MutSmiReg(dnets.DrugcellAdamrMutSmiXattn):

def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(mut_conf, smi_conf)
def __init__(self) -> None:
super().__init__(mut_conf=dnets.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(self, *args, **kwargs) -> torch.Tensor:
return self.reg(super().forward(*args, **kwargs)) # [b, 1]


class MutSmisPairwise(dnets.DrugcellAdamr2MutSmisXattn):
class MutSmisPairwise(dnets.DrugcellAdamrMutSmisXattn):

def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(mut_conf, smi_conf)
def __init__(self) -> None:
super().__init__(mut_conf=dnets.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE)
self.pairwise_rank = heads.PairwiseRankHead(self.smi_conf.d_model)

def forward(self, *args, **kwargs) -> torch.Tensor:
Expand Down
96 changes: 8 additions & 88 deletions dooc/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from moltx.models import AdaMR, AdaMR2
from moltx.models import AdaMR
from dooc.nets.drugcell import Drugcell


Expand Down Expand Up @@ -115,97 +115,17 @@ def forward(
return self.cross_attn(smi_out, mut_out) # [b, n, dmodel]


class _DrugcellAdamr2(nn.Module):
class DrugcellAdamr2MutSmiAdd(nn.Module):
pass

def __init__(self, mut_conf, smi_conf) -> None:
super().__init__()
self.mut_conf = mut_conf
self.smi_conf = smi_conf

self.mut_encoder = Drugcell(mut_conf)
self.smi_encoder = AdaMR2(smi_conf)

def load_ckpt(self, *ckpt_files: str) -> None:
self.load_state_dict(
torch.load(ckpt_files[0], map_location=torch.device("cpu"))
)

def load_pretrained_ckpt(self, mut_ckpt: str, smi_ckpt: str, freeze_mut: bool = False, freeze_smi: bool = False) -> None:
self.mut_encoder.load_ckpt(mut_ckpt)
self.smi_encoder.load_ckpt(smi_ckpt)
if freeze_smi:
self.smi_encoder.requires_grad_(False)
if freeze_mut:
self.mut_encoder.requires_grad_(False)

class DrugcellAdamr2MutSmiXattn(nn.Module):
pass

class DrugcellAdamr2MutSmiAdd(_DrugcellAdamr2):
def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
mut_x: [b, mut_seqlen]
smi_tgt: [b, smi_seqlen]
"""
mut_out = self.mut_encoder(mut_x)
smi_out = self.smi_encoder.forward_feature(smi_tgt)
return mut_out + smi_out # [b, dmodel]


class DrugcellAdamr2MutSmiXattn(_DrugcellAdamr2):
def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(mut_conf, smi_conf)
d_model = self.smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
mut_x: [b, mut_seqlen]
smi_tgt: [b, smi_seqlen]
"""
mut_out = self.mut_encoder(mut_x).unsqueeze(-2) # [b, 1, dmodel]
smi_out = self.smi_encoder.forward_feature(smi_tgt).unsqueeze(-2) # [b, 1, dmodel]
return self.cross_attn(smi_out, mut_out).squeeze(-2) # [b, dmodel]


class DrugcellAdamr2MutSmisAdd(_DrugcellAdamr2):
def _forward_mut(self, mut_x: torch.Tensor) -> torch.Tensor:
"""
mut_x: [b, mut_seqlen]
out: [b, 1, dmodel]
"""
return self.mut_encoder(mut_x).unsqueeze(-2)

def _forward_smi(self, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
smi_tgt: [b, n, smi_seqlen]
out: [b, n, dmodel]
"""
batched = smi_tgt.dim() == 3
if batched:
n = smi_tgt.shape[1]
smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1])
out = self.smi_encoder.forward_feature(smi_tgt)
return out.reshape(-1, n, out.shape[-1])
return self.smi_encoder.forward_feature(smi_tgt)

def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
mut_out = self._forward_mut(mut_x)
smi_out = self._forward_smi(smi_tgt)
return smi_out + mut_out # [b, n, dmodel]
class DrugcellAdamr2MutSmisAdd(nn.Module):
pass


class DrugcellAdamr2MutSmisXattn(DrugcellAdamr2MutSmisAdd):
def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(mut_conf, smi_conf)
d_model = smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
mut_out = self._forward_mut(mut_x)
smi_out = self._forward_smi(smi_tgt)
return self.cross_attn(smi_out, mut_out) # [b, n, dmodel]
pass
6 changes: 3 additions & 3 deletions dooc/nets/drugcell.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _cal_term_dim(self):
term_size_map is the number of all genes annotated with the term
"""
self.term_dim_map = {}
for term in self.term_size_map.keys():
for term, term_size in self.term_size_map.items():
num_output = self.conf.num_hiddens_genotype

# log the number of hidden variables per each term
Expand Down Expand Up @@ -220,14 +220,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# define forward function for genotype dcell #############################################
term_gene_out_map = {}

for term in self.term_direct_gene_map.keys():
for term, _ in self.term_direct_gene_map.items():
term_gene_out_map[term] = self._modules[term + "_direct_gene_layer"](
gene_input
)

term_nn_out_map = {}

for layer in self.term_layer_list:
for _, layer in enumerate(self.term_layer_list):

for term in layer:

Expand Down
34 changes: 24 additions & 10 deletions dooc/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ def _tokens2tensor(self, tokens: typing.Sequence[int], size: int = None) -> torc
return out.to(self.device)


class _MutSmi(_MutSmiBase):
class _DrugcellAdamrMutSmi(_MutSmiBase):

def _model_args(self, mut: typing.Sequence[int], smi: str) -> typing.Tuple[torch.Tensor]:
def _model_args(
self, mut: typing.Sequence[int], smi: str
) -> typing.Tuple[torch.Tensor]:
mut_x = torch.tensor(mut, device=self.device)
smi_src = self._tokens2tensor(self.smi_tokenizer(smi))
smi_tgt = self._tokens2tensor(self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS))
return mut_x, smi_tgt
return mut_x, smi_src, smi_tgt

def reg(self, mut: typing.Sequence[int], smi: str) -> float:
return self.model(*self._model_args(mut, smi)).item()
Expand All @@ -50,15 +53,18 @@ def cmp(smi1, smi2):
return cmp


class _MutSmis(_MutSmiBase):
class _DrugcellAdamrMutSmis(_MutSmiBase):

def _smi_args(
self, smis: typing.Sequence[str]
) -> torch.Tensor:
) -> typing.Tuple[torch.Tensor]:
smi_src = [self.smi_tokenizer(smi) for smi in smis]
smi_tgt = [self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS) for smi in smis]
size_src = max(map(len, smi_src))
size_tgt = max(map(len, smi_tgt))
smi_src = torch.concat([self._tokens2tensor(smi, size_src).unsqueeze(0) for smi in smi_src])
smi_tgt = torch.concat([self._tokens2tensor(smi, size_tgt).unsqueeze(0) for smi in smi_tgt])
return smi_tgt
return smi_src, smi_tgt

def cmp_smis_func(self, mut: typing.Sequence[int]) -> typing.Callable:
mut_x = torch.tensor(mut, device=self.device)
Expand All @@ -69,13 +75,21 @@ def cmp(smi1, smi2):
query = '-'.join(smis)
if query in cmped:
return cmped[query]
smi_tgt = self._smi_args(smis)
out = self.model.forward_cmp(mut_x, smi_tgt)
smi_src, smi_tgt = self._smi_args(smis)
out = self.model.forward_cmp(mut_x, smi_src, smi_tgt)
cmped[query] = out
return out
return cmp


class _DrugcellAdamr2MutSmi(_MutSmiBase):
pass


class _DrugcellAdamr2MutSmis(_MutSmiBase):
pass


class _MutSmiReg:

def __call__(self, mut: typing.Sequence[int], smi: str) -> typing.Dict:
Expand All @@ -97,9 +111,9 @@ def __call__(self, mut: typing.Sequence[int], smis: typing.Sequence[str]) -> typ
"""


class MutSmiReg(_MutSmi, _MutSmiReg):
class MutSmiReg(_DrugcellAdamrMutSmi, _MutSmiReg):
pass


class MutSmisRank(_MutSmis, _MutSmisRank):
class MutSmisRank(_DrugcellAdamrMutSmis, _MutSmisRank):
pass
34 changes: 4 additions & 30 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,6 @@ def test_DrugcellAdamrMutSmi(smi_tkz):
assert out.shape == (2, 1)


def test_DrugcellAdamr2MutSmi(smi_tkz):
ds = datasets._DrugcellAdamr2MutSmi(smi_tkz)
smis = ["CC[N+]CCBr", "Cc1ccc1"]
vals = [0.88, 0.89]
muts = [[random.choice([0, 1]) for _ in range(52)],
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, smis, vals[:1])
mut_x, smi_tgt, out = ds(muts, smis, vals)
assert smi_tgt.shape == (2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2, 1)


def test_DrugcellAdamrMutSmis(smi_tkz):
ds = datasets._DrugcellAdamrMutSmis(smi_tkz)
lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]]
Expand All @@ -47,20 +33,6 @@ def test_DrugcellAdamrMutSmis(smi_tkz):
assert out.shape == (2, 2)


def test_DrugcellAdamr2MutSmis(smi_tkz):
ds = datasets._DrugcellAdamr2MutSmis(smi_tkz)
lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]]
lvals = [[0.88, 0.89], [0.82, 0.9]]
muts = [[random.choice([0, 1]) for _ in range(52)],
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, lsmis, lvals[:1])
mut_x, smi_tgt, out = ds(muts, lsmis, lvals)
assert smi_tgt.shape == (2, 2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2, 2)


def test_MutSmiReg(smi_tkz):
ds = datasets.MutSmiReg(smi_tkz)
smis = ["CC[N+]CCBr", "Cc1ccc1"]
Expand All @@ -69,7 +41,8 @@ def test_MutSmiReg(smi_tkz):
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, smis, vals[:1])
mut_x, smi_tgt, out = ds(muts, smis, vals)
mut_x, smi_src, smi_tgt, out = ds(muts, smis, vals)
assert smi_src.shape == (2, 200)
assert smi_tgt.shape == (2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2, 1)
Expand All @@ -83,7 +56,8 @@ def test_MutSmisPairwise(smi_tkz):
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, lsmis, lvals[:1])
mut_x, smi_tgt, out = ds(muts, lsmis, lvals)
mut_x, smi_src, smi_tgt, out = ds(muts, lsmis, lvals)
assert smi_src.shape == (2, 2, 200)
assert smi_tgt.shape == (2, 2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2,)
Loading

0 comments on commit f37c1d4

Please sign in to comment.