Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adamr2 #12

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,53 @@ def __call__(
return mut_x, smi_src, smi_tgt, out


class _DrugcellAdamr2Base(_SmiMutBase):
pass
class _DrugcellAdamr2Base(_DrugcellAdamrBase):
def _smi_tokens(
self,
smiles: typing.Sequence[str],
seq_len: int = 200,
) -> torch.Tensor:
return self._smi_tokenize(
[f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len)


class _DrugcellAdamr2MutSmi(_DrugcellAdamr2Base):
pass
def __call__(
self,
muts: typing.Sequence[list],
smis: typing.Sequence[str],
vals: typing.Sequence[float],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
assert len(smis) == len(vals) and len(muts) == len(vals)
mut_x = self._mut_tokens(muts)
smi_tgt = self._smi_tokens(smis, seq_len)
out = self._out(vals).unsqueeze(-1)
return mut_x, smi_tgt, out


class _DrugcellAdamr2MutSmisPairwiseRank(_DrugcellAdamr2Base):
pass
class _DrugcellAdamr2MutSmis(_DrugcellAdamr2Base):
def __call__(
self,
muts: typing.Sequence[list],
lsmis: typing.Sequence[typing.Sequence[str]],
lvals: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
"""
muts: [mut1, mut2, ...] mut1: [gene1, gene2, ...]
bsmiles: [[smi11, smi12], [smi21, smi22], ...]
bvlaues: [[val11, val12], [val21, val22], ...]
"""
assert len(lsmis) == len(lvals) and len(muts) == len(lvals)
mut_x = self._mut_tokens(muts)
batchlen = len(lsmis)
listlen = len(lsmis[0])
smiles = [smi for bsmi in lsmis for smi in bsmi]
smi_tgt = self._smi_tokens(smiles, seq_len)
smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_tgt.size(-1))
out = self._out(lvals)
return mut_x, smi_tgt, out


"""
Expand All @@ -113,19 +150,19 @@ class _DrugcellAdamr2MutSmisPairwiseRank(_DrugcellAdamr2Base):
"""


class MutSmiReg(_DrugcellAdamrMutSmi):
class MutSmiReg(_DrugcellAdamr2MutSmi):
pass


class MutSmisPairwise(_DrugcellAdamrMutSmis):
class MutSmisPairwise(_DrugcellAdamr2MutSmis):
def __call__(
self,
muts: typing.Sequence[list],
lsmiles: typing.Sequence[typing.Sequence[str]],
lvalues: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
mut_x, smi_src, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len)
mut_x, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len)
out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device)
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1
return mut_x, smi_src, smi_tgt, out
return mut_x, smi_tgt, out
15 changes: 8 additions & 7 deletions dooc/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import torch
from moltx import nets as mnets
from moltx import models as mmodels
from dooc import nets as dnets
from dooc.nets import heads
from dooc.nets import heads, drugcell


"""
Expand All @@ -13,20 +14,20 @@
"""


class MutSmiReg(dnets.DrugcellAdamrMutSmiXattn):
class MutSmiReg(dnets.DrugcellAdamr2MutSmiXattn):

def __init__(self) -> None:
super().__init__(mut_conf=dnets.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE)
def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(mut_conf, smi_conf)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(self, *args, **kwargs) -> torch.Tensor:
return self.reg(super().forward(*args, **kwargs)) # [b, 1]


class MutSmisPairwise(dnets.DrugcellAdamrMutSmisXattn):
class MutSmisPairwise(dnets.DrugcellAdamr2MutSmisXattn):

def __init__(self) -> None:
super().__init__(mut_conf=dnets.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE)
def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(mut_conf, smi_conf)
self.pairwise_rank = heads.PairwiseRankHead(self.smi_conf.d_model)

def forward(self, *args, **kwargs) -> torch.Tensor:
Expand Down
96 changes: 88 additions & 8 deletions dooc/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from moltx.models import AdaMR
from moltx.models import AdaMR, AdaMR2
from dooc.nets.drugcell import Drugcell


Expand Down Expand Up @@ -115,17 +115,97 @@ def forward(
return self.cross_attn(smi_out, mut_out) # [b, n, dmodel]


class DrugcellAdamr2MutSmiAdd(nn.Module):
pass
class _DrugcellAdamr2(nn.Module):

def __init__(self, mut_conf, smi_conf) -> None:
super().__init__()
self.mut_conf = mut_conf
self.smi_conf = smi_conf

self.mut_encoder = Drugcell(mut_conf)
self.smi_encoder = AdaMR2(smi_conf)

def load_ckpt(self, *ckpt_files: str) -> None:
self.load_state_dict(
torch.load(ckpt_files[0], map_location=torch.device("cpu"))
)

def load_pretrained_ckpt(self, mut_ckpt: str, smi_ckpt: str, freeze_mut: bool = False, freeze_smi: bool = False) -> None:
self.mut_encoder.load_ckpt(mut_ckpt)
self.smi_encoder.load_ckpt(smi_ckpt)
if freeze_smi:
self.smi_encoder.requires_grad_(False)
if freeze_mut:
self.mut_encoder.requires_grad_(False)

class DrugcellAdamr2MutSmiXattn(nn.Module):
pass

class DrugcellAdamr2MutSmiAdd(_DrugcellAdamr2):
def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
mut_x: [b, mut_seqlen]
smi_tgt: [b, smi_seqlen]
"""
mut_out = self.mut_encoder(mut_x)
smi_out = self.smi_encoder.forward_feature(smi_tgt)
return mut_out + smi_out # [b, dmodel]

class DrugcellAdamr2MutSmisAdd(nn.Module):
pass

class DrugcellAdamr2MutSmiXattn(_DrugcellAdamr2):
def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(mut_conf, smi_conf)
d_model = self.smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
mut_x: [b, mut_seqlen]
smi_tgt: [b, smi_seqlen]
"""
mut_out = self.mut_encoder(mut_x).unsqueeze(-2) # [b, 1, dmodel]
smi_out = self.smi_encoder.forward_feature(smi_tgt).unsqueeze(-2) # [b, 1, dmodel]
return self.cross_attn(smi_out, mut_out).squeeze(-2) # [b, dmodel]


class DrugcellAdamr2MutSmisAdd(_DrugcellAdamr2):
def _forward_mut(self, mut_x: torch.Tensor) -> torch.Tensor:
"""
mut_x: [b, mut_seqlen]
out: [b, 1, dmodel]
"""
return self.mut_encoder(mut_x).unsqueeze(-2)

def _forward_smi(self, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
smi_tgt: [b, n, smi_seqlen]
out: [b, n, dmodel]
"""
batched = smi_tgt.dim() == 3
if batched:
n = smi_tgt.shape[1]
smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1])
out = self.smi_encoder.forward_feature(smi_tgt)
return out.reshape(-1, n, out.shape[-1])
return self.smi_encoder.forward_feature(smi_tgt)

def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
mut_out = self._forward_mut(mut_x)
smi_out = self._forward_smi(smi_tgt)
return smi_out + mut_out # [b, n, dmodel]


class DrugcellAdamr2MutSmisXattn(DrugcellAdamr2MutSmisAdd):
pass
def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(mut_conf, smi_conf)
d_model = smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
mut_out = self._forward_mut(mut_x)
smi_out = self._forward_smi(smi_tgt)
return self.cross_attn(smi_out, mut_out) # [b, n, dmodel]
6 changes: 3 additions & 3 deletions dooc/nets/drugcell.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def _cal_term_dim(self):
term_size_map is the number of all genes annotated with the term
"""
self.term_dim_map = {}
for term, term_size in self.term_size_map.items():
for term in self.term_size_map.keys():
num_output = self.conf.num_hiddens_genotype

# log the number of hidden variables per each term
Expand Down Expand Up @@ -220,14 +220,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
# define forward function for genotype dcell #############################################
term_gene_out_map = {}

for term, _ in self.term_direct_gene_map.items():
for term in self.term_direct_gene_map.keys():
term_gene_out_map[term] = self._modules[term + "_direct_gene_layer"](
gene_input
)

term_nn_out_map = {}

for _, layer in enumerate(self.term_layer_list):
for layer in self.term_layer_list:

for term in layer:

Expand Down
34 changes: 10 additions & 24 deletions dooc/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,12 @@ def _tokens2tensor(self, tokens: typing.Sequence[int], size: int = None) -> torc
return out.to(self.device)


class _DrugcellAdamrMutSmi(_MutSmiBase):
class _MutSmi(_MutSmiBase):

def _model_args(
self, mut: typing.Sequence[int], smi: str
) -> typing.Tuple[torch.Tensor]:
def _model_args(self, mut: typing.Sequence[int], smi: str) -> typing.Tuple[torch.Tensor]:
mut_x = torch.tensor(mut, device=self.device)
smi_src = self._tokens2tensor(self.smi_tokenizer(smi))
smi_tgt = self._tokens2tensor(self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS))
return mut_x, smi_src, smi_tgt
return mut_x, smi_tgt

def reg(self, mut: typing.Sequence[int], smi: str) -> float:
return self.model(*self._model_args(mut, smi)).item()
Expand All @@ -53,18 +50,15 @@ def cmp(smi1, smi2):
return cmp


class _DrugcellAdamrMutSmis(_MutSmiBase):
class _MutSmis(_MutSmiBase):

def _smi_args(
self, smis: typing.Sequence[str]
) -> typing.Tuple[torch.Tensor]:
smi_src = [self.smi_tokenizer(smi) for smi in smis]
) -> torch.Tensor:
smi_tgt = [self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS) for smi in smis]
size_src = max(map(len, smi_src))
size_tgt = max(map(len, smi_tgt))
smi_src = torch.concat([self._tokens2tensor(smi, size_src).unsqueeze(0) for smi in smi_src])
smi_tgt = torch.concat([self._tokens2tensor(smi, size_tgt).unsqueeze(0) for smi in smi_tgt])
return smi_src, smi_tgt
return smi_tgt

