Skip to content

Commit

Permalink
Merge pull request #8 from js-ish/refactor
Browse files Browse the repository at this point in the history
refactor
  • Loading branch information
yandy authored Jun 6, 2024
2 parents cbe3e8d + 04125c2 commit 1951b89
Show file tree
Hide file tree
Showing 12 changed files with 660 additions and 354 deletions.
121 changes: 102 additions & 19 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,128 @@
from moltx import tokenizers, datasets


class MutSmi:
class _SmiMutBase:
def __init__(self, smi_tokenizer: tokenizers.MoltxTokenizer, device: torch.device = torch.device("cpu")) -> None:
self.smi_ds = datasets.Base(smi_tokenizer, device)
self.device = device

def _smi_tokenize(self, smiles: typing.Sequence[str], seq_len: int = None) -> torch.Tensor:
return self.smi_ds._tokenize(smiles, seq_len)


"""
Mutations(Individual Sample) and Smiles Interaction
{MutationEnc}{SmileEnc}MutSmi: 1 mut with 1 smi
{MutationEnc}{SmileEnc}MutSmis: 1 mut with n smi
{MutationEnc}{SmileEnc}MutsSmi: n mut with 1 smi
"""


class _DrugcellAdamrBase(_SmiMutBase):
"""Base datasets, convert smiles and genes to torch.Tensor."""

def __init__(
self,
smi_tokenizer: tokenizers.MoltxTokenizer,
device: torch.device = torch.device("cpu"),
device: torch.device = torch.device("cpu")
) -> None:
super().__init__(smi_tokenizer, device)
self.smi_tokenizer = smi_tokenizer
self.smi_ds = datasets.AdaMRRegression(self.smi_tokenizer)
self.device = device

