-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from js-ish/refactor
refactor
- Loading branch information
Showing
12 changed files
with
660 additions
and
354 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,105 +1,40 @@ | ||
import torch | ||
from torch import nn | ||
from dooc import nets | ||
from moltx.models import AdaMR | ||
from moltx.nets import AbsPosEncoderDecoderConfig | ||
from moltx import models as mmodels | ||
from dooc import nets as dnets | ||
from dooc.nets import mutations, heads | ||
|
||
|
||
class MutSmi(nn.Module): | ||
"""Base""" | ||
""" | ||
Mutations(Individual Sample) and Smiles Interaction | ||
def __init__( | ||
self, | ||
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, | ||
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, | ||
) -> None: | ||
super().__init__() | ||
self.gene_conf = gene_conf | ||
self.smiles_conf = smiles_conf | ||
self.smiles_encoder = AdaMR(smiles_conf) | ||
MutSmiReg | ||
MutSmis{Pair/List} | ||
MutsSmi{Pair/List} | ||
""" | ||
|
||
self.gene_encoder = nets.GeneGNN(gene_conf) | ||
|
||
def load_ckpt(self, *ckpt_files: str) -> None: | ||
"""load check point model. | ||
class MutSmiReg(dnets.DrugcellAdamrMutSmiXattn): | ||
|
||
Args: | ||
ckpt_files (str): check point file paths. | ||
""" | ||
self.load_state_dict( | ||
torch.load(ckpt_files[0], map_location=torch.device("cpu")), strict=False | ||
) | ||
|
||
def load_pretrained_ckpt(self, drugcell_ckpt: str, moltx_ckpt: str, freeze_drugcell: bool = False, freeze_moltx: bool = False) -> None: | ||
self.gene_encoder.load_ckpt(drugcell_ckpt) | ||
self.smiles_encoder.load_ckpt(moltx_ckpt) | ||
if freeze_moltx: | ||
self.smiles_encoder.requires_grad_(False) | ||
if freeze_drugcell: | ||
self.gene_encoder.requires_grad_(False) | ||
|
||
def forward( | ||
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor | ||
) -> torch.Tensor: | ||
raise NotImplementedError() | ||
|
||
|
||
class MutSmiXAttention(MutSmi): | ||
"""Regression model using transformer cross attention.""" | ||
def __init__(self) -> None: | ||
super().__init__(mut_conf=mutations.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE) | ||
self.reg = heads.RegHead(self.smi_conf.d_model) | ||
|
||
def __init__( | ||
self, | ||
nhead: int = 2, | ||
num_layers: int = 2, | ||
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, | ||
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, | ||
) -> None: | ||
super().__init__(gene_conf, smiles_conf) | ||
d_model = self.smiles_conf.d_model | ||
d_hidden = d_model // 2 | ||
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) | ||
self.cross_att = nn.TransformerDecoder(layer, num_layers) | ||
self.reg = nn.Sequential( | ||
nn.Linear(d_model, d_hidden), | ||
nn.ReLU(), | ||
nn.Dropout(0.1), | ||
nn.Linear(d_hidden, 1), | ||
) | ||
def forward(self, *args, **kwargs) -> torch.Tensor: | ||
return self.reg(super().forward(*args, **kwargs)) # [b, 1] | ||
|
||
def forward( | ||
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor | ||
) -> torch.Tensor: | ||
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt).unsqueeze(-2) # [b, 1, dmodel] | ||
gene_out = self.gene_encoder(gene_src).unsqueeze(-2) # [b, 1, dmodel] | ||
|
||
feat = self.cross_att(smiles_out, gene_out) # [b, 1, dmodel] | ||
class MutSmisPairwise(dnets.DrugcellAdamrMutSmisXattn): | ||
|
||
return self.reg(feat.squeeze(-2)) # [b, 1] | ||
def __init__(self) -> None: | ||
super().__init__(mut_conf=mutations.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE) | ||
self.pairwise_rank = heads.PairwiseRankHead(self.smi_conf.d_model) | ||
|
||
def forward(self, *args, **kwargs) -> torch.Tensor: | ||
return self.pairwise_rank(super().forward(*args, **kwargs)) # [b, 2] | ||
|
||
class MutSmiFullConnection(MutSmi): | ||
"""Regression model using fully connection.""" | ||
|
||
def __init__( | ||
self, | ||
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, | ||
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, | ||
) -> None: | ||
super().__init__(gene_conf, smiles_conf) | ||
d_model = self.smiles_conf.d_model | ||
d_hidden = d_model // 2 | ||
self.reg = nn.Sequential( | ||
nn.Linear(d_model, d_hidden), | ||
nn.ReLU(), | ||
nn.Dropout(0.1), | ||
nn.Linear(d_hidden, 1), | ||
) | ||
|
||
def forward( | ||
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor | ||
) -> torch.Tensor: | ||
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) # [b, dmodel] | ||
gene_out = self.gene_encoder(gene_src) # [b, dmodel] | ||
feat = smiles_out + gene_out | ||
|
||
return self.reg(feat) # [b, 1] | ||
def forward_cmp(self, *args, **kwargs) -> float: | ||
""" | ||
for infer, no batch dim | ||
""" | ||
out = self.forward(*args, **kwargs) | ||
return (out[1] - out[0]).item() |
Oops, something went wrong.