Skip to content

Commit

Permalink
Merge pull request js-ish#5 from zzzseeu/feat-MutSmiXAttention-models
Browse files Browse the repository at this point in the history
feat: MutSmiXAttention models-pipelines-datasets-test
  • Loading branch information
yandy authored May 23, 2024
2 parents 06b5ad0 + e4733ae commit 8337dd6
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 86 deletions.
39 changes: 22 additions & 17 deletions dooc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,35 +49,47 @@ class MutSmiXAttention(MutSmi):

def __init__(
self,
d_model: int,
nhead: int = 2,
num_layers: int = 2,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
self.d_model = d_model
d_model = self.smiles_conf.d_model
d_hidden = d_model // 2
layer = nn.TransformerDecoderLayer(d_model, nhead)
self.cross_att = nn.TransformerDecoder(layer, num_layers)
self.reg = nn.Sequential(
nn.Linear(d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
)

def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
pass
assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src)
feat = self.cross_att(smiles_out, gene_out)

return self.reg(feat)


class MutSmiFullConnection(MutSmi):
"""Regression model using fully connection."""

def __init__(
self,
d_model: int,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
self.d_model = d_model

d_model = self.smiles_conf.d_model
d_hidden = d_model // 2
self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model)
self.reg = nn.Sequential(
nn.Linear(self.d_model, d_hidden),
nn.Linear(d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
Expand All @@ -86,15 +98,8 @@ def __init__(
def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
is_batched = smiles_src.dim() == 2

smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src).unsqueeze(1)

feat = None
if is_batched:
feat = smiles_out + self.gene_fc(gene_out)[:, 0]
else:
feat = smiles_out[0] + self.gene_fc(gene_out)[0]
gene_out = self.gene_encoder(gene_src)
feat = smiles_out + gene_out

return self.reg(feat)
12 changes: 10 additions & 2 deletions dooc/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

@dataclass
class GeneGNNConfig:
d_model: int
gene_dim: int
drug_dim: int
num_hiddens_genotype: int
Expand All @@ -24,6 +25,7 @@ class GeneGNN(nn.Module):
"""

DEFAULT_CONFIG = GeneGNNConfig(
d_model=768,
gene_dim=3008,
drug_dim=2048,
num_hiddens_genotype=6,
Expand All @@ -36,6 +38,7 @@ class GeneGNN(nn.Module):
def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None:
super().__init__()
self.conf = conf
d_model = self.conf.d_model

dg, dg_root, term_size_map, term_direct_gene_map = self._get_params()
self.dg, self.dg_root = dg, dg_root
Expand All @@ -49,6 +52,8 @@ def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None:
self._construct_nn_graph()
self._construct_nn_drug()
self._construct_final_layer()
self.out_fc = nn.Linear(self.conf.num_hiddens_genotype,
d_model)

def _contruct_direct_gene_layer(self):
"""
Expand Down Expand Up @@ -181,7 +186,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
removed drug layer
"""

x_dim = x.dim()
x = x.unsqueeze(0) if x_dim == 1 else x
gene_input = x.narrow(1, 0, self.conf.gene_dim)
# drug_input = x.narrow(1, self.conf.gene_dim, self.conf.drug_dim)

Expand Down Expand Up @@ -223,4 +229,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
aux_layer1_out
)

return term_nn_out_map[self.dg_root]
out = term_nn_out_map[self.dg_root]
out = self.out_fc(out)
return out
6 changes: 5 additions & 1 deletion dooc/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def __init__(
super().__init__(smi_tokenizer, model, device)

def __call__(self, gene: typing.Sequence[int], smiles: str) -> float:
return
smi_src, smi_tgt, gene_src = self._model_args(gene, smiles)
smi_src = smi_src.unsqueeze(0)
smi_tgt = smi_tgt.unsqueeze(0)
pred = self.model(smi_src, smi_tgt, gene_src)
return pred.item()


class MutSmiFullConnection(MutSmi):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
moltx~=1.0.2
moltx~=1.0.4
networkx~=3.1
scikit-learn~=1.3.0
36 changes: 18 additions & 18 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,24 @@ def test_MutSmi():
return


# def test_MutSmiXAttention():
# tokenizer = tkz.MoltxTokenizer.from_pretrain(
# conf=AdaMRTokenizerConfig.Prediction
# )
# ds = datasets.MutSmiXAttention(tokenizer)
# smiles = ["CC[N+]CCBr", "Cc1ccc1"]
# values = [0.88, 0.89]
# mutations = [[random.choice([0, 1]) for _ in range(3008)],
# [random.choice([0, 1]) for _ in range(3008)]]
# with pytest.raises(RuntimeError):
# ds(smiles, mutations, values[:1])
# smiles_src, smiles_tgt, mutations_src, out = ds(smiles,
# mutations,
# values)
# assert smiles_src.shape == (2, 4)
# assert smiles_tgt.shape == (2, 6)
# assert mutations_src.shape == (2, 3008)
# assert out.shape == (2, 1)
def test_MutSmiXAttention():
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
ds = datasets.MutSmiXAttention(tokenizer)
smiles = ["CC[N+]CCBr", "Cc1ccc1"]
values = [0.88, 0.89]
mutations = [[random.choice([0, 1]) for _ in range(3008)],
[random.choice([0, 1]) for _ in range(3008)]]
with pytest.raises(RuntimeError):
ds(smiles, mutations, values[:1])
smiles_src, smiles_tgt, mutations_src, out = ds(smiles,
mutations,
values)
assert smiles_src.shape == (2, 200)
assert smiles_tgt.shape == (2, 200)
assert mutations_src.shape == (2, 3008)
assert out.shape == (2, 1)


def test_MutSmiFullConnection():
Expand Down
33 changes: 12 additions & 21 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,25 @@
import torch
from dooc import models, nets
from dooc import models
import random


# def test_MutSmiXAttention():
# smiles_src = torch.randint(0, 64, [2, 4])
# smiles_tgt = torch.randint(0, 64, [2, 6])
# mutations_src = torch.randn(2, 3008, dtype=torch.float32)
# d_model = 768
# model = models.MutSmiXAttention(d_model)
# out = model(smiles_src, smiles_tgt, mutations_src)
# assert out.shape == (1,)
def test_MutSmiXAttention():
smiles_src = torch.randint(0, 64, [2, 200])
smiles_tgt = torch.randint(0, 64, [2, 200])
mutations = [[random.choice([0, 1]) for _ in range(3008)],
[random.choice([0, 1]) for _ in range(3008)]]
mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu")
model = models.MutSmiXAttention()
out = model(smiles_src, smiles_tgt, mutations_src)
assert out.shape == (2, 1)


def test_MutSmiFullConnection(datadir):
def test_MutSmiFullConnection():
smiles_src = torch.randint(0, 64, [2, 200])
smiles_tgt = torch.randint(0, 64, [2, 200])
mutations = [[random.choice([0, 1]) for _ in range(3008)],
[random.choice([0, 1]) for _ in range(3008)]]
mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu")
d_model = 768
gene_conf = nets.GeneGNNConfig(
gene_dim=3008,
drug_dim=2048,
num_hiddens_genotype=6,
num_hiddens_drug=[100, 50, 6],
num_hiddens_final=6,
gene2ind_path=f"{datadir}/gene2ind.txt",
ont_path=f"{datadir}/drugcell_ont.txt",
)
model = models.MutSmiFullConnection(d_model, gene_conf=gene_conf)
model = models.MutSmiFullConnection()
out = model(smiles_src, smiles_tgt, mutations_src)
assert out.shape == (2, 1)
42 changes: 16 additions & 26 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
@@ -1,42 +1,32 @@
import random
from dooc import pipelines, models, nets
from dooc import pipelines, models
from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig


# def test_MutSmiXAttention():
# tokenizer = tkz.MoltxTokenizer.from_pretrain(
# conf=AdaMRTokenizerConfig.Prediction
# )
# d_model = 768
# model = models.MutSmiXAttention(d_model)
# model.load_ckpt('/path/to/mutsmixattention.ckpt')
# pipeline = pipelines.MutSmiXAttention(tokenizer, model)
# mutation = [random.choice([1, 0]) for _ in range(3008)]
# smiles = "CC[N+](C)(C)Cc1ccccc1Br"
# predict = pipeline(mutation, smiles)
# assert isinstance(predict, float)
def test_MutSmiXAttention():
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
model = models.MutSmiXAttention()
# model.load_ckpt('/path/to/mutsmixattention.ckpt')
pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer,
model=model)
mutation = [random.choice([1, 0]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutation, smiles)
assert isinstance(predict, float)


def test_MutSmiFullConnection(datadir):
def test_MutSmiFullConnection():
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
d_model = 768
gene_conf = nets.GeneGNNConfig(
gene_dim=3008,
drug_dim=2048,
num_hiddens_genotype=6,
num_hiddens_drug=[100, 50, 6],
num_hiddens_final=6,
gene2ind_path=f"{datadir}/gene2ind.txt",
ont_path=f"{datadir}/drugcell_ont.txt",
)
model = models.MutSmiFullConnection(d_model, gene_conf)
model = models.MutSmiFullConnection()
# model.load_ckpt('/path/to/mutsmifullconnection.ckpt')
pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer,
model=model)
mutation = [[random.choice([1, 0]) for _ in range(3008)]]
mutation = [random.choice([1, 0]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutation, smiles)
assert isinstance(predict, float)

0 comments on commit 8337dd6

Please sign in to comment.