Skip to content

Commit

Permalink
feat: add MutSmiFullConnection to datasets-models-pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzseeu committed May 17, 2024
1 parent 01123f6 commit 9400ede
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 37 deletions.
4 changes: 2 additions & 2 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def gen_smi_token(
self,
smiles: typing.Sequence[str],
values: typing.Sequence[float],
seq_len: int = 256,
seq_len: int = 200,
) -> torch.Tensor:
src, tgt, out = self.smi_ds(smiles, values, seq_len)
return src.to(self.device), tgt.to(self.device), out.to(self.device)
Expand All @@ -33,7 +33,7 @@ def __call__(
smiles: typing.Sequence[str],
genes: typing.Sequence[list],
values: typing.Sequence[float],
seq_len: int = 256,
seq_len: int = 200,
) -> typing.Tuple[torch.Tensor]:
smi_src, smi_tgt, out = self.gen_smi_token(smiles, values, seq_len)
gene_src = self.gen_gene_token(genes)
Expand Down
24 changes: 10 additions & 14 deletions dooc/nets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from torch import nn
from dataclasses import dataclass
from dataclasses import dataclass, field
from dooc.utils import load_gene_mapping, load_ontology


Expand All @@ -9,7 +9,7 @@ class GeneGNNConfig:
gene_dim: int
drug_dim: int
num_hiddens_genotype: int = 6
num_hiddens_drug: list = [100, 50, 6]
num_hiddens_drug: list = field(default_factory=list)
num_hiddens_final: int = 6
gene2ind_path: str = "data/gene2ind.txt"
ont_path: str = "data/drugcell_ont.txt"
Expand Down Expand Up @@ -94,15 +94,15 @@ def _construct_nn_graph(self):
self.term_layer_list = [] # term_layer_list stores the built neural network
self.term_neighbor_map = {}
# term_neighbor_map records all children of each term
for term in self.dG.nodes():
for term in self.dg.nodes():
self.term_neighbor_map[term] = []
for child in self.dG.neighbors(term):
for child in self.dg.neighbors(term):
self.term_neighbor_map[term].append(child)

while True:
leaves = [n for n in self.dG.nodes() if self.dG.out_degree(n) == 0]
# leaves = [n for n,d in self.dG.out_degree().items() if d==0]
# leaves = [n for n,d in self.dG.out_degree() if d==0]
leaves = [n for n in self.dg.nodes() if self.dg.out_degree(n) == 0]
# leaves = [n for n,d in self.dg.out_degree().items() if d==0]
# leaves = [n for n,d in self.dg.out_degree() if d==0]

if len(leaves) == 0:
break
Expand Down Expand Up @@ -130,7 +130,7 @@ def _construct_nn_graph(self):
self.add_module(term + "_aux_linear_layer1", nn.Linear(term_hidden, 1))
self.add_module(term + "_aux_linear_layer2", nn.Linear(1, 1))

self.dG.remove_nodes_from(leaves)
self.dg.remove_nodes_from(leaves)

def _construct_final_layer(self):
"""
Expand Down Expand Up @@ -158,14 +158,10 @@ def _cal_term_dim(self):
"""
self.term_dim_map = {}
for term, term_size in self.term_size_map.items():
num_output = self.num_hiddens_genoty
num_output = self.conf.num_hiddens_genotype

# log the number of hidden variables per each term
num_output = int(num_output)
print(
"term\t%s\tterm_size\t%d\tnum_hiddens\t%d"
% (term, term_size, num_output)
)
self.term_dim_map[term] = num_output

def _get_params(self):
Expand All @@ -186,7 +182,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""

gene_input = x.narrow(1, 0, self.conf.gene_dim)
# drug_input = x.narrow(1, self.gene_dim, self.drug_dim)
# drug_input = x.narrow(1, self.conf.gene_dim, self.conf.drug_dim)

# define forward function for genotype dcell #############################################
term_gene_out_map = {}
Expand Down
4 changes: 2 additions & 2 deletions dooc/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _model_args(
self, gene: typing.Sequence[int], smiles: str
) -> typing.Tuple[torch.Tensor]:
smi_src = self.gen_smi_token(smiles)
smi_tgt = self.gen_smi_token(self.smi_tokenizer.BOS + smiles + self.tokenizer.EOS)
smi_tgt = self.gen_smi_token(self.smi_tokenizer.BOS + smiles + self.smi_tokenizer.EOS)
gene_src = self.gen_gene_token(gene)
return smi_src, smi_tgt, gene_src

