Skip to content

Commit

Permalink
Merge pull request #3 from zzzseeu/feat-MutSmiFullConnection
Browse files Browse the repository at this point in the history
Feat MutSmiFullConnection
  • Loading branch information
yandy authored May 17, 2024
2 parents 625133e + 7106d54 commit d71e2e8
Show file tree
Hide file tree
Showing 11 changed files with 528 additions and 99 deletions.
9 changes: 9 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[run]
omit =
*/_version.py
relative_files = True

[report]
exclude_lines =
^VERSION =
raise NotImplementedError
12 changes: 12 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[flake8]
# https://github.com/ambv/black/blob/master/.flake8
# E203 whitespace before ':'
# E266 Too many leading '#' for block comment
# E501 Line too long
# W503 Line break occurred before a binary operator
ignore = E203, E266, E501, W503, W504, B950, B008
max-line-length = 120
max-complexity = 18
exclude =
.git,
tests,
71 changes: 39 additions & 32 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,48 @@
import typing
import torch
from moltx.datasets import Base


class MutSmi(Base):
"""Base datasets, convert smiles and mutations to torch.Tensor.
"""
pass
from moltx import tokenizers, datasets


class MutSmi:
"""Base datasets, convert smiles and genes to torch.Tensor."""

def __init__(
self,
smi_tokenizer: tokenizers.MoltxTokenizer,
device: torch.device = torch.device("cpu"),
) -> None:
self.smi_tokenizer = smi_tokenizer
self.smi_ds = datasets.AdaMRRegression(self.smi_tokenizer)
self.device = device

def gen_smi_token(
self,
smiles: typing.Sequence[str],
values: typing.Sequence[float],
seq_len: int = 200,
) -> torch.Tensor:
src, tgt, out = self.smi_ds(smiles, values, seq_len)
return src.to(self.device), tgt.to(self.device), out.to(self.device)

def gen_gene_token(self, genes: typing.Sequence[list]) -> torch.Tensor:
return torch.tensor(genes, dtype=torch.float).to(self.device)

def __call__(
self,
smiles: typing.Sequence[str],
genes: typing.Sequence[list],
values: typing.Sequence[float],
seq_len: int = 200,
) -> typing.Tuple[torch.Tensor]:
smi_src, smi_tgt, out = self.gen_smi_token(smiles, values, seq_len)
gene_src = self.gen_gene_token(genes)
return smi_src, smi_tgt, gene_src, out


class MutSmiXAttention(MutSmi):
"""Regression task datasets.
"""
def __call__(self,
smiles: typing.Sequence[str],
mutations: typing.Sequence[list],
values: typing.Sequence[float]
) -> typing.Tuple[torch.Tensor]:
"""_summary_
Args:
smiles (typing.Sequence[str]): molecule smiles.
mutations (typing.Sequence[list]): mutations one-hot list.
values (typing.Sequence[float]): actual inhibitation rate.
Returns:
smiles_src, smiles_tgt, mutations_src, out
"""
pass

pass

class MutSmiFullConnection(MutSmiXAttention):
"""_summary_

Args:
MutSmi (_type_): _description_
"""
class MutSmiFullConnection(MutSmi):
pass
111 changes: 78 additions & 33 deletions dooc/models.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,100 @@
import torch
from torch import nn
from dooc import nets
from moltx.models import AdaMR
from moltx.nets import AbsPosEncoderDecoderConfig


class MutSmi(nn.Module):
"""Base
"""Base"""

"""
def __init__(self):
pass
def __init__(
self,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__()
self.gene_conf = gene_conf
self.smiles_conf = smiles_conf
self.smiles_encoder = AdaMR(smiles_conf)

self.gene_encoder = nets.GeneGNN(gene_conf)

def load_ckpt(self, ckpt_file: str) -> None:
def load_ckpt(self, *ckpt_files: str) -> None:
"""load check point model.
Args:
ckpt_file (str): check point file path.
ckpt_files (str): check point file paths.
"""
return
self.load_state_dict(
torch.load(ckpt_files[0], map_location=torch.device("cpu")), strict=False
)

def load_pretrained_ckpt(self,
drugcell_ckpt: str,
moltx_ckpt: str) -> None:
"""load drugcell checkpoint and moltx checkpoint.
def load_pretrained_ckpt(self, drugcell_ckpt: str, moltx_ckpt: str, freeze_drugcell: bool = False, freeze_moltx: bool = False) -> None:
self.gene_encoder.load_ckpt(drugcell_ckpt)
self.smiles_encoder.load_ckpt(moltx_ckpt)
if freeze_moltx:
self.smiles_encoder.requires_grad_(False)
if freeze_drugcell:
self.gene_encoder.requires_grad_(False)

Args:
drugcell_ckpt (str): ckpt file path.
moltx_ckpt (str): ckpt file path.
"""
return
def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
raise NotImplementedError()


class MutSmiXAttention(MutSmi):
"""Regression model using transformer cross attention.
"""
def __init__(self):
pass
"""Regression model using transformer cross attention."""

def __init__(
self,
d_model: int,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
self.d_model = d_model

def forward(self,
smiles_src: torch.Tensor,
smiles_tgt: torch.Tensor,
mutations_src: torch.Tensor) -> torch.Tensor:
def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
pass


class MutSmiFullConnection(MutSmi):
"""Regression model using fully connection.
"""
def __init__(self):
pass
"""Regression model using fully connection."""

def forward(self,
smiles_src: torch.Tensor,
smiles_tgt: torch.Tensor,
mutations_src: torch.Tensor) -> torch.Tensor:
pass
def __init__(
self,
d_model: int,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
self.d_model = d_model

d_hidden = d_model // 2
self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model)
self.reg = nn.Sequential(
nn.Linear(self.d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
)

def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
is_batched = smiles_src.dim() == 2

smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src).unsqueeze(1)

feat = None
if is_batched:
feat = smiles_out + self.gene_fc(gene_out)[:, 0]
else:
feat = smiles_out[0] + self.gene_fc(gene_out)[0]

return self.reg(feat)
Loading

0 comments on commit d71e2e8

Please sign in to comment.