diff --git a/dooc/datasets.py b/dooc/datasets.py index 59517da..941ae25 100644 --- a/dooc/datasets.py +++ b/dooc/datasets.py @@ -92,16 +92,53 @@ def __call__( return mut_x, smi_src, smi_tgt, out -class _DrugcellAdamr2Base(_SmiMutBase): - pass +class _DrugcellAdamr2Base(_DrugcellAdamrBase): + def _smi_tokens( + self, + smiles: typing.Sequence[str], + seq_len: int = 200, + ) -> torch.Tensor: + return self._smi_tokenize( + [f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len) class _DrugcellAdamr2MutSmi(_DrugcellAdamr2Base): - pass + def __call__( + self, + muts: typing.Sequence[list], + smis: typing.Sequence[str], + vals: typing.Sequence[float], + seq_len: int = 200 + ) -> typing.Tuple[torch.Tensor]: + assert len(smis) == len(vals) and len(muts) == len(vals) + mut_x = self._mut_tokens(muts) + smi_tgt = self._smi_tokens(smis, seq_len) + out = self._out(vals).unsqueeze(-1) + return mut_x, smi_tgt, out -class _DrugcellAdamr2MutSmisPairwiseRank(_DrugcellAdamr2Base): - pass +class _DrugcellAdamr2MutSmis(_DrugcellAdamr2Base): + def __call__( + self, + muts: typing.Sequence[list], + lsmis: typing.Sequence[typing.Sequence[str]], + lvals: typing.Sequence[typing.Sequence[float]], + seq_len: int = 200 + ) -> typing.Tuple[torch.Tensor]: + """ + muts: [mut1, mut2, ...] mut1: [gene1, gene2, ...] + bsmiles: [[smi11, smi12], [smi21, smi22], ...] + bvlaues: [[val11, val12], [val21, val22], ...] + """ + assert len(lsmis) == len(lvals) and len(muts) == len(lvals) + mut_x = self._mut_tokens(muts) + batchlen = len(lsmis) + listlen = len(lsmis[0]) + smiles = [smi for bsmi in lsmis for smi in bsmi] + smi_tgt = self._smi_tokens(smiles, seq_len) + smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_tgt.size(-1)) + out = self._out(lvals) + return mut_x, smi_tgt, out """ @@ -113,11 +150,11 @@ class _DrugcellAdamr2MutSmisPairwiseRank(_DrugcellAdamr2Base): """ -class MutSmiReg(_DrugcellAdamrMutSmi): +class MutSmiReg(_DrugcellAdamr2MutSmi): pass -class MutSmisPairwise(_DrugcellAdamrMutSmis): +class MutSmisPairwise(_DrugcellAdamr2MutSmis): def __call__( self, muts: typing.Sequence[list], @@ -125,7 +162,7 @@ def __call__( lvalues: typing.Sequence[typing.Sequence[float]], seq_len: int = 200 ) -> typing.Tuple[torch.Tensor]: - mut_x, smi_src, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len) + mut_x, smi_tgt, rout = super().__call__(muts, lsmiles, lvalues, seq_len) out = torch.zeros(rout.size(0), dtype=torch.long, device=self.device) out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1 - return mut_x, smi_src, smi_tgt, out + return mut_x, smi_tgt, out diff --git a/dooc/models.py b/dooc/models.py index 87dd599..c5559ba 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -1,7 +1,8 @@ import torch +from moltx import nets as mnets from moltx import models as mmodels from dooc import nets as dnets -from dooc.nets import heads +from dooc.nets import heads, drugcell """ @@ -13,20 +14,20 @@ """ -class MutSmiReg(dnets.DrugcellAdamrMutSmiXattn): +class MutSmiReg(dnets.DrugcellAdamr2MutSmiXattn): - def __init__(self) -> None: - super().__init__(mut_conf=dnets.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE) + def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None: + super().__init__(mut_conf, smi_conf) self.reg = heads.RegHead(self.smi_conf.d_model) def forward(self, *args, **kwargs) -> torch.Tensor: return self.reg(super().forward(*args, **kwargs)) # [b, 1] -class MutSmisPairwise(dnets.DrugcellAdamrMutSmisXattn): +class MutSmisPairwise(dnets.DrugcellAdamr2MutSmisXattn): - def __init__(self) -> None: - super().__init__(mut_conf=dnets.Drugcell.DEFAULT_CONFIG, smi_conf=mmodels.AdaMR.CONFIG_BASE) + def __init__(self, mut_conf: drugcell.DrugcellConfig = dnets.Drugcell.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None: + super().__init__(mut_conf, smi_conf) self.pairwise_rank = heads.PairwiseRankHead(self.smi_conf.d_model) def forward(self, *args, **kwargs) -> torch.Tensor: diff --git a/dooc/nets/__init__.py b/dooc/nets/__init__.py index 66f5be8..13bb44a 100644 --- a/dooc/nets/__init__.py +++ b/dooc/nets/__init__.py @@ -1,6 +1,6 @@ import torch from torch import nn -from moltx.models import AdaMR +from moltx.models import AdaMR, AdaMR2 from dooc.nets.drugcell import Drugcell @@ -115,17 +115,97 @@ def forward( return self.cross_attn(smi_out, mut_out) # [b, n, dmodel] -class DrugcellAdamr2MutSmiAdd(nn.Module): - pass +class _DrugcellAdamr2(nn.Module): + def __init__(self, mut_conf, smi_conf) -> None: + super().__init__() + self.mut_conf = mut_conf + self.smi_conf = smi_conf + + self.mut_encoder = Drugcell(mut_conf) + self.smi_encoder = AdaMR2(smi_conf) + + def load_ckpt(self, *ckpt_files: str) -> None: + self.load_state_dict( + torch.load(ckpt_files[0], map_location=torch.device("cpu")) + ) + + def load_pretrained_ckpt(self, mut_ckpt: str, smi_ckpt: str, freeze_mut: bool = False, freeze_smi: bool = False) -> None: + self.mut_encoder.load_ckpt(mut_ckpt) + self.smi_encoder.load_ckpt(smi_ckpt) + if freeze_smi: + self.smi_encoder.requires_grad_(False) + if freeze_mut: + self.mut_encoder.requires_grad_(False) -class DrugcellAdamr2MutSmiXattn(nn.Module): - pass +class DrugcellAdamr2MutSmiAdd(_DrugcellAdamr2): + def forward( + self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + """ + mut_x: [b, mut_seqlen] + smi_tgt: [b, smi_seqlen] + """ + mut_out = self.mut_encoder(mut_x) + smi_out = self.smi_encoder.forward_feature(smi_tgt) + return mut_out + smi_out # [b, dmodel] -class DrugcellAdamr2MutSmisAdd(nn.Module): - pass + +class DrugcellAdamr2MutSmiXattn(_DrugcellAdamr2): + def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None: + super().__init__(mut_conf, smi_conf) + d_model = self.smi_conf.d_model + layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) + self.cross_attn = nn.TransformerDecoder(layer, num_layers) + + def forward( + self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + """ + mut_x: [b, mut_seqlen] + smi_tgt: [b, smi_seqlen] + """ + mut_out = self.mut_encoder(mut_x).unsqueeze(-2) # [b, 1, dmodel] + smi_out = self.smi_encoder.forward_feature(smi_tgt).unsqueeze(-2) # [b, 1, dmodel] + return self.cross_attn(smi_out, mut_out).squeeze(-2) # [b, dmodel] + + +class DrugcellAdamr2MutSmisAdd(_DrugcellAdamr2): + def _forward_mut(self, mut_x: torch.Tensor) -> torch.Tensor: + """ + mut_x: [b, mut_seqlen] + out: [b, 1, dmodel] + """ + return self.mut_encoder(mut_x).unsqueeze(-2) + + def _forward_smi(self, smi_tgt: torch.Tensor) -> torch.Tensor: + """ + smi_tgt: [b, n, smi_seqlen] + out: [b, n, dmodel] + """ + batched = smi_tgt.dim() == 3 + if batched: + n = smi_tgt.shape[1] + smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1]) + out = self.smi_encoder.forward_feature(smi_tgt) + return out.reshape(-1, n, out.shape[-1]) + return self.smi_encoder.forward_feature(smi_tgt) + + def forward( + self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + mut_out = self._forward_mut(mut_x) + smi_out = self._forward_smi(smi_tgt) + return smi_out + mut_out # [b, n, dmodel] class DrugcellAdamr2MutSmisXattn(DrugcellAdamr2MutSmisAdd): - pass + def __init__(self, mut_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None: + super().__init__(mut_conf, smi_conf) + d_model = smi_conf.d_model + layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True) + self.cross_attn = nn.TransformerDecoder(layer, num_layers) + + def forward( + self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor: + mut_out = self._forward_mut(mut_x) + smi_out = self._forward_smi(smi_tgt) + return self.cross_attn(smi_out, mut_out) # [b, n, dmodel] diff --git a/dooc/nets/drugcell.py b/dooc/nets/drugcell.py index 44b02d2..d0c8ee4 100644 --- a/dooc/nets/drugcell.py +++ b/dooc/nets/drugcell.py @@ -189,7 +189,7 @@ def _cal_term_dim(self): term_size_map is the number of all genes annotated with the term """ self.term_dim_map = {} - for term, term_size in self.term_size_map.items(): + for term in self.term_size_map.keys(): num_output = self.conf.num_hiddens_genotype # log the number of hidden variables per each term @@ -220,14 +220,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # define forward function for genotype dcell ############################################# term_gene_out_map = {} - for term, _ in self.term_direct_gene_map.items(): + for term in self.term_direct_gene_map.keys(): term_gene_out_map[term] = self._modules[term + "_direct_gene_layer"]( gene_input ) term_nn_out_map = {} - for _, layer in enumerate(self.term_layer_list): + for layer in self.term_layer_list: for term in layer: diff --git a/dooc/pipelines.py b/dooc/pipelines.py index 59ca2bc..5d533aa 100644 --- a/dooc/pipelines.py +++ b/dooc/pipelines.py @@ -25,15 +25,12 @@ def _tokens2tensor(self, tokens: typing.Sequence[int], size: int = None) -> torc return out.to(self.device) -class _DrugcellAdamrMutSmi(_MutSmiBase): +class _MutSmi(_MutSmiBase): - def _model_args( - self, mut: typing.Sequence[int], smi: str - ) -> typing.Tuple[torch.Tensor]: + def _model_args(self, mut: typing.Sequence[int], smi: str) -> typing.Tuple[torch.Tensor]: mut_x = torch.tensor(mut, device=self.device) - smi_src = self._tokens2tensor(self.smi_tokenizer(smi)) smi_tgt = self._tokens2tensor(self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS)) - return mut_x, smi_src, smi_tgt + return mut_x, smi_tgt def reg(self, mut: typing.Sequence[int], smi: str) -> float: return self.model(*self._model_args(mut, smi)).item() @@ -53,18 +50,15 @@ def cmp(smi1, smi2): return cmp -class _DrugcellAdamrMutSmis(_MutSmiBase): +class _MutSmis(_MutSmiBase): def _smi_args( self, smis: typing.Sequence[str] - ) -> typing.Tuple[torch.Tensor]: - smi_src = [self.smi_tokenizer(smi) for smi in smis] + ) -> torch.Tensor: smi_tgt = [self.smi_tokenizer(self.smi_tokenizer.BOS + smi + self.smi_tokenizer.EOS) for smi in smis] - size_src = max(map(len, smi_src)) size_tgt = max(map(len, smi_tgt)) - smi_src = torch.concat([self._tokens2tensor(smi, size_src).unsqueeze(0) for smi in smi_src]) smi_tgt = torch.concat([self._tokens2tensor(smi, size_tgt).unsqueeze(0) for smi in smi_tgt]) - return smi_src, smi_tgt + return smi_tgt def cmp_smis_func(self, mut: typing.Sequence[int]) -> typing.Callable: mut_x = torch.tensor(mut, device=self.device) @@ -75,21 +69,13 @@ def cmp(smi1, smi2): query = '-'.join(smis) if query in cmped: return cmped[query] - smi_src, smi_tgt = self._smi_args(smis) - out = self.model.forward_cmp(mut_x, smi_src, smi_tgt) + smi_tgt = self._smi_args(smis) + out = self.model.forward_cmp(mut_x, smi_tgt) cmped[query] = out return out return cmp -class _DrugcellAdamr2MutSmi(_MutSmiBase): - pass - - -class _DrugcellAdamr2MutSmis(_MutSmiBase): - pass - - class _MutSmiReg: def __call__(self, mut: typing.Sequence[int], smi: str) -> typing.Dict: @@ -111,9 +97,9 @@ def __call__(self, mut: typing.Sequence[int], smis: typing.Sequence[str]) -> typ """ -class MutSmiReg(_DrugcellAdamrMutSmi, _MutSmiReg): +class MutSmiReg(_MutSmi, _MutSmiReg): pass -class MutSmisRank(_DrugcellAdamrMutSmis, _MutSmisRank): +class MutSmisRank(_MutSmis, _MutSmisRank): pass diff --git a/tests/test_datasets.py b/tests/test_datasets.py index f24088f..cb85f6a 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -18,6 +18,20 @@ def test_DrugcellAdamrMutSmi(smi_tkz): assert out.shape == (2, 1) +def test_DrugcellAdamr2MutSmi(smi_tkz): + ds = datasets._DrugcellAdamr2MutSmi(smi_tkz) + smis = ["CC[N+]CCBr", "Cc1ccc1"] + vals = [0.88, 0.89] + muts = [[random.choice([0, 1]) for _ in range(52)], + [random.choice([0, 1]) for _ in range(52)]] + with pytest.raises(AssertionError): + ds(muts, smis, vals[:1]) + mut_x, smi_tgt, out = ds(muts, smis, vals) + assert smi_tgt.shape == (2, 200) + assert mut_x.shape == (2, 52) + assert out.shape == (2, 1) + + def test_DrugcellAdamrMutSmis(smi_tkz): ds = datasets._DrugcellAdamrMutSmis(smi_tkz) lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]] @@ -33,6 +47,20 @@ def test_DrugcellAdamrMutSmis(smi_tkz): assert out.shape == (2, 2) +def test_DrugcellAdamr2MutSmis(smi_tkz): + ds = datasets._DrugcellAdamr2MutSmis(smi_tkz) + lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]] + lvals = [[0.88, 0.89], [0.82, 0.9]] + muts = [[random.choice([0, 1]) for _ in range(52)], + [random.choice([0, 1]) for _ in range(52)]] + with pytest.raises(AssertionError): + ds(muts, lsmis, lvals[:1]) + mut_x, smi_tgt, out = ds(muts, lsmis, lvals) + assert smi_tgt.shape == (2, 2, 200) + assert mut_x.shape == (2, 52) + assert out.shape == (2, 2) + + def test_MutSmiReg(smi_tkz): ds = datasets.MutSmiReg(smi_tkz) smis = ["CC[N+]CCBr", "Cc1ccc1"] @@ -41,8 +69,7 @@ def test_MutSmiReg(smi_tkz): [random.choice([0, 1]) for _ in range(52)]] with pytest.raises(AssertionError): ds(muts, smis, vals[:1]) - mut_x, smi_src, smi_tgt, out = ds(muts, smis, vals) - assert smi_src.shape == (2, 200) + mut_x, smi_tgt, out = ds(muts, smis, vals) assert smi_tgt.shape == (2, 200) assert mut_x.shape == (2, 52) assert out.shape == (2, 1) @@ -56,8 +83,7 @@ def test_MutSmisPairwise(smi_tkz): [random.choice([0, 1]) for _ in range(52)]] with pytest.raises(AssertionError): ds(muts, lsmis, lvals[:1]) - mut_x, smi_src, smi_tgt, out = ds(muts, lsmis, lvals) - assert smi_src.shape == (2, 2, 200) + mut_x, smi_tgt, out = ds(muts, lsmis, lvals) assert smi_tgt.shape == (2, 2, 200) assert mut_x.shape == (2, 52) assert out.shape == (2,) diff --git a/tests/test_nets.py b/tests/test_nets.py index 2de3437..c523bec 100644 --- a/tests/test_nets.py +++ b/tests/test_nets.py @@ -8,6 +8,10 @@ def adamr_conf(): return mmodel.AdaMR.CONFIG_BASE +@pytest.fixture +def adamr2_conf(): + return mmodel.AdaMR2.CONFIG_LARGE + @pytest.fixture def drugcell_conf(): return nets.Drugcell.DEFAULT_CONFIG @@ -21,6 +25,15 @@ def drugcell_adamr_mut_smi_ds(smi_tkz): ds = datasets._DrugcellAdamrMutSmi(smi_tkz) return ds(muts, smis, vals) +@pytest.fixture +def drugcell_adamr2_mut_smi_ds(smi_tkz): + smis = ["CC[N+]CCBr", "Cc1ccc1"] + vals = [0.88, 0.89] + muts = [[random.choice([0, 1]) for _ in range(3008)], + [random.choice([0, 1]) for _ in range(3008)]] + ds = datasets._DrugcellAdamr2MutSmi(smi_tkz) + return ds(muts, smis, vals) + @pytest.fixture def drugcell_adamr_mut_smis_ds(smi_tkz): lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]] @@ -30,6 +43,14 @@ def drugcell_adamr_mut_smis_ds(smi_tkz): ds = datasets._DrugcellAdamrMutSmis(smi_tkz) return ds(muts, lsmis, lvals) +@pytest.fixture +def drugcell_adamr2_mut_smis_ds(smi_tkz): + lsmis = [["CC[N+]CCBr", "Cc1ccc1"], ["CCC[N+]CCBr", "CCc1ccc1"]] + lvals = [[0.88, 0.89], [0.82, 0.9]] + muts = [[random.choice([0, 1]) for _ in range(3008)], + [random.choice([0, 1]) for _ in range(3008)]] + ds = datasets._DrugcellAdamr2MutSmis(smi_tkz) + return ds(muts, lsmis, lvals) def test_DrugcellAdamrMutSmi(adamr_conf, drugcell_conf, drugcell_adamr_mut_smi_ds): label = drugcell_adamr_mut_smi_ds[-1] @@ -45,6 +66,20 @@ def test_DrugcellAdamrMutSmi(adamr_conf, drugcell_conf, drugcell_adamr_mut_smi_d assert out.size(0) == label.size(0) +def test_DrugcellAdamr2MutSmi(adamr2_conf, drugcell_conf, drugcell_adamr2_mut_smi_ds): + label = drugcell_adamr2_mut_smi_ds[-1] + + model = nets.DrugcellAdamr2MutSmiAdd(drugcell_conf, adamr2_conf) + out = model(*drugcell_adamr2_mut_smi_ds[:-1]) + assert out.dim() == 2 + assert out.size(0) == label.size(0) + + model = nets.DrugcellAdamr2MutSmiXattn(drugcell_conf, adamr2_conf) + out = model(*drugcell_adamr2_mut_smi_ds[:-1]) + assert out.dim() == 2 + assert out.size(0) == label.size(0) + + def test_DrugcellAdamrMutSmis(adamr_conf, drugcell_conf, drugcell_adamr_mut_smis_ds): label = drugcell_adamr_mut_smis_ds[-1] @@ -57,3 +92,17 @@ def test_DrugcellAdamrMutSmis(adamr_conf, drugcell_conf, drugcell_adamr_mut_smis out = model(*drugcell_adamr_mut_smis_ds[:-1]) assert out.dim() == 3 assert out.size(0) == label.size(0) and out.size(1) == label.size(1) + + +def test_DrugcellAdamr2MutSmis(adamr2_conf, drugcell_conf, drugcell_adamr2_mut_smis_ds): + label = drugcell_adamr2_mut_smis_ds[-1] + + model = nets.DrugcellAdamr2MutSmisAdd(drugcell_conf, adamr2_conf) + out = model(*drugcell_adamr2_mut_smis_ds[:-1]) + assert out.dim() == 3 + assert out.size(0) == label.size(0) and out.size(1) == label.size(1) + + model = nets.DrugcellAdamr2MutSmisXattn(drugcell_conf, adamr2_conf) + out = model(*drugcell_adamr2_mut_smis_ds[:-1]) + assert out.dim() == 3 + assert out.size(0) == label.size(0) and out.size(1) == label.size(1) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 58f6850..538df15 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -2,31 +2,22 @@ from dooc import pipelines, models -def test_DrugcellAdamrMutSmiReg(smi_tkz): - +def test_MutSmiReg(smi_tkz): mutation = [random.choice([1, 0]) for _ in range(3008)] smiles = "CC[N+](C)(C)Cc1ccccc1Br" - class Reg(pipelines._DrugcellAdamrMutSmi, pipelines._MutSmiReg): - pass - model = models.MutSmiReg() - pipeline = Reg(smi_tokenizer=smi_tkz, + pipeline = pipelines.MutSmiReg(smi_tokenizer=smi_tkz, model=model) out = pipeline(mutation, smiles) assert isinstance(out, float) -def test_DrugcellAdamr2MutSmiReg(smi_tkz): - pass - - -def test_DrugcellAdamrMutSmisRank(smi_tkz): - +def test_MutSmisRank(smi_tkz): mutation = [random.choice([1, 0]) for _ in range(3008)] smiles = ["CC[N+](C)(C)Cc1ccccc1Br", "CC[N+](C)(C)Cc1ccccc1Br", "c1cccc1c"] - class Pointwise(pipelines._DrugcellAdamrMutSmi, pipelines._MutSmisRank): + class Pointwise(pipelines._MutSmi, pipelines._MutSmisRank): pass model = models.MutSmiReg() @@ -36,16 +27,9 @@ class Pointwise(pipelines._DrugcellAdamrMutSmi, pipelines._MutSmisRank): assert len(out) == 3 assert out[1] == "CC[N+](C)(C)Cc1ccccc1Br" - class PairListRank(pipelines._DrugcellAdamrMutSmis, pipelines._MutSmisRank): - pass - model = models.MutSmisPairwise() - pipeline = PairListRank(smi_tokenizer=smi_tkz, model=model) + pipeline = pipelines.MutSmisRank(smi_tokenizer=smi_tkz, model=model) out = pipeline(mutation, smiles) assert isinstance(out, list) assert len(out) == 3 assert out[1] == "CC[N+](C)(C)Cc1ccccc1Br" - - -def test_DrugcellAdamr2MutSmisRank(smi_tkz): - pass