-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3 from zzzseeu/feat-MutSmiFullConnection
Feat MutSmiFullConnection
- Loading branch information
Showing
11 changed files
with
528 additions
and
99 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[run] | ||
omit = | ||
*/_version.py | ||
relative_files = True | ||
|
||
[report] | ||
exclude_lines = | ||
^VERSION = | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
[flake8] | ||
# https://github.com/ambv/black/blob/master/.flake8 | ||
# E203 whitespace before ':' | ||
# E266 Too many leading '#' for block comment | ||
# E501 Line too long | ||
# W503 Line break occurred before a binary operator | ||
ignore = E203, E266, E501, W503, W504, B950, B008 | ||
max-line-length = 120 | ||
max-complexity = 18 | ||
exclude = | ||
.git, | ||
tests, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,48 @@ | ||
import typing | ||
import torch | ||
from moltx.datasets import Base | ||
|
||
|
||
class MutSmi(Base): | ||
"""Base datasets, convert smiles and mutations to torch.Tensor. | ||
""" | ||
pass | ||
from moltx import tokenizers, datasets | ||
|
||
|
||
class MutSmi: | ||
"""Base datasets, convert smiles and genes to torch.Tensor.""" | ||
|
||
def __init__( | ||
self, | ||
smi_tokenizer: tokenizers.MoltxTokenizer, | ||
device: torch.device = torch.device("cpu"), | ||
) -> None: | ||
self.smi_tokenizer = smi_tokenizer | ||
self.smi_ds = datasets.AdaMRRegression(self.smi_tokenizer) | ||
self.device = device | ||
|
||
def gen_smi_token( | ||
self, | ||
smiles: typing.Sequence[str], | ||
values: typing.Sequence[float], | ||
seq_len: int = 200, | ||
) -> torch.Tensor: | ||
src, tgt, out = self.smi_ds(smiles, values, seq_len) | ||
return src.to(self.device), tgt.to(self.device), out.to(self.device) | ||
|
||
def gen_gene_token(self, genes: typing.Sequence[list]) -> torch.Tensor: | ||
return torch.tensor(genes, dtype=torch.float).to(self.device) | ||
|
||
def __call__( | ||
self, | ||
smiles: typing.Sequence[str], | ||
genes: typing.Sequence[list], | ||
values: typing.Sequence[float], | ||
seq_len: int = 200, | ||
) -> typing.Tuple[torch.Tensor]: | ||
smi_src, smi_tgt, out = self.gen_smi_token(smiles, values, seq_len) | ||
gene_src = self.gen_gene_token(genes) | ||
return smi_src, smi_tgt, gene_src, out | ||
|
||
|
||
class MutSmiXAttention(MutSmi): | ||
"""Regression task datasets. | ||
""" | ||
def __call__(self, | ||
smiles: typing.Sequence[str], | ||
mutations: typing.Sequence[list], | ||
values: typing.Sequence[float] | ||
) -> typing.Tuple[torch.Tensor]: | ||
"""_summary_ | ||
Args: | ||
smiles (typing.Sequence[str]): molecule smiles. | ||
mutations (typing.Sequence[list]): mutations one-hot list. | ||
values (typing.Sequence[float]): actual inhibitation rate. | ||
Returns: | ||
smiles_src, smiles_tgt, mutations_src, out | ||
""" | ||
pass | ||
|
||
pass | ||
|
||
class MutSmiFullConnection(MutSmiXAttention): | ||
"""_summary_ | ||
|
||
Args: | ||
MutSmi (_type_): _description_ | ||
""" | ||
class MutSmiFullConnection(MutSmi): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,55 +1,100 @@ | ||
import torch | ||
from torch import nn | ||
from dooc import nets | ||
from moltx.models import AdaMR | ||
from moltx.nets import AbsPosEncoderDecoderConfig | ||
|
||
|
||
class MutSmi(nn.Module): | ||
"""Base | ||
"""Base""" | ||
|
||
""" | ||
def __init__(self): | ||
pass | ||
def __init__( | ||
self, | ||
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, | ||
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, | ||
) -> None: | ||
super().__init__() | ||
self.gene_conf = gene_conf | ||
self.smiles_conf = smiles_conf | ||
self.smiles_encoder = AdaMR(smiles_conf) | ||
|
||
self.gene_encoder = nets.GeneGNN(gene_conf) | ||
|
||
def load_ckpt(self, ckpt_file: str) -> None: | ||
def load_ckpt(self, *ckpt_files: str) -> None: | ||
"""load check point model. | ||
Args: | ||
ckpt_file (str): check point file path. | ||
ckpt_files (str): check point file paths. | ||
""" | ||
return | ||
self.load_state_dict( | ||
torch.load(ckpt_files[0], map_location=torch.device("cpu")), strict=False | ||
) | ||
|
||
def load_pretrained_ckpt(self, | ||
drugcell_ckpt: str, | ||
moltx_ckpt: str) -> None: | ||
"""load drugcell checkpoint and moltx checkpoint. | ||
def load_pretrained_ckpt(self, drugcell_ckpt: str, moltx_ckpt: str, freeze_drugcell: bool = False, freeze_moltx: bool = False) -> None: | ||
self.gene_encoder.load_ckpt(drugcell_ckpt) | ||
self.smiles_encoder.load_ckpt(moltx_ckpt) | ||
if freeze_moltx: | ||
self.smiles_encoder.requires_grad_(False) | ||
if freeze_drugcell: | ||
self.gene_encoder.requires_grad_(False) | ||
|
||
Args: | ||
drugcell_ckpt (str): ckpt file path. | ||
moltx_ckpt (str): ckpt file path. | ||
""" | ||
return | ||
def forward( | ||
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor | ||
) -> torch.Tensor: | ||
raise NotImplementedError() | ||
|
||
|
||
class MutSmiXAttention(MutSmi): | ||
"""Regression model using transformer cross attention. | ||
""" | ||
def __init__(self): | ||
pass | ||
"""Regression model using transformer cross attention.""" | ||
|
||
def __init__( | ||
self, | ||
d_model: int, | ||
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, | ||
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, | ||
) -> None: | ||
super().__init__(gene_conf, smiles_conf) | ||
self.d_model = d_model | ||
|
||
def forward(self, | ||
smiles_src: torch.Tensor, | ||
smiles_tgt: torch.Tensor, | ||
mutations_src: torch.Tensor) -> torch.Tensor: | ||
def forward( | ||
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor | ||
) -> torch.Tensor: | ||
pass | ||
|
||
|
||
class MutSmiFullConnection(MutSmi): | ||
"""Regression model using fully connection. | ||
""" | ||
def __init__(self): | ||
pass | ||
"""Regression model using fully connection.""" | ||
|
||
def forward(self, | ||
smiles_src: torch.Tensor, | ||
smiles_tgt: torch.Tensor, | ||
mutations_src: torch.Tensor) -> torch.Tensor: | ||
pass | ||
def __init__( | ||
self, | ||
d_model: int, | ||
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, | ||
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, | ||
) -> None: | ||
super().__init__(gene_conf, smiles_conf) | ||
self.d_model = d_model | ||
|
||
d_hidden = d_model // 2 | ||
self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model) | ||
self.reg = nn.Sequential( | ||
nn.Linear(self.d_model, d_hidden), | ||
nn.ReLU(), | ||
nn.Dropout(0.1), | ||
nn.Linear(d_hidden, 1), | ||
) | ||
|
||
def forward( | ||
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor | ||
) -> torch.Tensor: | ||
is_batched = smiles_src.dim() == 2 | ||
|
||
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) | ||
gene_out = self.gene_encoder(gene_src).unsqueeze(1) | ||
|
||
feat = None | ||
if is_batched: | ||
feat = smiles_out + self.gene_fc(gene_out)[:, 0] | ||
else: | ||
feat = smiles_out[0] + self.gene_fc(gene_out)[0] | ||
|
||
return self.reg(feat) |
Oops, something went wrong.