diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..cdb4802 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,9 @@ +[run] +omit = + */_version.py +relative_files = True + +[report] +exclude_lines = + ^VERSION = + raise NotImplementedError diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..d90c601 --- /dev/null +++ b/.flake8 @@ -0,0 +1,12 @@ +[flake8] +# https://github.com/ambv/black/blob/master/.flake8 +# E203 whitespace before ':' +# E266 Too many leading '#' for block comment +# E501 Line too long +# W503 Line break occurred before a binary operator +ignore = E203, E266, E501, W503, W504, B950, B008 +max-line-length = 120 +max-complexity = 18 +exclude = + .git, + tests, diff --git a/dooc/datasets.py b/dooc/datasets.py index 4aadf4a..2f2f3bb 100644 --- a/dooc/datasets.py +++ b/dooc/datasets.py @@ -1,41 +1,48 @@ import typing import torch -from moltx.datasets import Base - -class MutSmi(Base): - """Base datasets, convert smiles and mutations to torch.Tensor. - - """ - pass +from moltx import tokenizers, datasets + + +class MutSmi: + """Base datasets, convert smiles and genes to torch.Tensor.""" + + def __init__( + self, + smi_tokenizer: tokenizers.MoltxTokenizer, + device: torch.device = torch.device("cpu"), + ) -> None: + self.smi_tokenizer = smi_tokenizer + self.smi_ds = datasets.AdaMRRegression(self.smi_tokenizer) + self.device = device + + def gen_smi_token( + self, + smiles: typing.Sequence[str], + values: typing.Sequence[float], + seq_len: int = 200, + ) -> torch.Tensor: + src, tgt, out = self.smi_ds(smiles, values, seq_len) + return src.to(self.device), tgt.to(self.device), out.to(self.device) + + def gen_gene_token(self, genes: typing.Sequence[list]) -> torch.Tensor: + return torch.tensor(genes, dtype=torch.float).to(self.device) + + def __call__( + self, + smiles: typing.Sequence[str], + genes: typing.Sequence[list], + values: typing.Sequence[float], + seq_len: int = 200, + ) -> typing.Tuple[torch.Tensor]: + smi_src, smi_tgt, out = self.gen_smi_token(smiles, values, seq_len) + gene_src = self.gen_gene_token(genes) + return smi_src, smi_tgt, gene_src, out class MutSmiXAttention(MutSmi): - """Regression task datasets. - - """ - def __call__(self, - smiles: typing.Sequence[str], - mutations: typing.Sequence[list], - values: typing.Sequence[float] - ) -> typing.Tuple[torch.Tensor]: - """_summary_ - - Args: - smiles (typing.Sequence[str]): molecule smiles. - mutations (typing.Sequence[list]): mutations one-hot list. - values (typing.Sequence[float]): actual inhibitation rate. - - Returns: - smiles_src, smiles_tgt, mutations_src, out - """ - pass - + pass -class MutSmiFullConnection(MutSmiXAttention): - """_summary_ - Args: - MutSmi (_type_): _description_ - """ +class MutSmiFullConnection(MutSmi): pass diff --git a/dooc/models.py b/dooc/models.py index a2b6867..9f459fd 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -1,55 +1,100 @@ import torch from torch import nn +from dooc import nets +from moltx.models import AdaMR +from moltx.nets import AbsPosEncoderDecoderConfig class MutSmi(nn.Module): - """Base + """Base""" - """ - def __init__(self): - pass + def __init__( + self, + gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, + smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, + ) -> None: + super().__init__() + self.gene_conf = gene_conf + self.smiles_conf = smiles_conf + self.smiles_encoder = AdaMR(smiles_conf) + + self.gene_encoder = nets.GeneGNN(gene_conf) - def load_ckpt(self, ckpt_file: str) -> None: + def load_ckpt(self, *ckpt_files: str) -> None: """load check point model. Args: - ckpt_file (str): check point file path. + ckpt_files (str): check point file paths. """ - return + self.load_state_dict( + torch.load(ckpt_files[0], map_location=torch.device("cpu")), strict=False + ) - def load_pretrained_ckpt(self, - drugcell_ckpt: str, - moltx_ckpt: str) -> None: - """load drugcell checkpoint and moltx checkpoint. + def load_pretrained_ckpt(self, drugcell_ckpt: str, moltx_ckpt: str, freeze_drugcell: bool = False, freeze_moltx: bool = False) -> None: + self.gene_encoder.load_ckpt(drugcell_ckpt) + self.smiles_encoder.load_ckpt(moltx_ckpt) + if freeze_moltx: + self.smiles_encoder.requires_grad_(False) + if freeze_drugcell: + self.gene_encoder.requires_grad_(False) - Args: - drugcell_ckpt (str): ckpt file path. - moltx_ckpt (str): ckpt file path. - """ - return + def forward( + self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor + ) -> torch.Tensor: + raise NotImplementedError() class MutSmiXAttention(MutSmi): - """Regression model using transformer cross attention. - """ - def __init__(self): - pass + """Regression model using transformer cross attention.""" + + def __init__( + self, + d_model: int, + gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, + smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, + ) -> None: + super().__init__(gene_conf, smiles_conf) + self.d_model = d_model - def forward(self, - smiles_src: torch.Tensor, - smiles_tgt: torch.Tensor, - mutations_src: torch.Tensor) -> torch.Tensor: + def forward( + self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor + ) -> torch.Tensor: pass class MutSmiFullConnection(MutSmi): - """Regression model using fully connection. - """ - def __init__(self): - pass + """Regression model using fully connection.""" - def forward(self, - smiles_src: torch.Tensor, - smiles_tgt: torch.Tensor, - mutations_src: torch.Tensor) -> torch.Tensor: - pass + def __init__( + self, + d_model: int, + gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, + smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, + ) -> None: + super().__init__(gene_conf, smiles_conf) + self.d_model = d_model + + d_hidden = d_model // 2 + self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model) + self.reg = nn.Sequential( + nn.Linear(self.d_model, d_hidden), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(d_hidden, 1), + ) + + def forward( + self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor + ) -> torch.Tensor: + is_batched = smiles_src.dim() == 2 + + smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) + gene_out = self.gene_encoder(gene_src).unsqueeze(1) + + feat = None + if is_batched: + feat = smiles_out + self.gene_fc(gene_out)[:, 0] + else: + feat = smiles_out[0] + self.gene_fc(gene_out)[0] + + return self.reg(feat) diff --git a/dooc/nets.py b/dooc/nets.py index 758ec2f..9a0150e 100644 --- a/dooc/nets.py +++ b/dooc/nets.py @@ -1,13 +1,226 @@ +import os import torch from torch import nn +from dataclasses import dataclass +from dooc.utils import load_gene_mapping, load_ontology -class GNN(nn.Module): +@dataclass +class GeneGNNConfig: + gene_dim: int + drug_dim: int + num_hiddens_genotype: int + num_hiddens_drug: list + num_hiddens_final: int + gene2ind_path: str + ont_path: str + + +class GeneGNN(nn.Module): """GNN for mutations embeddings. + reference: https://github.com/idekerlab/DrugCell/ + """ - def __init__(self): - pass - def forward(self, x) -> torch.Tensor: - pass + DEFAULT_CONFIG = GeneGNNConfig( + gene_dim=3008, + drug_dim=2048, + num_hiddens_genotype=6, + num_hiddens_drug=[100, 50, 6], + num_hiddens_final=6, + gene2ind_path=os.path.join(os.path.dirname(__file__), "data", "gene2ind.txt"), + ont_path=os.path.join(os.path.dirname(__file__), "data", "drugcell_ont.txt"), + ) + + def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None: + super().__init__() + self.conf = conf + + dg, dg_root, term_size_map, term_direct_gene_map = self._get_params() + self.dg, self.dg_root = dg, dg_root + self.term_size_map, self.term_direct_gene_map = ( + term_size_map, + term_direct_gene_map, + ) + + self._cal_term_dim() + self._contruct_direct_gene_layer() + self._construct_nn_graph() + self._construct_nn_drug() + self._construct_final_layer() + + def _contruct_direct_gene_layer(self): + """ + build a layer for forwarding gene that are directly annotated with the term + """ + for term, gene_set in self.term_direct_gene_map.items(): + if len(gene_set) == 0: + raise ValueError(f"There are no directed asscoiated genes for {term}") + + # if there are some genes directly annotated with the term, add a layer taking in all genes and forwarding out only those genes + self.add_module( + term + "_direct_gene_layer", + nn.Linear(self.conf.gene_dim, len(gene_set)), + ) + + def _construct_nn_drug(self): + """ + add modules for fully connected neural networks for drug processing + """ + input_size = self.conf.drug_dim + + for i in range(len(self.conf.num_hiddens_drug)): + self.add_module( + "drug_linear_layer_" + str(i + 1), + nn.Linear(input_size, self.conf.num_hiddens_drug[i]), + ) + self.add_module( + "drug_batchnorm_layer_" + str(i + 1), + nn.BatchNorm1d(self.conf.num_hiddens_drug[i]), + ) + self.add_module( + "drug_aux_linear_layer1_" + str(i + 1), + nn.Linear(self.conf.num_hiddens_drug[i], 1), + ) + self.add_module("drug_aux_linear_layer2_" + str(i + 1), nn.Linear(1, 1)) + + input_size = self.conf.num_hiddens_drug[i] + + def _construct_nn_graph(self): + """ + start from bottom (leaves), and start building a neural network using the given ontology + adding modules --- the modules are not connected yet + """ + self.term_layer_list = [] # term_layer_list stores the built neural network + self.term_neighbor_map = {} + # term_neighbor_map records all children of each term + for term in self.dg.nodes(): + self.term_neighbor_map[term] = [] + for child in self.dg.neighbors(term): + self.term_neighbor_map[term].append(child) + + while True: + leaves = [n for n in self.dg.nodes() if self.dg.out_degree(n) == 0] + # leaves = [n for n,d in self.dg.out_degree().items() if d==0] + # leaves = [n for n,d in self.dg.out_degree() if d==0] + + if len(leaves) == 0: + break + + self.term_layer_list.append(leaves) + + for term in leaves: + + # input size will be #chilren + #genes directly annotated by the term + input_size = 0 + + for child in self.term_neighbor_map[term]: + input_size += self.term_dim_map[child] + + if term in self.term_direct_gene_map: + input_size += len(self.term_direct_gene_map[term]) + + # term_hidden is the number of the hidden variables in each state + term_hidden = self.term_dim_map[term] + + self.add_module( + term + "_linear_layer", nn.Linear(input_size, term_hidden) + ) + self.add_module(term + "_batchnorm_layer", nn.BatchNorm1d(term_hidden)) + self.add_module(term + "_aux_linear_layer1", nn.Linear(term_hidden, 1)) + self.add_module(term + "_aux_linear_layer2", nn.Linear(1, 1)) + + self.dg.remove_nodes_from(leaves) + + def _construct_final_layer(self): + """ + add modules for final layer + """ + final_input_size = ( + self.conf.num_hiddens_genotype + self.conf.num_hiddens_drug[-1] + ) + self.add_module( + "final_linear_layer", + nn.Linear(final_input_size, self.conf.num_hiddens_final), + ) + self.add_module( + "final_batchnorm_layer", nn.BatchNorm1d(self.conf.num_hiddens_final) + ) + self.add_module( + "final_aux_linear_layer", nn.Linear(self.conf.num_hiddens_final, 1) + ) + self.add_module("final_linear_layer_output", nn.Linear(1, 1)) + + def _cal_term_dim(self): + """ + calculate the number of values in a state (term) + term_size_map is the number of all genes annotated with the term + """ + self.term_dim_map = {} + for term, term_size in self.term_size_map.items(): + num_output = self.conf.num_hiddens_genotype + + # log the number of hidden variables per each term + num_output = int(num_output) + self.term_dim_map[term] = num_output + + def _get_params(self): + gene2id_mapping = load_gene_mapping(self.conf.gene2ind_path) + dg, dg_root, term_size_map, term_direct_gene_map = load_ontology( + self.conf.ont_path, gene2id_mapping + ) + return dg, dg_root, term_size_map, term_direct_gene_map + + def load_ckpt(self, *ckpt_files: str) -> None: + self.load_state_dict( + torch.load(ckpt_files[0], map_location=torch.device("cpu")), strict=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + removed drug layer + """ + + gene_input = x.narrow(1, 0, self.conf.gene_dim) + # drug_input = x.narrow(1, self.conf.gene_dim, self.conf.drug_dim) + + # define forward function for genotype dcell ############################################# + term_gene_out_map = {} + + for term, _ in self.term_direct_gene_map.items(): + term_gene_out_map[term] = self._modules[term + "_direct_gene_layer"]( + gene_input + ) + + term_nn_out_map = {} + aux_out_map = {} + + for _, layer in enumerate(self.term_layer_list): + + for term in layer: + + child_input_list = [] + + for child in self.term_neighbor_map[term]: + child_input_list.append(term_nn_out_map[child]) + + if term in self.term_direct_gene_map: + child_input_list.append(term_gene_out_map[term]) + + child_input = torch.cat(child_input_list, 1) + + term_nn_out = self._modules[term + "_linear_layer"](child_input) + + tanh_out = torch.tanh(term_nn_out) + term_nn_out_map[term] = self._modules[term + "_batchnorm_layer"]( + tanh_out + ) + aux_layer1_out = torch.tanh( + self._modules[term + "_aux_linear_layer1"](term_nn_out_map[term]) + ) + aux_out_map[term] = self._modules[term + "_aux_linear_layer2"]( + aux_layer1_out + ) + + return term_nn_out_map[self.dg_root] diff --git a/dooc/pipelines.py b/dooc/pipelines.py index ffc35d1..448f477 100644 --- a/dooc/pipelines.py +++ b/dooc/pipelines.py @@ -1,27 +1,67 @@ import typing import torch +from moltx import tokenizers +from dooc import models class MutSmi: - def __init__(self): - pass + def __init__( + self, + smi_tokenizer: tokenizers.MoltxTokenizer, + model: models.MutSmi, + device: torch.device = torch.device("cpu"), + ) -> None: + self.smi_tokenizer = smi_tokenizer + + model = model.to(device) + model.eval() + model.requires_grad_(False) + self.model = model + + self.device = device + + def _model_args( + self, gene: typing.Sequence[int], smiles: str + ) -> typing.Tuple[torch.Tensor]: + smi_src = self.gen_smi_token(smiles) + smi_tgt = self.gen_smi_token(self.smi_tokenizer.BOS + smiles + self.smi_tokenizer.EOS) + gene_src = self.gen_gene_token(gene) + return smi_src, smi_tgt, gene_src + + def gen_smi_token(self, smiles: str) -> torch.Tensor: + tokens = self.smi_tokenizer(smiles) + res = torch.zeros(len(tokens), dtype=torch.int) + for i, tk in enumerate(tokens): + res[i] = tk + return res.to(self.device) + + def gen_gene_token(self, gene: typing.Sequence[float]) -> torch.Tensor: + return torch.tensor(gene, dtype=torch.float).to(self.device) class MutSmiXAttention(MutSmi): - def _model_args(self, - mutation: typing.Sequence[int], - smiles: str) -> typing.Tuple[torch.Tensor]: - return + def __init__( + self, + smi_tokenizer: tokenizers.MoltxTokenizer, + model: models.MutSmiXAttention, + device: torch.device = torch.device("cpu"), + ) -> None: + super().__init__(smi_tokenizer, model, device) - def __call__(self, mutation: typing.Sequence[int], smiles: str) -> float: + def __call__(self, gene: typing.Sequence[int], smiles: str) -> float: return class MutSmiFullConnection(MutSmi): - def _model_args(self, - mutation: typing.Sequence[int], - smiles: str) -> typing.Tuple[torch.Tensor]: - return + def __init__( + self, + smi_tokenizer: tokenizers.MoltxTokenizer, + model: models.MutSmiFullConnection, + device: torch.device = torch.device("cpu"), + ) -> None: + super().__init__(smi_tokenizer, model, device) - def __call__(self, mutation: typing.Sequence[int], smiles: str) -> float: - return + def __call__(self, gene: typing.Sequence[int], smiles: str) -> float: + smi_src, smi_tgt, gene_src = self._model_args(gene, smiles) + pred = self.model(smi_src, smi_tgt, gene_src) + return pred.item() diff --git a/dooc/utils.py b/dooc/utils.py new file mode 100644 index 0000000..8ce035a --- /dev/null +++ b/dooc/utils.py @@ -0,0 +1,80 @@ +import typing +import networkx as nx +from collections import defaultdict +import networkx.algorithms.components.connected as nxacc +import networkx.algorithms.dag as nxadag + + +def load_gene_mapping(file_path: str) -> dict: + res = {} + + with open(file_path) as f: + for line in f: + line = line.rstrip().split() + res[line[1]] = int(line[0]) + + return res + + +def load_ontology(file_name: str, gene2id_mapping: dict) -> typing.Sequence: + dg = nx.DiGraph() + term_direct_gene_map = defaultdict(set) + + term_size_map, gene_set = {}, set() + + file_handle = open(file_name) + for line in file_handle: + line = line.rstrip().split() + if line[2] == "default": + dg.add_edge(line[0], line[1]) + continue + + if line[1] not in gene2id_mapping: + continue + if line[0] not in term_direct_gene_map: + term_direct_gene_map[line[0]] = set() + + term_direct_gene_map[line[0]].add(gene2id_mapping[line[1]]) + gene_set.add(line[1]) + file_handle.close() + + print("There are", len(gene_set), "genes") + + leaves = [] + for term in dg.nodes(): + term_gene_set = set() + if term in term_direct_gene_map: + term_gene_set = term_direct_gene_map[term] + + deslist = nxadag.descendants(dg, term) + + for child in deslist: + if child in term_direct_gene_map: + term_gene_set = term_gene_set | term_direct_gene_map[child] + + if len(term_gene_set) == 0: + raise ValueError(f"There is empty terms, please delete term: {term}") + + term_size_map[term] = len(term_gene_set) + + if dg.in_degree(term) == 0: + leaves.append(term) + + ug = dg.to_undirected() + connected_subg_list = list(nxacc.connected_components(ug)) + + print("There are", len(leaves), "roots:", leaves[0]) + print("There are", len(dg.nodes()), "terms") + print("There are", len(connected_subg_list), "connected componenets") + + if len(leaves) > 1: + raise ValueError( + "There are more than 1 root of ontology. Please use only one root." + ) + + if len(connected_subg_list) > 1: + raise ValueError( + "There are more than connected components. Please connect them." + ) + + return dg, leaves[0], term_size_map, term_direct_gene_map diff --git a/requirements.txt b/requirements.txt index 84103ec..1003af0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ -moltx~=1.0.0 +moltx~=1.0.2 +networkx~=3.1 +scikit-learn~=1.3.0 diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 8b224be..70e741e 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -43,7 +43,7 @@ def test_MutSmiFullConnection(): smiles_src, smiles_tgt, mutations_src, out = ds(smiles, mutations, values) - assert smiles_src.shape == (2, 4) - assert smiles_tgt.shape == (2, 6) + assert smiles_src.shape == (2, 200) + assert smiles_tgt.shape == (2, 200) assert mutations_src.shape == (2, 3008) assert out.shape == (2, 1) diff --git a/tests/test_models.py b/tests/test_models.py index 48aea7d..227d344 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,5 +1,6 @@ import torch -from dooc import models +from dooc import models, nets +import random # def test_MutSmiXAttention(): @@ -12,11 +13,22 @@ # assert out.shape == (1,) -def test_MutSmiFullConnection(): - smiles_src = torch.randint(0, 64, [2, 4]) - smiles_tgt = torch.randint(0, 64, [2, 6]) - mutations_src = torch.randn(2, 3008, dtype=torch.float32) +def test_MutSmiFullConnection(datadir): + smiles_src = torch.randint(0, 64, [2, 200]) + smiles_tgt = torch.randint(0, 64, [2, 200]) + mutations = [[random.choice([0, 1]) for _ in range(3008)], + [random.choice([0, 1]) for _ in range(3008)]] + mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu") d_model = 768 - model = models.MutSmiFullConnection(d_model) + gene_conf = nets.GeneGNNConfig( + gene_dim=3008, + drug_dim=2048, + num_hiddens_genotype=6, + num_hiddens_drug=[100, 50, 6], + num_hiddens_final=6, + gene2ind_path=f"{datadir}/gene2ind.txt", + ont_path=f"{datadir}/drugcell_ont.txt", + ) + model = models.MutSmiFullConnection(d_model, gene_conf=gene_conf) out = model(smiles_src, smiles_tgt, mutations_src) - assert out.shape == (1,) + assert out.shape == (2, 1) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4fc7f73..570094a 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1,5 +1,5 @@ import random -from dooc import pipelines, models +from dooc import pipelines, models, nets from moltx import tokenizers as tkz from moltx.models import AdaMRTokenizerConfig @@ -18,16 +18,25 @@ # assert isinstance(predict, float) -def test_MutSmiFullConnection(): +def test_MutSmiFullConnection(datadir): tokenizer = tkz.MoltxTokenizer.from_pretrain( conf=AdaMRTokenizerConfig.Prediction ) d_model = 768 - model = models.MutSmiFullConnection(d_model) + gene_conf = nets.GeneGNNConfig( + gene_dim=3008, + drug_dim=2048, + num_hiddens_genotype=6, + num_hiddens_drug=[100, 50, 6], + num_hiddens_final=6, + gene2ind_path=f"{datadir}/gene2ind.txt", + ont_path=f"{datadir}/drugcell_ont.txt", + ) + model = models.MutSmiFullConnection(d_model, gene_conf) # model.load_ckpt('/path/to/mutsmifullconnection.ckpt') - pipeline = pipelines.MutSmiFullConnection(tokenizer, - model) - mutation = [random.choice([1, 0]) for _ in range(3008)] + pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer, + model=model) + mutation = [[random.choice([1, 0]) for _ in range(3008)]] smiles = "CC[N+](C)(C)Cc1ccccc1Br" predict = pipeline(mutation, smiles) assert isinstance(predict, float)