From 04125c299cb377c8e314970b5a2494baf409b313 Mon Sep 17 00:00:00 2001 From: Michael Ding Date: Sun, 2 Jun 2024 23:29:05 +0800 Subject: [PATCH] refactor --- dooc/datasets.py | 121 +++++++++++++++++---- dooc/models.py | 119 +++++---------------- dooc/nets/__init__.py | 131 +++++++++++++++++++++++ dooc/nets/heads.py | 36 +++++++ dooc/{nets.py => nets/mutations.py} | 102 ++++++++++++++++-- dooc/pipelines.py | 160 ++++++++++++++++++---------- dooc/utils.py | 80 -------------- tests/conftest.py | 9 ++ tests/test_datasets.py | 98 +++++++++-------- tests/test_models.py | 25 ----- tests/test_nets.py | 60 ++++++++++- tests/test_pipelines.py | 73 ++++++++----- 12 files changed, 660 insertions(+), 354 deletions(-) create mode 100644 dooc/nets/__init__.py create mode 100644 dooc/nets/heads.py rename dooc/{nets.py => nets/mutations.py} (73%) delete mode 100644 dooc/utils.py delete mode 100644 tests/test_models.py diff --git a/dooc/datasets.py b/dooc/datasets.py index 2f2f3bb..37789bc 100644 --- a/dooc/datasets.py +++ b/dooc/datasets.py @@ -4,45 +4,128 @@ from moltx import tokenizers, datasets -class MutSmi: +class _SmiMutBase: + def __init__(self, smi_tokenizer: tokenizers.MoltxTokenizer, device: torch.device = torch.device("cpu")) -> None: + self.smi_ds = datasets.Base(smi_tokenizer, device) + self.device = device + + def _smi_tokenize(self, smiles: typing.Sequence[str], seq_len: int = None) -> torch.Tensor: + return self.smi_ds._tokenize(smiles, seq_len) + + +""" +Mutations(Individual Sample) and Smiles Interaction + +{MutationEnc}{SmileEnc}MutSmi: 1 mut with 1 smi +{MutationEnc}{SmileEnc}MutSmis: 1 mut with n smi +{MutationEnc}{SmileEnc}MutsSmi: n mut with 1 smi +""" + + +class _DrugcellAdamrBase(_SmiMutBase): """Base datasets, convert smiles and genes to torch.Tensor.""" def __init__( self, smi_tokenizer: tokenizers.MoltxTokenizer, - device: torch.device = torch.device("cpu"), + device: torch.device = torch.device("cpu") ) -> None: + super().__init__(smi_tokenizer, device) self.smi_tokenizer = smi_tokenizer - self.smi_ds = datasets.AdaMRRegression(self.smi_tokenizer) - self.device = device - def gen_smi_token( + def _smi_tokens( self, smiles: typing.Sequence[str], - values: typing.Sequence[float], seq_len: int = 200, ) -> torch.Tensor: - src, tgt, out = self.smi_ds(smiles, values, seq_len) - return src.to(self.device), tgt.to(self.device), out.to(self.device) + src = self._smi_tokenize(smiles, seq_len) + tgt = self._smi_tokenize( + [f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len) + return src, tgt + + def _mut_tokens(self, muts: typing.Sequence[list]) -> torch.Tensor: + return torch.tensor(muts, device=self.device) - def gen_gene_token(self, genes: typing.Sequence[list]) -> torch.Tensor: - return torch.tensor(genes, dtype=torch.float).to(self.device) + def _out(self, values: typing.Sequence[float]) -> torch.Tensor: + return torch.tensor(values, device=self.device) + + +class _DrugcellAdamrMutSmi(_DrugcellAdamrBase): def __call__( self, - smiles: typing.Sequence[str], - genes: typing.Sequence[list], - values: typing.Sequence[float], - seq_len: int = 200, + muts: typing.Sequence[list], + smis: typing.Sequence[str], + vals: typing.Sequence[float], + seq_len: int = 200 ) -> typing.Tuple[torch.Tensor]: - smi_src, smi_tgt, out = self.gen_smi_token(smiles, values, seq_len) - gene_src = self.gen_gene_token(genes) - return smi_src, smi_tgt, gene_src, out + assert len(smis) == len(vals) and len(muts) == len(vals) + mut_x = self._mut_tokens(muts) + smi_src, smi_tgt = self._smi_tokens(smis, seq_len) + out = self._out(vals).unsqueeze(-1) + return mut_x, smi_src, smi_tgt, out -class MutSmiXAttention(MutSmi): +class _DrugcellAdamrMutSmis(_DrugcellAdamrBase): + + def __call__( + self, + muts: typing.Sequence[list], + lsmis: typing.Sequence[typing.Sequence[str]], + lvals: typing.Sequence[typing.Sequence[float]], + seq_len: int = 200 + ) -> typing.Tuple[torch.Tensor]: + """ + muts: [mut1, mut2, ...] mut1: [gene1, gene2, ...] + bsmiles: [[smi11, smi12], [smi21, smi22], ...] + bvlaues: [[val11, val12], [val21, val22], ...] + """ + assert len(lsmis) == len(lvals) and len(muts) == len(lvals) + mut_x = self._mut_tokens(muts) + batchlen = len(lsmis) + listlen = len(lsmis[0]) + smiles = [smi for bsmi in lsmis for smi in bsmi] + smi_src, smi_tgt = self._smi_tokens(smiles, seq_len) + smi_src = smi_src.reshape(batchlen, listlen, smi_src.size(-1)) + smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_src.size(-1)) + out = self._out(lvals) + return mut_x, smi_src, smi_tgt, out + + +class _DrugcellAdamr2Base(_SmiMutBase): + pass + + +class _DrugcellAdamr2MutSmi(_DrugcellAdamr2Base): + pass + + +class _DrugcellAdamr2MutSmisPairwiseRank(_DrugcellAdamr2Base): pass -class MutSmiFullConnection(MutSmi): +""" +Mutations(Individual Sample) and Smiles Interaction + +MutSmiReg +MutSmis{Pair/List} +MutsSmi{Pair/List} +""" + + +class MutSmiReg(_DrugcellAdamrMutSmi): pass + + +class MutSmisPairwise(_DrugcellAdamrMutSmis): + def __call__( + self, + muts: typing.Sequence[list], + lsmiles: typing.Sequence[typing.Sequence[str]], + lvalues: typing.Sequence[typing.Sequence[float]], + seq_len: int = 200 + ) -> typing.Tuple[torch.Tensor]: + mut_x, smi_src, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len) + out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device) + out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1 + return mut_x, smi_src, smi_tgt, out.unsqueeze(-1) diff --git a/dooc/models.py b/dooc/models.py index 77ba8a9..2e310b8 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -1,105 +1,40 @@ import torch -from torch import nn -from dooc import nets -from moltx.models import AdaMR -from moltx.nets import AbsPosEncoderDecoderConfig +from moltx import models as mmodels +from dooc import nets as dnets +from dooc.nets import mutations, heads -class MutSmi(nn.Module): - """Base""" +""" +Mutations(Individual Sample) and Smiles Interaction - def __init__( - self, - gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, - smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, - ) -> None: - super().__init__() - self.gene_conf = gene_conf - self.smiles_conf = smiles_conf - self.smiles_encoder = AdaMR(smiles_conf) +MutSmiReg +MutSmis{Pair/List} +MutsSmi{Pair/List} +""" - self.gene_encoder = nets.GeneGNN(gene_conf) - def load_ckpt(self, *ckpt_files: str) -> None: - """load check point model. +class MutSmiReg(dnets.DrugcellAdamrMutSmiXattn): - Args: - ckpt_files (str): check point file paths. - """ - self.load_state_dict( - torch.load(ckpt_files[0], map_location=torch.device("cpu")), strict=False - ) - - def load_pretrained_ckpt(self, drugcell_ckpt: str, moltx_ckpt: str, freeze_drugcell: bool = False, freeze_moltx: bool = False) -> None: - self.gene_encoder.load_ckpt(drugcell_ckpt) - self.smiles_encoder.load_ckpt(moltx_ckpt) - if freeze_moltx: - self.smiles_encoder.requires_grad_(False) - if freeze_drugcell: - self.gene_encoder.requires_grad_(False) - - def forward( - self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor - ) -> torch.Tensor: - raise NotImplementedError() - - -class MutSmiXAttention(MutSmi): - """Regression model using transformer cross attention.""" + def __init__(self) -> None: + super().__init__(mut_conf=mutations.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE) + self.reg = heads.RegHead(self.smi_conf.d_model) - def __init__( - self, - nhead: int = 2, - num_layers: int = 2, - gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, - smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, - ) -> None: - super().__init__(gene_conf, smiles_conf) - d_model = self.smiles_conf.d_model - d_hidden = d_model // 2 - layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) - self.cross_att = nn.TransformerDecoder(layer, num_layers) - self.reg = nn.Sequential( - nn.Linear(d_model, d_hidden), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(d_hidden, 1), - ) + def forward(self, *args, **kwargs) -> torch.Tensor: + return self.reg(super().forward(*args, **kwargs)) # [b, 1] - def forward( - self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor - ) -> torch.Tensor: - smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt).unsqueeze(-2) # [b, 1, dmodel] - gene_out = self.gene_encoder(gene_src).unsqueeze(-2) # [b, 1, dmodel] - feat = self.cross_att(smiles_out, gene_out) # [b, 1, dmodel] +class MutSmisPairwise(dnets.DrugcellAdamrMutSmisXattn): - return self.reg(feat.squeeze(-2)) # [b, 1] + def __init__(self) -> None: + super().__init__(mut_conf=mutations.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE) + self.pairwise_rank = heads.PairwiseRankHead(self.smi_conf.d_model) + def forward(self, *args, **kwargs) -> torch.Tensor: + return self.pairwise_rank(super().forward(*args, **kwargs)) # [b, 2] -class MutSmiFullConnection(MutSmi): - """Regression model using fully connection.""" - - def __init__( - self, - gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, - smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, - ) -> None: - super().__init__(gene_conf, smiles_conf) - d_model = self.smiles_conf.d_model - d_hidden = d_model // 2 - self.reg = nn.Sequential( - nn.Linear(d_model, d_hidden), - nn.ReLU(), - nn.Dropout(0.1), - nn.Linear(d_hidden, 1), - ) - - def forward( - self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor - ) -> torch.Tensor: - smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) # [b, dmodel] - gene_out = self.gene_encoder(gene_src) # [b, dmodel] - feat = smiles_out + gene_out - - return self.reg(feat) # [b, 1] + def forward_cmp(self, *args, **kwargs) -> float: + """ + for infer, no batch dim + """ + out = self.forward(*args, **kwargs) + return (out[1] - out[0]).item() diff --git a/dooc/nets/__init__.py b/dooc/nets/__init__.py new file mode 100644 index 0000000..36b3758 --- /dev/null +++ b/dooc/nets/__init__.py @@ -0,0 +1,131 @@ +import torch +from torch import nn +from moltx.models import AdaMR +from dooc.nets.mutations import Drugcell + + +""" +Mutations(Individual Sample) and Smiles Interaction + +{MutationEnc}{SmileEnc}MutSmi{Add/Xattn}: 1 mut with 1 smi +{MutationEnc}{SmileEnc}MutSmis{Add/Xattn}: 1 mut with n smi +{MutationEnc}{SmileEnc}MutsSmi{Add/Xattn}: n mut with 1 smi +""" + + +class _DrugcellAdamr(nn.Module): + + def __init__(self, mut_conf, smi_conf) -> None: + super().__init__() + self.mut_conf = mut_conf + self.smi_conf = smi_conf + + self.mut_encoder = Drugcell(mut_conf) + self.smi_encoder = AdaMR(smi_conf) + + def load_ckpt(self, *ckpt_files: str) -> None: + self.load_state_dict( + torch.load(ckpt_files[0], map_location=torch.device("cpu")) + ) + + def load_pretrained_ckpt(self, mut_ckpt: str, smi_ckpt: str, freeze_mut: bool = False, freeze_smi: bool = False) -> None: + self.mut_encoder.load_ckpt(mut_ckpt) + self.smi_encoder.load_ckpt(smi_ckpt) + if freeze_smi: + self.smi_encoder.requires_grad_(False) + if freeze_mut: + self.mut_encoder.requires_grad_(False) + + +class DrugcellAdamrMutSmiAdd(_DrugcellAdamr): + + def forward( + self, mut_x: torch.Tensor, smi_src: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + """ + mut_x: [b, mut_seqlen] + smi_src, smi_tgt: [b, smi_seqlen] + """ + mut_out = self.mut_encoder(mut_x) + smi_out = self.smi_encoder.forward_feature(smi_src, smi_tgt) + return mut_out + smi_out # [b, dmodel] + + +class DrugcellAdamrMutSmiXattn(_DrugcellAdamr): + + def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None: + super().__init__(mut_conf, smi_conf) + d_model = self.smi_conf.d_model + layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) + self.cross_attn = nn.TransformerDecoder(layer, num_layers) + + def forward( + self, mut_x: torch.Tensor, smi_src: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + """ + mut_x: [b, mut_seqlen] + smi_src, smi_tgt: [b, smi_seqlen] + """ + mut_out = self.mut_encoder(mut_x).unsqueeze(-2) # [b, 1, dmodel] + smi_out = self.smi_encoder.forward_feature(smi_src, smi_tgt).unsqueeze(-2) # [b, 1, dmodel] + return self.cross_attn(smi_out, mut_out).squeeze(-2) # [b, dmodel] + + +class DrugcellAdamrMutSmisAdd(_DrugcellAdamr): + + def _forward_mut(self, mut_x: torch.Tensor) -> torch.Tensor: + """ + mut_x: [b, mut_seqlen] + out: [b, 1, dmodel] + """ + return self.mut_encoder(mut_x).unsqueeze(-2) + + def _forward_smi(self, smi_src: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + """ + smi_src: [b, n, smi_seqlen] + smi_tgt: [b, n, smi_seqlen] + out: [b, n, dmodel] + """ + batched = smi_src.dim() == 3 + if batched: + n = smi_src.shape[1] + smi_src = smi_src.reshape(-1, smi_src.shape[-1]) + smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1]) + out = self.smi_encoder.forward_feature(smi_src, smi_tgt) + return out.reshape(-1, n, out.shape[-1]) + return self.smi_encoder.forward_feature(smi_src, smi_tgt) + + def forward( + self, mut_x: torch.Tensor, smi_src: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + mut_out = self._forward_mut(mut_x) + smi_out = self._forward_smi(smi_src, smi_tgt) + return smi_out + mut_out # [b, n, dmodel] + + +class DrugcellAdamrMutSmisXattn(DrugcellAdamrMutSmisAdd): + + def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None: + super().__init__(mut_conf, smi_conf) + d_model = smi_conf.d_model + layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) + self.cross_attn = nn.TransformerDecoder(layer, num_layers) + + def forward( + self, mut_x: torch.Tensor, smi_src: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + mut_out = self._forward_mut(mut_x) + smi_out = self._forward_smi(smi_src, smi_tgt) + return self.cross_attn(smi_out, mut_out) # [b, n, dmodel] + + +class DrugcellAdamr2MutSmiAdd(nn.Module): + pass + + +class DrugcellAdamr2MutSmiXattn(nn.Module): + pass + + +class DrugcellAdamr2MutSmisAdd(nn.Module): + pass + + +class DrugcellAdamr2MutSmisXattn(DrugcellAdamr2MutSmisAdd): + pass diff --git a/dooc/nets/heads.py b/dooc/nets/heads.py new file mode 100644 index 0000000..cce640a --- /dev/null +++ b/dooc/nets/heads.py @@ -0,0 +1,36 @@ +import torch +from torch import nn + + +class RegHead(nn.Module): + def __init__(self, in_features: int): + super().__init__() + hidden_dim = in_features // 2 + self.mlp = nn.Sequential( + nn.Linear(in_features, hidden_dim), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(hidden_dim, 1) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +class PairwiseRankHead(nn.Module): + def __init__(self, d_features: int): + super().__init__() + self.mlp = nn.Sequential( + nn.Flatten(-2), + nn.Linear(d_features * 2, d_features), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(d_features, 2) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: [b, 2, d_features] + """ + assert x.size(-2) == 2 + return self.mlp(x) # [b, 2] 1: x1 > x2, 0: x1 <= x2 diff --git a/dooc/nets.py b/dooc/nets/mutations.py similarity index 73% rename from dooc/nets.py rename to dooc/nets/mutations.py index e40f15c..3ca37fa 100644 --- a/dooc/nets.py +++ b/dooc/nets/mutations.py @@ -1,12 +1,91 @@ import os +import typing +from dataclasses import dataclass +from collections import defaultdict import torch from torch import nn -from dataclasses import dataclass -from dooc.utils import load_gene_mapping, load_ontology +import networkx as nx +import networkx.algorithms.components.connected as nxacc +import networkx.algorithms.dag as nxadag + + +def _load_gene_mapping(file_path: str) -> dict: + res = {} + + with open(file_path) as f: + for line in f: + line = line.rstrip().split() + res[line[1]] = int(line[0]) + + return res + + +def _load_ontology(file_name: str, gene2id_mapping: dict) -> typing.Sequence: + dg = nx.DiGraph() + term_direct_gene_map = defaultdict(set) + + term_size_map, gene_set = {}, set() + + file_handle = open(file_name) + for line in file_handle: + line = line.rstrip().split() + if line[2] == "default": + dg.add_edge(line[0], line[1]) + continue + + if line[1] not in gene2id_mapping: + continue + if line[0] not in term_direct_gene_map: + term_direct_gene_map[line[0]] = set() + + term_direct_gene_map[line[0]].add(gene2id_mapping[line[1]]) + gene_set.add(line[1]) + file_handle.close() + + print("There are", len(gene_set), "genes") + + leaves = [] + for term in dg.nodes(): + term_gene_set = set() + if term in term_direct_gene_map: + term_gene_set = term_direct_gene_map[term] + + deslist = nxadag.descendants(dg, term) + + for child in deslist: + if child in term_direct_gene_map: + term_gene_set = term_gene_set | term_direct_gene_map[child] + + if len(term_gene_set) == 0: + raise ValueError(f"There is empty terms, please delete term: {term}") + + term_size_map[term] = len(term_gene_set) + + if dg.in_degree(term) == 0: + leaves.append(term) + + ug = dg.to_undirected() + connected_subg_list = list(nxacc.connected_components(ug)) + + print("There are", len(leaves), "roots:", leaves[0]) + print("There are", len(dg.nodes()), "terms") + print("There are", len(connected_subg_list), "connected componenets") + + if len(leaves) > 1: + raise ValueError( + "There are more than 1 root of ontology. Please use only one root." + ) + + if len(connected_subg_list) > 1: + raise ValueError( + "There are more than connected components. Please connect them." + ) + + return dg, leaves[0], term_size_map, term_direct_gene_map @dataclass -class GeneGNNConfig: +class DrugcellConfig: d_model: int gene_dim: int drug_dim: int @@ -17,25 +96,25 @@ class GeneGNNConfig: ont_path: str -class GeneGNN(nn.Module): +class Drugcell(nn.Module): """GNN for mutations embeddings. reference: https://github.com/idekerlab/DrugCell/ """ - DEFAULT_CONFIG = GeneGNNConfig( + DEFAULT_CONFIG = DrugcellConfig( d_model=768, gene_dim=3008, drug_dim=2048, num_hiddens_genotype=6, num_hiddens_drug=[100, 50, 6], num_hiddens_final=6, - gene2ind_path=os.path.join(os.path.dirname(__file__), "data", "gene2ind.txt"), - ont_path=os.path.join(os.path.dirname(__file__), "data", "drugcell_ont.txt"), + gene2ind_path=os.path.join(os.path.dirname(__file__), "../data", "gene2ind.txt"), + ont_path=os.path.join(os.path.dirname(__file__), "../data", "drugcell_ont.txt"), ) - def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None: + def __init__(self, conf: DrugcellConfig = DEFAULT_CONFIG) -> None: super().__init__() self.conf = conf d_model = self.conf.d_model @@ -171,8 +250,8 @@ def _cal_term_dim(self): self.term_dim_map[term] = num_output def _get_params(self): - gene2id_mapping = load_gene_mapping(self.conf.gene2ind_path) - dg, dg_root, term_size_map, term_direct_gene_map = load_ontology( + gene2id_mapping = _load_gene_mapping(self.conf.gene2ind_path) + dg, dg_root, term_size_map, term_direct_gene_map = _load_ontology( self.conf.ont_path, gene2id_mapping ) return dg, dg_root, term_size_map, term_direct_gene_map @@ -184,8 +263,9 @@ def load_ckpt(self, *ckpt_files: str) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ - removed drug layer + omit drug layer, cmp to origin drugcell """ + x = x.float() x_dim = x.dim() x = x.unsqueeze(0) if x_dim == 1 else x gene_input = x.narrow(1, 0, self.conf.gene_dim) diff --git a/dooc/pipelines.py b/dooc/pipelines.py index 23b1e4a..59ca2bc 100644 --- a/dooc/pipelines.py +++ b/dooc/pipelines.py @@ -1,69 +1,119 @@ import typing +from functools import cmp_to_key import torch +import torch.nn as nn from moltx import tokenizers -from dooc import models -class MutSmi: - def __init__( - self, - smi_tokenizer: tokenizers.MoltxTokenizer, - model: models.MutSmi, - device: torch.device = torch.device("cpu"), - ) -> None: +class _MutSmiBase: + def __init__(self, smi_tokenizer: tokenizers.MoltxTokenizer, model: nn.Module, device: torch.device = torch.device("cpu")) -> None: self.smi_tokenizer = smi_tokenizer - + self.device = device model = model.to(device) model.eval() model.requires_grad_(False) self.model = model - self.device = device + def _tokens2tensor(self, tokens: typing.Sequence[int], size: int = None) -> torch.Tensor: + if size is None: + size = len(tokens) + if len(tokens) > size: + raise IndexError('the length of tokens is greater than size!') + out = torch.zeros(size, dtype=torch.int) + for i, tk in enumerate(tokens): + out[i] = tk + return out.to(self.device) + + +class _DrugcellAdamrMutSmi(_MutSmiBase): def _model_args( - self, gene: typing.Sequence[int], smiles: str + self, mut: typing.Sequence[int], smi: str ) -> typing.Tuple[torch.Tensor]: - smi_src = self.gen_smi_token(smiles) - smi_tgt = self.gen_smi_token(self.smi_tokenizer.BOS + smiles + self.smi_tokenizer.EOS) - gene_src = self.gen_gene_token(gene) - return smi_src, smi_tgt, gene_src - - def gen_smi_token(self, smiles: str) -> torch.Tensor: - tokens = self.smi_tokenizer(smiles) - res = torch.zeros(len(tokens), dtype=torch.int) - for i, tk in enumerate(tokens): - res[i] = tk - return res.to(self.device) - - def gen_gene_token(self, gene: typing.Sequence[float]) -> torch.Tensor: - return torch.tensor(gene, dtype=torch.float).to(self.device) - - -class MutSmiXAttention(MutSmi): - def __init__( - self, - smi_tokenizer: tokenizers.MoltxTokenizer, - model: models.MutSmiXAttention, - device: torch.device = torch.device("cpu"), - ) -> None: - super().__init__(smi_tokenizer, model, device) - - def __call__(self, gene: typing.Sequence[int], smiles: str) -> float: - smi_src, smi_tgt, gene_src = self._model_args(gene, smiles) - pred = self.model(smi_src, smi_tgt, gene_src) - return pred.item() - - -class MutSmiFullConnection(MutSmi): - def __init__( - self, - smi_tokenizer: tokenizers.MoltxTokenizer, - model: models.MutSmiFullConnection, - device: torch.device = torch.device("cpu"), - ) -> None: - super().__init__(smi_tokenizer, model, device) - - def __call__(self, gene: typing.Sequence[int], smiles: str) -> float: - smi_src, smi_tgt, gene_src = self._model_args(gene, smiles) - pred = self.model(smi_src, smi_tgt, gene_src) - return pred.item() + mut_x = torch.tensor(mut, device=self.device) + smi_src = self._tokens2tensor(self.smi_tokenizer(smi)) + smi_tgt = self._tokens2tensor(self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS)) + return mut_x, smi_src, smi_tgt + + def reg(self, mut: typing.Sequence[int], smi: str) -> float: + return self.model(*self._model_args(mut, smi)).item() + + def cmp_smis_func(self, mut: typing.Sequence[int]) -> typing.Callable: + cmped = {} + + def cmp(smi1, smi2): + query = '-'.join([smi1, smi2]) + if query in cmped: + return cmped[query] + out1 = self.reg(mut, smi1) + out2 = self.reg(mut, smi2) + out = out1 - out2 + cmped[query] = out + return out + return cmp + + +class _DrugcellAdamrMutSmis(_MutSmiBase): + + def _smi_args( + self, smis: typing.Sequence[str] + ) -> typing.Tuple[torch.Tensor]: + smi_src = [self.smi_tokenizer(smi) for smi in smis] + smi_tgt = [self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS) for smi in smis] + size_src = max(map(len, smi_src)) + size_tgt = max(map(len, smi_tgt)) + smi_src = torch.concat([self._tokens2tensor(smi, size_src).unsqueeze(0) for smi in smi_src]) + smi_tgt = torch.concat([self._tokens2tensor(smi, size_tgt).unsqueeze(0) for smi in smi_tgt]) + return smi_src, smi_tgt + + def cmp_smis_func(self, mut: typing.Sequence[int]) -> typing.Callable: + mut_x = torch.tensor(mut, device=self.device) + cmped = {} + + def cmp(smi1, smi2): + smis = [smi1, smi2] + query = '-'.join(smis) + if query in cmped: + return cmped[query] + smi_src, smi_tgt = self._smi_args(smis) + out = self.model.forward_cmp(mut_x, smi_src, smi_tgt) + cmped[query] = out + return out + return cmp + + +class _DrugcellAdamr2MutSmi(_MutSmiBase): + pass + + +class _DrugcellAdamr2MutSmis(_MutSmiBase): + pass + + +class _MutSmiReg: + + def __call__(self, mut: typing.Sequence[int], smi: str) -> typing.Dict: + return self.reg(mut, smi) + + +class _MutSmisRank: + + def __call__(self, mut: typing.Sequence[int], smis: typing.Sequence[str]) -> typing.Sequence[str]: + return sorted(smis, key=cmp_to_key(self.cmp_smis_func(mut))) + + +""" +Mutations(Individual Sample) and Smiles Interaction + +MutSmiReg +MutSmisRank +MutsSmiRank +""" + + +class MutSmiReg(_DrugcellAdamrMutSmi, _MutSmiReg): + pass + + +class MutSmisRank(_DrugcellAdamrMutSmis, _MutSmisRank): + pass diff --git a/dooc/utils.py b/dooc/utils.py deleted file mode 100644 index 8ce035a..0000000 --- a/dooc/utils.py +++ /dev/null @@ -1,80 +0,0 @@ -import typing -import networkx as nx -from collections import defaultdict -import networkx.algorithms.components.connected as nxacc -import networkx.algorithms.dag as nxadag - - -def load_gene_mapping(file_path: str) -> dict: - res = {} - - with open(file_path) as f: - for line in f: - line = line.rstrip().split() - res[line[1]] = int(line[0]) - - return res - - -def load_ontology(file_name: str, gene2id_mapping: dict) -> typing.Sequence: - dg = nx.DiGraph() - term_direct_gene_map = defaultdict(set) - - term_size_map, gene_set = {}, set() - - file_handle = open(file_name) - for line in file_handle: - line = line.rstrip().split() - if line[2] == "default": - dg.add_edge(line[0], line[1]) - continue - - if line[1] not in gene2id_mapping: - continue - if line[0] not in term_direct_gene_map: - term_direct_gene_map[line[0]] = set() - - term_direct_gene_map[line[0]].add(gene2id_mapping[line[1]]) - gene_set.add(line[1]) - file_handle.close() - - print("There are", len(gene_set), "genes") - - leaves = [] - for term in dg.nodes(): - term_gene_set = set() - if term in term_direct_gene_map: - term_gene_set = term_direct_gene_map[term] - - deslist = nxadag.descendants(dg, term) - - for child in deslist: - if child in term_direct_gene_map: - term_gene_set = term_gene_set | term_direct_gene_map[child] - - if len(term_gene_set) == 0: - raise ValueError(f"There is empty terms, please delete term: {term}") - - term_size_map[term] = len(term_gene_set) - - if dg.in_degree(term) == 0: - leaves.append(term) - - ug = dg.to_undirected() - connected_subg_list = list(nxacc.connected_components(ug)) - - print("There are", len(leaves), "roots:", leaves[0]) - print("There are", len(dg.nodes()), "terms") - print("There are", len(connected_subg_list), "connected componenets") - - if len(leaves) > 1: - raise ValueError( - "There are more than 1 root of ontology. Please use only one root." - ) - - if len(connected_subg_list) > 1: - raise ValueError( - "There are more than connected components. Please connect them." - ) - - return dg, leaves[0], term_size_map, term_direct_gene_map diff --git a/tests/conftest.py b/tests/conftest.py index 259e1bc..a502795 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,16 @@ import os.path import pytest +from moltx import tokenizers as tkz +from moltx.models import AdaMRTokenizerConfig @pytest.fixture def datadir(): return os.path.join(os.path.dirname(__file__), '../dooc/data') + + +@pytest.fixture +def smi_tkz(): + return tkz.MoltxTokenizer.from_pretrain( + conf=AdaMRTokenizerConfig.Prediction + ) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index be3a539..32afd63 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,49 +1,63 @@ import pytest from dooc import datasets import random -from moltx import tokenizers as tkz -from moltx.models import AdaMRTokenizerConfig - - -def test_MutSmi(): - return - - -def test_MutSmiXAttention(): - tokenizer = tkz.MoltxTokenizer.from_pretrain( - conf=AdaMRTokenizerConfig.Prediction - ) - ds = datasets.MutSmiXAttention(tokenizer) - smiles = ["CC[N+]CCBr", "Cc1ccc1"] - values = [0.88, 0.89] - mutations = [[random.choice([0, 1]) for _ in range(3008)], - [random.choice([0, 1]) for _ in range(3008)]] - with pytest.raises(RuntimeError): - ds(smiles, mutations, values[:1]) - smiles_src, smiles_tgt, mutations_src, out = ds(smiles, - mutations, - values) - assert smiles_src.shape == (2, 200) - assert smiles_tgt.shape == (2, 200) - assert mutations_src.shape == (2, 3008) + + +def test_DrugcellAdamrMutSmi(smi_tkz): + ds = datasets._DrugcellAdamrMutSmi(smi_tkz) + smis = ["CC[N+]CCBr", "Cc1ccc1"] + vals = [0.88, 0.89] + muts = [[random.choice([0, 1]) for _ in range(52)], + [random.choice([0, 1]) for _ in range(52)]] + with pytest.raises(AssertionError): + ds(muts, smis, vals[:1]) + mut_x, smi_src, smi_tgt, out = ds(muts, smis, vals) + assert smi_src.shape == (2, 200) + assert smi_tgt.shape == (2, 200) + assert mut_x.shape == (2, 52) + assert out.shape == (2, 1) + + +def test_DrugcellAdamrMutSmis(smi_tkz): + ds = datasets._DrugcellAdamrMutSmis(smi_tkz) + lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]] + lvals = [[0.88, 0.89], [0.82, 0.9]] + muts = [[random.choice([0, 1]) for _ in range(52)], + [random.choice([0, 1]) for _ in range(52)]] + with pytest.raises(AssertionError): + ds(muts, lsmis, lvals[:1]) + mut_x, smi_src, smi_tgt, out = ds(muts, lsmis, lvals) + assert smi_src.shape == (2, 2, 200) + assert smi_tgt.shape == (2, 2, 200) + assert mut_x.shape == (2, 52) + assert out.shape == (2, 2) + + +def test_MutSmiReg(smi_tkz): + ds = datasets.MutSmiReg(smi_tkz) + smis = ["CC[N+]CCBr", "Cc1ccc1"] + vals = [0.88, 0.89] + muts = [[random.choice([0, 1]) for _ in range(52)], + [random.choice([0, 1]) for _ in range(52)]] + with pytest.raises(AssertionError): + ds(muts, smis, vals[:1]) + mut_x, smi_src, smi_tgt, out = ds(muts, smis, vals) + assert smi_src.shape == (2, 200) + assert smi_tgt.shape == (2, 200) + assert mut_x.shape == (2, 52) assert out.shape == (2, 1) -def test_MutSmiFullConnection(): - tokenizer = tkz.MoltxTokenizer.from_pretrain( - conf=AdaMRTokenizerConfig.Prediction - ) - ds = datasets.MutSmiFullConnection(tokenizer) - smiles = ["CC[N+]CCBr", "Cc1ccc1"] - values = [0.88, 0.89] - mutations = [[random.choice([0, 1]) for _ in range(3008)], - [random.choice([0, 1]) for _ in range(3008)]] - with pytest.raises(RuntimeError): - ds(smiles, mutations, values[:1]) - smiles_src, smiles_tgt, mutations_src, out = ds(smiles, - mutations, - values) - assert smiles_src.shape == (2, 200) - assert smiles_tgt.shape == (2, 200) - assert mutations_src.shape == (2, 3008) +def test_MutSmisPairwise(smi_tkz): + ds = datasets.MutSmisPairwise(smi_tkz) + lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]] + lvals = [[0.88, 0.89], [0.82, 0.9]] + muts = [[random.choice([0, 1]) for _ in range(52)], + [random.choice([0, 1]) for _ in range(52)]] + with pytest.raises(AssertionError): + ds(muts, lsmis, lvals[:1]) + mut_x, smi_src, smi_tgt, out = ds(muts, lsmis, lvals) + assert smi_src.shape == (2, 2, 200) + assert smi_tgt.shape == (2, 2, 200) + assert mut_x.shape == (2, 52) assert out.shape == (2, 1) diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index db9103b..0000000 --- a/tests/test_models.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -from dooc import models -import random - - -def test_MutSmiXAttention(): - smiles_src = torch.randint(0, 64, [2, 200]) - smiles_tgt = torch.randint(0, 64, [2, 200]) - mutations = [[random.choice([0, 1]) for _ in range(3008)], - [random.choice([0, 1]) for _ in range(3008)]] - mutations_src = torch.tensor(mutations, dtype=torch.float) - model = models.MutSmiXAttention() - out = model(smiles_src, smiles_tgt, mutations_src) - assert out.shape == (2, 1) - - -def test_MutSmiFullConnection(): - smiles_src = torch.randint(0, 64, [2, 200]) - smiles_tgt = torch.randint(0, 64, [2, 200]) - mutations = [[random.choice([0, 1]) for _ in range(3008)], - [random.choice([0, 1]) for _ in range(3008)]] - mutations_src = torch.tensor(mutations, dtype=torch.float) - model = models.MutSmiFullConnection() - out = model(smiles_src, smiles_tgt, mutations_src) - assert out.shape == (2, 1) diff --git a/tests/test_nets.py b/tests/test_nets.py index c603407..edbd683 100644 --- a/tests/test_nets.py +++ b/tests/test_nets.py @@ -1,6 +1,60 @@ -from dooc import nets +import pytest import random +from moltx import models as mmodel +from dooc.nets import mutations +from dooc import nets, datasets -def test_GNN(): - return +@pytest.fixture +def adamr_conf(): + return mmodel.AdaMR.CONFIG_BASE + +@pytest.fixture +def drugcell_conf(): + return mutations.Drugcell.DEFAULT_CONFIG + +@pytest.fixture +def drugcell_adamr_mut_smi_ds(smi_tkz): + smis = ["CC[N+]CCBr", "Cc1ccc1"] + vals = [0.88, 0.89] + muts = [[random.choice([0, 1]) for _ in range(3008)], + [random.choice([0, 1]) for _ in range(3008)]] + ds = datasets._DrugcellAdamrMutSmi(smi_tkz) + return ds(muts, smis, vals) + +@pytest.fixture +def drugcell_adamr_mut_smis_ds(smi_tkz): + lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]] + lvals = [[0.88, 0.89], [0.82, 0.9]] + muts = [[random.choice([0, 1]) for _ in range(3008)], + [random.choice([0, 1]) for _ in range(3008)]] + ds = datasets._DrugcellAdamrMutSmis(smi_tkz) + return ds(muts, lsmis, lvals) + + +def test_DrugcellAdamrMutSmi(adamr_conf, drugcell_conf, drugcell_adamr_mut_smi_ds): + label = drugcell_adamr_mut_smi_ds[-1] + + model = nets.DrugcellAdamrMutSmiAdd(drugcell_conf, adamr_conf) + out = model(*drugcell_adamr_mut_smi_ds[:-1]) + assert out.dim() == 2 + assert out.size(0) == label.size(0) + + model = nets.DrugcellAdamrMutSmiXattn(drugcell_conf, adamr_conf) + out = model(*drugcell_adamr_mut_smi_ds[:-1]) + assert out.dim() == 2 + assert out.size(0) == label.size(0) + + +def test_DrugcellAdamrMutSmis(adamr_conf, drugcell_conf, drugcell_adamr_mut_smis_ds): + label = drugcell_adamr_mut_smis_ds[-1] + + model = nets.DrugcellAdamrMutSmisAdd(drugcell_conf, adamr_conf) + out = model(*drugcell_adamr_mut_smis_ds[:-1]) + assert out.dim() == 3 + assert out.size(0) == label.size(0) and out.size(1) == label.size(1) + + model = nets.DrugcellAdamrMutSmisXattn(drugcell_conf, adamr_conf) + out = model(*drugcell_adamr_mut_smis_ds[:-1]) + assert out.dim() == 3 + assert out.size(0) == label.size(0) and out.size(1) == label.size(1) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index c95efa9..58f6850 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1,32 +1,51 @@ import random from dooc import pipelines, models -from moltx import tokenizers as tkz -from moltx.models import AdaMRTokenizerConfig - - -def test_MutSmiXAttention(): - tokenizer = tkz.MoltxTokenizer.from_pretrain( - conf=AdaMRTokenizerConfig.Prediction - ) - model = models.MutSmiXAttention() - # model.load_ckpt('/path/to/mutsmixattention.ckpt') - pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer, - model=model) + + +def test_DrugcellAdamrMutSmiReg(smi_tkz): + mutation = [random.choice([1, 0]) for _ in range(3008)] smiles = "CC[N+](C)(C)Cc1ccccc1Br" - predict = pipeline(mutation, smiles) - assert isinstance(predict, float) - - -def test_MutSmiFullConnection(): - tokenizer = tkz.MoltxTokenizer.from_pretrain( - conf=AdaMRTokenizerConfig.Prediction - ) - model = models.MutSmiFullConnection() - # model.load_ckpt('/path/to/mutsmifullconnection.ckpt') - pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer, - model=model) + + class Reg(pipelines._DrugcellAdamrMutSmi, pipelines._MutSmiReg): + pass + + model = models.MutSmiReg() + pipeline = Reg(smi_tokenizer=smi_tkz, + model=model) + out = pipeline(mutation, smiles) + assert isinstance(out, float) + + +def test_DrugcellAdamr2MutSmiReg(smi_tkz): + pass + + +def test_DrugcellAdamrMutSmisRank(smi_tkz): + mutation = [random.choice([1, 0]) for _ in range(3008)] - smiles = "CC[N+](C)(C)Cc1ccccc1Br" - predict = pipeline(mutation, smiles) - assert isinstance(predict, float) + smiles = ["CC[N+](C)(C)Cc1ccccc1Br", "CC[N+](C)(C)Cc1ccccc1Br", "c1cccc1c"] + + class Pointwise(pipelines._DrugcellAdamrMutSmi, pipelines._MutSmisRank): + pass + + model = models.MutSmiReg() + pipeline = Pointwise(smi_tokenizer=smi_tkz, model=model) + out = pipeline(mutation, smiles) + assert isinstance(out, list) + assert len(out) == 3 + assert out[1] == "CC[N+](C)(C)Cc1ccccc1Br" + + class PairListRank(pipelines._DrugcellAdamrMutSmis, pipelines._MutSmisRank): + pass + + model = models.MutSmisPairwise() + pipeline = PairListRank(smi_tokenizer=smi_tkz, model=model) + out = pipeline(mutation, smiles) + assert isinstance(out, list) + assert len(out) == 3 + assert out[1] == "CC[N+](C)(C)Cc1ccccc1Br" + + +def test_DrugcellAdamr2MutSmisRank(smi_tkz): + pass