def cmp_smis_func(self, mut: typing.Sequence[int]) -> typing.Callable:
mut_x = torch.tensor(mut, device=self.device)
Expand All @@ -75,21 +69,13 @@ def cmp(smi1, smi2):
query = '-'.join(smis)
if query in cmped:
return cmped[query]
smi_src, smi_tgt = self._smi_args(smis)
out = self.model.forward_cmp(mut_x, smi_src, smi_tgt)
smi_tgt = self._smi_args(smis)
out = self.model.forward_cmp(mut_x, smi_tgt)
cmped[query] = out
return out
return cmp


class _DrugcellAdamr2MutSmi(_MutSmiBase):
pass


class _DrugcellAdamr2MutSmis(_MutSmiBase):
pass


class _MutSmiReg:

def __call__(self, mut: typing.Sequence[int], smi: str) -> typing.Dict:
Expand All @@ -111,9 +97,9 @@ def __call__(self, mut: typing.Sequence[int], smis: typing.Sequence[str]) -> typ
"""


class MutSmiReg(_DrugcellAdamrMutSmi, _MutSmiReg):
class MutSmiReg(_MutSmi, _MutSmiReg):
pass


class MutSmisRank(_DrugcellAdamrMutSmis, _MutSmisRank):
class MutSmisRank(_MutSmis, _MutSmisRank):
pass
34 changes: 30 additions & 4 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,20 @@ def test_DrugcellAdamrMutSmi(smi_tkz):
assert out.shape == (2, 1)


def test_DrugcellAdamr2MutSmi(smi_tkz):
ds = datasets._DrugcellAdamr2MutSmi(smi_tkz)
smis = ["CC[N+]CCBr", "Cc1ccc1"]
vals = [0.88, 0.89]
muts = [[random.choice([0, 1]) for _ in range(52)],
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, smis, vals[:1])
mut_x, smi_tgt, out = ds(muts, smis, vals)
assert smi_tgt.shape == (2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2, 1)


def test_DrugcellAdamrMutSmis(smi_tkz):
ds = datasets._DrugcellAdamrMutSmis(smi_tkz)
lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]]
Expand All @@ -33,6 +47,20 @@ def test_DrugcellAdamrMutSmis(smi_tkz):
assert out.shape == (2, 2)


def test_DrugcellAdamr2MutSmis(smi_tkz):
ds = datasets._DrugcellAdamr2MutSmis(smi_tkz)
lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]]
lvals = [[0.88, 0.89], [0.82, 0.9]]
muts = [[random.choice([0, 1]) for _ in range(52)],
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, lsmis, lvals[:1])
mut_x, smi_tgt, out = ds(muts, lsmis, lvals)
assert smi_tgt.shape == (2, 2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2, 2)


def test_MutSmiReg(smi_tkz):
ds = datasets.MutSmiReg(smi_tkz)
smis = ["CC[N+]CCBr", "Cc1ccc1"]
Expand All @@ -41,8 +69,7 @@ def test_MutSmiReg(smi_tkz):
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, smis, vals[:1])
mut_x, smi_src, smi_tgt, out = ds(muts, smis, vals)
assert smi_src.shape == (2, 200)
mut_x, smi_tgt, out = ds(muts, smis, vals)
assert smi_tgt.shape == (2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2, 1)
Expand All @@ -56,8 +83,7 @@ def test_MutSmisPairwise(smi_tkz):
[random.choice([0, 1]) for _ in range(52)]]
with pytest.raises(AssertionError):
ds(muts, lsmis, lvals[:1])
mut_x, smi_src, smi_tgt, out = ds(muts, lsmis, lvals)
assert smi_src.shape == (2, 2, 200)
mut_x, smi_tgt, out = ds(muts, lsmis, lvals)
assert smi_tgt.shape == (2, 2, 200)
assert mut_x.shape == (2, 52)
assert out.shape == (2,)
Loading
Loading