def gen_smi_token(
def _smi_tokens(
self,
smiles: typing.Sequence[str],
values: typing.Sequence[float],
seq_len: int = 200,
) -> torch.Tensor:
src, tgt, out = self.smi_ds(smiles, values, seq_len)
return src.to(self.device), tgt.to(self.device), out.to(self.device)
src = self._smi_tokenize(smiles, seq_len)
tgt = self._smi_tokenize(
[f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len)
return src, tgt

def _mut_tokens(self, muts: typing.Sequence[list]) -> torch.Tensor:
return torch.tensor(muts, device=self.device)

def gen_gene_token(self, genes: typing.Sequence[list]) -> torch.Tensor:
return torch.tensor(genes, dtype=torch.float).to(self.device)
def _out(self, values: typing.Sequence[float]) -> torch.Tensor:
return torch.tensor(values, device=self.device)


class _DrugcellAdamrMutSmi(_DrugcellAdamrBase):

def __call__(
self,
smiles: typing.Sequence[str],
genes: typing.Sequence[list],
values: typing.Sequence[float],
seq_len: int = 200,
muts: typing.Sequence[list],
smis: typing.Sequence[str],
vals: typing.Sequence[float],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
smi_src, smi_tgt, out = self.gen_smi_token(smiles, values, seq_len)
gene_src = self.gen_gene_token(genes)
return smi_src, smi_tgt, gene_src, out
assert len(smis) == len(vals) and len(muts) == len(vals)
mut_x = self._mut_tokens(muts)
smi_src, smi_tgt = self._smi_tokens(smis, seq_len)
out = self._out(vals).unsqueeze(-1)
return mut_x, smi_src, smi_tgt, out


class MutSmiXAttention(MutSmi):
class _DrugcellAdamrMutSmis(_DrugcellAdamrBase):

def __call__(
self,
muts: typing.Sequence[list],
lsmis: typing.Sequence[typing.Sequence[str]],
lvals: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
"""
muts: [mut1, mut2, ...] mut1: [gene1, gene2, ...]
bsmiles: [[smi11, smi12], [smi21, smi22], ...]
bvlaues: [[val11, val12], [val21, val22], ...]
"""
assert len(lsmis) == len(lvals) and len(muts) == len(lvals)
mut_x = self._mut_tokens(muts)
batchlen = len(lsmis)
listlen = len(lsmis[0])
smiles = [smi for bsmi in lsmis for smi in bsmi]
smi_src, smi_tgt = self._smi_tokens(smiles, seq_len)
smi_src = smi_src.reshape(batchlen, listlen, smi_src.size(-1))
smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_src.size(-1))
out = self._out(lvals)
return mut_x, smi_src, smi_tgt, out


class _DrugcellAdamr2Base(_SmiMutBase):
pass


class _DrugcellAdamr2MutSmi(_DrugcellAdamr2Base):
pass


class _DrugcellAdamr2MutSmisPairwiseRank(_DrugcellAdamr2Base):
pass


class MutSmiFullConnection(MutSmi):
"""
Mutations(Individual Sample) and Smiles Interaction
MutSmiReg
MutSmis{Pair/List}
MutsSmi{Pair/List}
"""


class MutSmiReg(_DrugcellAdamrMutSmi):
pass


class MutSmisPairwise(_DrugcellAdamrMutSmis):
def __call__(
self,
muts: typing.Sequence[list],
lsmiles: typing.Sequence[typing.Sequence[str]],
lvalues: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
mut_x, smi_src, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len)
out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device)
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1
return mut_x, smi_src, smi_tgt, out.unsqueeze(-1)
119 changes: 27 additions & 92 deletions dooc/models.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,40 @@
import torch
from torch import nn
from dooc import nets
from moltx.models import AdaMR
from moltx.nets import AbsPosEncoderDecoderConfig
from moltx import models as mmodels
from dooc import nets as dnets
from dooc.nets import mutations, heads


class MutSmi(nn.Module):
"""Base"""
"""
Mutations(Individual Sample) and Smiles Interaction
def __init__(
self,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__()
self.gene_conf = gene_conf
self.smiles_conf = smiles_conf
self.smiles_encoder = AdaMR(smiles_conf)
MutSmiReg
MutSmis{Pair/List}
MutsSmi{Pair/List}
"""

self.gene_encoder = nets.GeneGNN(gene_conf)

def load_ckpt(self, *ckpt_files: str) -> None:
"""load check point model.
class MutSmiReg(dnets.DrugcellAdamrMutSmiXattn):

Args:
ckpt_files (str): check point file paths.
"""
self.load_state_dict(
torch.load(ckpt_files[0], map_location=torch.device("cpu")), strict=False
)

def load_pretrained_ckpt(self, drugcell_ckpt: str, moltx_ckpt: str, freeze_drugcell: bool = False, freeze_moltx: bool = False) -> None:
self.gene_encoder.load_ckpt(drugcell_ckpt)
self.smiles_encoder.load_ckpt(moltx_ckpt)
if freeze_moltx:
self.smiles_encoder.requires_grad_(False)
if freeze_drugcell:
self.gene_encoder.requires_grad_(False)

def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
raise NotImplementedError()


class MutSmiXAttention(MutSmi):
"""Regression model using transformer cross attention."""
def __init__(self) -> None:
super().__init__(mut_conf=mutations.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE)
self.reg = heads.RegHead(self.smi_conf.d_model)

def __init__(
self,
nhead: int = 2,
num_layers: int = 2,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
d_model = self.smiles_conf.d_model
d_hidden = d_model // 2
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_att = nn.TransformerDecoder(layer, num_layers)
self.reg = nn.Sequential(
nn.Linear(d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
)
def forward(self, *args, **kwargs) -> torch.Tensor:
return self.reg(super().forward(*args, **kwargs)) # [b, 1]

def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt).unsqueeze(-2) # [b, 1, dmodel]
gene_out = self.gene_encoder(gene_src).unsqueeze(-2) # [b, 1, dmodel]

feat = self.cross_att(smiles_out, gene_out) # [b, 1, dmodel]
class MutSmisPairwise(dnets.DrugcellAdamrMutSmisXattn):

return self.reg(feat.squeeze(-2)) # [b, 1]
def __init__(self) -> None:
super().__init__(mut_conf=mutations.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE)
self.pairwise_rank = heads.PairwiseRankHead(self.smi_conf.d_model)

def forward(self, *args, **kwargs) -> torch.Tensor:
return self.pairwise_rank(super().forward(*args, **kwargs)) # [b, 2]

class MutSmiFullConnection(MutSmi):
"""Regression model using fully connection."""

def __init__(
self,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
d_model = self.smiles_conf.d_model
d_hidden = d_model // 2
self.reg = nn.Sequential(
nn.Linear(d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
)

def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) # [b, dmodel]
gene_out = self.gene_encoder(gene_src) # [b, dmodel]
feat = smiles_out + gene_out

return self.reg(feat) # [b, 1]
def forward_cmp(self, *args, **kwargs) -> float:
"""
for infer, no batch dim
"""
out = self.forward(*args, **kwargs)
return (out[1] - out[0]).item()
Loading

0 comments on commit 1951b89

Please sign in to comment.