Expand Down Expand Up @@ -64,4 +64,4 @@ def __init__(
def __call__(self, gene: typing.Sequence[int], smiles: str) -> float:
smi_src, smi_tgt, gene_src = self._model_args(gene, smiles)
pred = self.model(smi_src, smi_tgt, gene_src)
return pred
return pred.item()
10 changes: 8 additions & 2 deletions dooc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def load_ontology(file_name: str, gene2id_mapping: dict) -> typing.Sequence:

if line[1] not in gene2id_mapping:
continue
if line[0] not in term_direct_gene_map:
term_direct_gene_map[line[0]] = set()

term_direct_gene_map[line[0]].add(gene2id_mapping[line[1]])
gene_set.add(line[1])
file_handle.close()
Expand All @@ -39,12 +42,15 @@ def load_ontology(file_name: str, gene2id_mapping: dict) -> typing.Sequence:

leaves = []
for term in dg.nodes():
term_gene_set = term_direct_gene_map[term]
term_gene_set = set()
if term in term_direct_gene_map:
term_gene_set = term_direct_gene_map[term]

deslist = nxadag.descendants(dg, term)

for child in deslist:
term_gene_set = term_gene_set | term_direct_gene_map[child]
if child in term_direct_gene_map:
term_gene_set = term_gene_set | term_direct_gene_map[child]

if len(term_gene_set) == 0:
raise ValueError(f"There is empty terms, please delete term: {term}")
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
moltx~=1.0.0
networkx
moltx~=1.0.2
networkx~=3.1
scikit-learn~=1.3.0
4 changes: 2 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_MutSmiFullConnection():
smiles_src, smiles_tgt, mutations_src, out = ds(smiles,
mutations,
values)
assert smiles_src.shape == (2, 4)
assert smiles_tgt.shape == (2, 6)
assert smiles_src.shape == (2, 200)
assert smiles_tgt.shape == (2, 200)
assert mutations_src.shape == (2, 3008)
assert out.shape == (2, 1)
26 changes: 19 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from dooc import models
from dooc import models, nets
import random


# def test_MutSmiXAttention():
Expand All @@ -12,11 +13,22 @@
# assert out.shape == (1,)


def test_MutSmiFullConnection():
smiles_src = torch.randint(0, 64, [2, 4])
smiles_tgt = torch.randint(0, 64, [2, 6])
mutations_src = torch.randn(2, 3008, dtype=torch.float32)
def test_MutSmiFullConnection(datadir):
smiles_src = torch.randint(0, 64, [2, 200])
smiles_tgt = torch.randint(0, 64, [2, 200])
mutations = [[random.choice([0, 1]) for _ in range(3008)],
[random.choice([0, 1]) for _ in range(3008)]]
mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu")
d_model = 768
model = models.MutSmiFullConnection(d_model)
gene_conf = nets.GeneGNNConfig(
gene_dim=3008,
drug_dim=2048,
num_hiddens_genotype=6,
num_hiddens_drug=[100, 50, 6],
num_hiddens_final=6,
gene2ind_path=f"{datadir}/gene2ind.txt",
ont_path=f"{datadir}/drugcell_ont.txt",
)
model = models.MutSmiFullConnection(d_model, gene_conf=gene_conf)
out = model(smiles_src, smiles_tgt, mutations_src)
assert out.shape == (1,)
assert out.shape == (2, 1)
21 changes: 15 additions & 6 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from dooc import pipelines, models
from dooc import pipelines, models, nets
from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig

Expand All @@ -18,16 +18,25 @@
# assert isinstance(predict, float)


def test_MutSmiFullConnection():
def test_MutSmiFullConnection(datadir):
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
d_model = 768
model = models.MutSmiFullConnection(d_model)
gene_conf = nets.GeneGNNConfig(
gene_dim=3008,
drug_dim=2048,
num_hiddens_genotype=6,
num_hiddens_drug=[100, 50, 6],
num_hiddens_final=6,
gene2ind_path=f"{datadir}/gene2ind.txt",
ont_path=f"{datadir}/drugcell_ont.txt",
)
model = models.MutSmiFullConnection(d_model, gene_conf)
# model.load_ckpt('/path/to/mutsmifullconnection.ckpt')
pipeline = pipelines.MutSmiFullConnection(tokenizer,
model)
mutation = [random.choice([1, 0]) for _ in range(3008)]
pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer,
model=model)
mutation = [[random.choice([1, 0]) for _ in range(3008)]]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutation, smiles)
assert isinstance(predict, float)

0 comments on commit 9400ede

Please sign in to comment.