From e165fd9a1b5eea1fffd2d831e5bc81dcf3d0cf39 Mon Sep 17 00:00:00 2001 From: zzzseeu Date: Wed, 22 May 2024 03:41:10 +0000 Subject: [PATCH 1/4] feat: MutSmiXAttention models-pipelines-datasets-test --- dooc/models.py | 28 +++++++++++++++++++++++++++- dooc/pipelines.py | 4 +++- tests/test_datasets.py | 36 ++++++++++++++++++------------------ tests/test_models.py | 33 +++++++++++++-------------------- tests/test_pipelines.py | 41 +++++++++++++++++------------------------ 5 files changed, 78 insertions(+), 64 deletions(-) diff --git a/dooc/models.py b/dooc/models.py index 9f459fd..673ba05 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -50,16 +50,42 @@ class MutSmiXAttention(MutSmi): def __init__( self, d_model: int, + nhead: int = 2, + num_layers: int = 2, gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) self.d_model = d_model + d_hidden = d_model // 2 + self.nhead = nhead + self.num_layers = num_layers + self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model) + decoder_layer = nn.TransformerDecoderLayer(self.d_model, + self.nhead) + self.decoder = nn.TransformerDecoder(decoder_layer, + num_layers=self.num_layers) + self.reg = nn.Sequential( + nn.Linear(self.d_model, d_hidden), + nn.ReLU(), + nn.Dropout(0.1), + nn.Linear(d_hidden, 1), + ) def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: - pass + is_batched = smiles_src.dim() == 2 + + smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) + gene_out = self.gene_fc(self.gene_encoder(gene_src)) + feat = None + if is_batched: + feat = self.decoder(smiles_out, gene_out) + else: + feat = self.decoder(smiles_out.unsqueeze(0), gene_out) + + return self.reg(feat) class MutSmiFullConnection(MutSmi): diff --git a/dooc/pipelines.py b/dooc/pipelines.py index 448f477..23b1e4a 100644 --- a/dooc/pipelines.py +++ b/dooc/pipelines.py @@ -49,7 +49,9 @@ def __init__( super().__init__(smi_tokenizer, model, device) def __call__(self, gene: typing.Sequence[int], smiles: str) -> float: - return + smi_src, smi_tgt, gene_src = self._model_args(gene, smiles) + pred = self.model(smi_src, smi_tgt, gene_src) + return pred.item() class MutSmiFullConnection(MutSmi): diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 70e741e..be3a539 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -9,24 +9,24 @@ def test_MutSmi(): return -# def test_MutSmiXAttention(): -# tokenizer = tkz.MoltxTokenizer.from_pretrain( -# conf=AdaMRTokenizerConfig.Prediction -# ) -# ds = datasets.MutSmiXAttention(tokenizer) -# smiles = ["CC[N+]CCBr", "Cc1ccc1"] -# values = [0.88, 0.89] -# mutations = [[random.choice([0, 1]) for _ in range(3008)], -# [random.choice([0, 1]) for _ in range(3008)]] -# with pytest.raises(RuntimeError): -# ds(smiles, mutations, values[:1]) -# smiles_src, smiles_tgt, mutations_src, out = ds(smiles, -# mutations, -# values) -# assert smiles_src.shape == (2, 4) -# assert smiles_tgt.shape == (2, 6) -# assert mutations_src.shape == (2, 3008) -# assert out.shape == (2, 1) +def test_MutSmiXAttention(): + tokenizer = tkz.MoltxTokenizer.from_pretrain( + conf=AdaMRTokenizerConfig.Prediction + ) + ds = datasets.MutSmiXAttention(tokenizer) + smiles = ["CC[N+]CCBr", "Cc1ccc1"] + values = [0.88, 0.89] + mutations = [[random.choice([0, 1]) for _ in range(3008)], + [random.choice([0, 1]) for _ in range(3008)]] + with pytest.raises(RuntimeError): + ds(smiles, mutations, values[:1]) + smiles_src, smiles_tgt, mutations_src, out = ds(smiles, + mutations, + values) + assert smiles_src.shape == (2, 200) + assert smiles_tgt.shape == (2, 200) + assert mutations_src.shape == (2, 3008) + assert out.shape == (2, 1) def test_MutSmiFullConnection(): diff --git a/tests/test_models.py b/tests/test_models.py index 227d344..5c8ede1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,34 +1,27 @@ import torch -from dooc import models, nets +from dooc import models import random -# def test_MutSmiXAttention(): -# smiles_src = torch.randint(0, 64, [2, 4]) -# smiles_tgt = torch.randint(0, 64, [2, 6]) -# mutations_src = torch.randn(2, 3008, dtype=torch.float32) -# d_model = 768 -# model = models.MutSmiXAttention(d_model) -# out = model(smiles_src, smiles_tgt, mutations_src) -# assert out.shape == (1,) +def test_MutSmiXAttention(): + smiles_src = torch.randint(0, 64, [2, 200]) + smiles_tgt = torch.randint(0, 64, [2, 200]) + mutations = [[random.choice([0, 1]) for _ in range(3008)], + [random.choice([0, 1]) for _ in range(3008)]] + mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu") + d_model = 768 + model = models.MutSmiXAttention(d_model) + out = model(smiles_src, smiles_tgt, mutations_src) + assert out.shape == (2, 1) -def test_MutSmiFullConnection(datadir): +def test_MutSmiFullConnection(): smiles_src = torch.randint(0, 64, [2, 200]) smiles_tgt = torch.randint(0, 64, [2, 200]) mutations = [[random.choice([0, 1]) for _ in range(3008)], [random.choice([0, 1]) for _ in range(3008)]] mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu") d_model = 768 - gene_conf = nets.GeneGNNConfig( - gene_dim=3008, - drug_dim=2048, - num_hiddens_genotype=6, - num_hiddens_drug=[100, 50, 6], - num_hiddens_final=6, - gene2ind_path=f"{datadir}/gene2ind.txt", - ont_path=f"{datadir}/drugcell_ont.txt", - ) - model = models.MutSmiFullConnection(d_model, gene_conf=gene_conf) + model = models.MutSmiFullConnection(d_model) out = model(smiles_src, smiles_tgt, mutations_src) assert out.shape == (2, 1) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 570094a..e7fa354 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -1,38 +1,31 @@ import random -from dooc import pipelines, models, nets +from dooc import pipelines, models from moltx import tokenizers as tkz from moltx.models import AdaMRTokenizerConfig -# def test_MutSmiXAttention(): -# tokenizer = tkz.MoltxTokenizer.from_pretrain( -# conf=AdaMRTokenizerConfig.Prediction -# ) -# d_model = 768 -# model = models.MutSmiXAttention(d_model) -# model.load_ckpt('/path/to/mutsmixattention.ckpt') -# pipeline = pipelines.MutSmiXAttention(tokenizer, model) -# mutation = [random.choice([1, 0]) for _ in range(3008)] -# smiles = "CC[N+](C)(C)Cc1ccccc1Br" -# predict = pipeline(mutation, smiles) -# assert isinstance(predict, float) +def test_MutSmiXAttention(): + tokenizer = tkz.MoltxTokenizer.from_pretrain( + conf=AdaMRTokenizerConfig.Prediction + ) + d_model = 768 + model = models.MutSmiXAttention(d_model) + # model.load_ckpt('/path/to/mutsmixattention.ckpt') + pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer, + model=model) + mutation = [[random.choice([1, 0]) for _ in range(3008)]] + smiles = "CC[N+](C)(C)Cc1ccccc1Br" + predict = pipeline(mutation, smiles) + print(predict) + assert isinstance(predict, float) -def test_MutSmiFullConnection(datadir): +def test_MutSmiFullConnection(): tokenizer = tkz.MoltxTokenizer.from_pretrain( conf=AdaMRTokenizerConfig.Prediction ) d_model = 768 - gene_conf = nets.GeneGNNConfig( - gene_dim=3008, - drug_dim=2048, - num_hiddens_genotype=6, - num_hiddens_drug=[100, 50, 6], - num_hiddens_final=6, - gene2ind_path=f"{datadir}/gene2ind.txt", - ont_path=f"{datadir}/drugcell_ont.txt", - ) - model = models.MutSmiFullConnection(d_model, gene_conf) + model = models.MutSmiFullConnection(d_model) # model.load_ckpt('/path/to/mutsmifullconnection.ckpt') pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer, model=model) From eac2322e458d9ad2ed22bc8dd4c766ff45c9430e Mon Sep 17 00:00:00 2001 From: seeuzhouxin <18986114551@163.com> Date: Wed, 22 May 2024 18:27:14 +0800 Subject: [PATCH 2/4] fix: del unused class vars & rename decoder & rewrite smiles/gene encoder --- dooc/models.py | 58 ++++++++++++++++++----------------------- dooc/nets.py | 6 +++-- dooc/pipelines.py | 2 ++ tests/test_models.py | 6 ++--- tests/test_pipelines.py | 11 +++----- 5 files changed, 38 insertions(+), 45 deletions(-) diff --git a/dooc/models.py b/dooc/models.py index 673ba05..37b5a3a 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -16,9 +16,21 @@ def __init__( super().__init__() self.gene_conf = gene_conf self.smiles_conf = smiles_conf - self.smiles_encoder = AdaMR(smiles_conf) + self.d_model = self.smiles_conf.d_model + self.init_smiles_encoder = AdaMR(smiles_conf) - self.gene_encoder = nets.GeneGNN(gene_conf) + self.init_gene_encoder = nets.GeneGNN(gene_conf) + self.gene_fc = nn.Linear(gene_conf.num_hiddens_genotype, self.d_model) + + def smiles_encoder(self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor) -> torch.Tensor: + smiles_out = self.init_smiles_encoder.forward_feature(smiles_src, smiles_tgt) + return smiles_out + # return smiles_out.unsqueeze(0) if smiles_out.dim() == 1 else smiles_out + + def gene_encoder(self, gene_src: torch.Tensor) -> torch.Tensor: + gene_out = self.init_gene_encoder(gene_src) + gene_out = self.gene_fc(gene_out) + return gene_out def load_ckpt(self, *ckpt_files: str) -> None: """load check point model. @@ -49,22 +61,17 @@ class MutSmiXAttention(MutSmi): def __init__( self, - d_model: int, nhead: int = 2, num_layers: int = 2, gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - self.d_model = d_model - d_hidden = d_model // 2 - self.nhead = nhead - self.num_layers = num_layers - self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model) + d_hidden = self.d_model // 2 decoder_layer = nn.TransformerDecoderLayer(self.d_model, - self.nhead) - self.decoder = nn.TransformerDecoder(decoder_layer, - num_layers=self.num_layers) + nhead) + self.smixmut_decoder = nn.TransformerDecoder(decoder_layer, + num_layers) self.reg = nn.Sequential( nn.Linear(self.d_model, d_hidden), nn.ReLU(), @@ -75,15 +82,11 @@ def __init__( def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: - is_batched = smiles_src.dim() == 2 - - smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) - gene_out = self.gene_fc(self.gene_encoder(gene_src)) + assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2 + smiles_out = self.smiles_encoder(smiles_src, smiles_tgt) + gene_out = self.gene_encoder(gene_src) feat = None - if is_batched: - feat = self.decoder(smiles_out, gene_out) - else: - feat = self.decoder(smiles_out.unsqueeze(0), gene_out) + feat = self.smixmut_decoder(smiles_out, gene_out) return self.reg(feat) @@ -93,15 +96,11 @@ class MutSmiFullConnection(MutSmi): def __init__( self, - d_model: int, gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - self.d_model = d_model - - d_hidden = d_model // 2 - self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model) + d_hidden = self.d_model // 2 self.reg = nn.Sequential( nn.Linear(self.d_model, d_hidden), nn.ReLU(), @@ -112,15 +111,10 @@ def __init__( def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: - is_batched = smiles_src.dim() == 2 - - smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) - gene_out = self.gene_encoder(gene_src).unsqueeze(1) + smiles_out = self.smiles_encoder(smiles_src, smiles_tgt) + gene_out = self.gene_encoder(gene_src) feat = None - if is_batched: - feat = smiles_out + self.gene_fc(gene_out)[:, 0] - else: - feat = smiles_out[0] + self.gene_fc(gene_out)[0] + feat = smiles_out + gene_out return self.reg(feat) diff --git a/dooc/nets.py b/dooc/nets.py index 9a0150e..3505ac4 100644 --- a/dooc/nets.py +++ b/dooc/nets.py @@ -181,7 +181,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ removed drug layer """ - + x_dim = x.dim() + x = x.unsqueeze(0) if x_dim == 1 else x gene_input = x.narrow(1, 0, self.conf.gene_dim) # drug_input = x.narrow(1, self.conf.gene_dim, self.conf.drug_dim) @@ -223,4 +224,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: aux_layer1_out ) - return term_nn_out_map[self.dg_root] + out = term_nn_out_map[self.dg_root] + return out diff --git a/dooc/pipelines.py b/dooc/pipelines.py index 23b1e4a..abe7a39 100644 --- a/dooc/pipelines.py +++ b/dooc/pipelines.py @@ -50,6 +50,8 @@ def __init__( def __call__(self, gene: typing.Sequence[int], smiles: str) -> float: smi_src, smi_tgt, gene_src = self._model_args(gene, smiles) + smi_src = smi_src.unsqueeze(0) + smi_tgt = smi_tgt.unsqueeze(0) pred = self.model(smi_src, smi_tgt, gene_src) return pred.item() diff --git a/tests/test_models.py b/tests/test_models.py index 5c8ede1..46fe205 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,8 +9,7 @@ def test_MutSmiXAttention(): mutations = [[random.choice([0, 1]) for _ in range(3008)], [random.choice([0, 1]) for _ in range(3008)]] mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu") - d_model = 768 - model = models.MutSmiXAttention(d_model) + model = models.MutSmiXAttention() out = model(smiles_src, smiles_tgt, mutations_src) assert out.shape == (2, 1) @@ -21,7 +20,6 @@ def test_MutSmiFullConnection(): mutations = [[random.choice([0, 1]) for _ in range(3008)], [random.choice([0, 1]) for _ in range(3008)]] mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu") - d_model = 768 - model = models.MutSmiFullConnection(d_model) + model = models.MutSmiFullConnection() out = model(smiles_src, smiles_tgt, mutations_src) assert out.shape == (2, 1) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e7fa354..c95efa9 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -8,15 +8,13 @@ def test_MutSmiXAttention(): tokenizer = tkz.MoltxTokenizer.from_pretrain( conf=AdaMRTokenizerConfig.Prediction ) - d_model = 768 - model = models.MutSmiXAttention(d_model) + model = models.MutSmiXAttention() # model.load_ckpt('/path/to/mutsmixattention.ckpt') pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer, model=model) - mutation = [[random.choice([1, 0]) for _ in range(3008)]] + mutation = [random.choice([1, 0]) for _ in range(3008)] smiles = "CC[N+](C)(C)Cc1ccccc1Br" predict = pipeline(mutation, smiles) - print(predict) assert isinstance(predict, float) @@ -24,12 +22,11 @@ def test_MutSmiFullConnection(): tokenizer = tkz.MoltxTokenizer.from_pretrain( conf=AdaMRTokenizerConfig.Prediction ) - d_model = 768 - model = models.MutSmiFullConnection(d_model) + model = models.MutSmiFullConnection() # model.load_ckpt('/path/to/mutsmifullconnection.ckpt') pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer, model=model) - mutation = [[random.choice([1, 0]) for _ in range(3008)]] + mutation = [random.choice([1, 0]) for _ in range(3008)] smiles = "CC[N+](C)(C)Cc1ccccc1Br" predict = pipeline(mutation, smiles) assert isinstance(predict, float) From afe17db698bab4ac154b9fad38ef685a0cb21aae Mon Sep 17 00:00:00 2001 From: seeuzhouxin <18986114551@163.com> Date: Wed, 22 May 2024 22:11:22 +0800 Subject: [PATCH 3/4] fix: put gene fc layer in GNN & del self.d_model --- dooc/models.py | 30 +++++++++--------------------- dooc/nets.py | 5 +++++ requirements.txt | 2 +- 3 files changed, 15 insertions(+), 22 deletions(-) diff --git a/dooc/models.py b/dooc/models.py index 37b5a3a..1964f4f 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -16,21 +16,9 @@ def __init__( super().__init__() self.gene_conf = gene_conf self.smiles_conf = smiles_conf - self.d_model = self.smiles_conf.d_model - self.init_smiles_encoder = AdaMR(smiles_conf) + self.smiles_encoder = AdaMR(smiles_conf) - self.init_gene_encoder = nets.GeneGNN(gene_conf) - self.gene_fc = nn.Linear(gene_conf.num_hiddens_genotype, self.d_model) - - def smiles_encoder(self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor) -> torch.Tensor: - smiles_out = self.init_smiles_encoder.forward_feature(smiles_src, smiles_tgt) - return smiles_out - # return smiles_out.unsqueeze(0) if smiles_out.dim() == 1 else smiles_out - - def gene_encoder(self, gene_src: torch.Tensor) -> torch.Tensor: - gene_out = self.init_gene_encoder(gene_src) - gene_out = self.gene_fc(gene_out) - return gene_out + self.gene_encoder = nets.GeneGNN(gene_conf) def load_ckpt(self, *ckpt_files: str) -> None: """load check point model. @@ -67,13 +55,13 @@ def __init__( smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - d_hidden = self.d_model // 2 - decoder_layer = nn.TransformerDecoderLayer(self.d_model, + d_hidden = self.smiles_conf.d_model // 2 + decoder_layer = nn.TransformerDecoderLayer(self.smiles_conf.d_model, nhead) self.smixmut_decoder = nn.TransformerDecoder(decoder_layer, num_layers) self.reg = nn.Sequential( - nn.Linear(self.d_model, d_hidden), + nn.Linear(self.smiles_conf.d_model, d_hidden), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_hidden, 1), @@ -83,7 +71,7 @@ def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2 - smiles_out = self.smiles_encoder(smiles_src, smiles_tgt) + smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) gene_out = self.gene_encoder(gene_src) feat = None feat = self.smixmut_decoder(smiles_out, gene_out) @@ -100,9 +88,9 @@ def __init__( smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - d_hidden = self.d_model // 2 + d_hidden = self.smiles_conf.d_model // 2 self.reg = nn.Sequential( - nn.Linear(self.d_model, d_hidden), + nn.Linear(self.smiles_conf.d_model, d_hidden), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_hidden, 1), @@ -111,7 +99,7 @@ def __init__( def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: - smiles_out = self.smiles_encoder(smiles_src, smiles_tgt) + smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) gene_out = self.gene_encoder(gene_src) feat = None diff --git a/dooc/nets.py b/dooc/nets.py index 3505ac4..8002b18 100644 --- a/dooc/nets.py +++ b/dooc/nets.py @@ -7,6 +7,7 @@ @dataclass class GeneGNNConfig: + d_model: int gene_dim: int drug_dim: int num_hiddens_genotype: int @@ -24,6 +25,7 @@ class GeneGNN(nn.Module): """ DEFAULT_CONFIG = GeneGNNConfig( + d_model=768, gene_dim=3008, drug_dim=2048, num_hiddens_genotype=6, @@ -49,6 +51,8 @@ def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None: self._construct_nn_graph() self._construct_nn_drug() self._construct_final_layer() + self.final_fc_layer = nn.Linear(self.conf.num_hiddens_genotype, + self.conf.d_model) def _contruct_direct_gene_layer(self): """ @@ -225,4 +229,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) out = term_nn_out_map[self.dg_root] + out = self.final_fc_layer(out) return out diff --git a/requirements.txt b/requirements.txt index 1003af0..07a13fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -moltx~=1.0.2 +moltx~=1.0.4 networkx~=3.1 scikit-learn~=1.3.0 From e4733aec74168c4df1c9dec72b5bc46983c6fe96 Mon Sep 17 00:00:00 2001 From: seeuzhouxin <18986114551@163.com> Date: Wed, 22 May 2024 23:12:41 +0800 Subject: [PATCH 4/4] fix: neat syntax --- dooc/models.py | 21 +++++++++------------ dooc/nets.py | 7 ++++--- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/dooc/models.py b/dooc/models.py index 1964f4f..ec7b02f 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -55,13 +55,12 @@ def __init__( smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - d_hidden = self.smiles_conf.d_model // 2 - decoder_layer = nn.TransformerDecoderLayer(self.smiles_conf.d_model, - nhead) - self.smixmut_decoder = nn.TransformerDecoder(decoder_layer, - num_layers) + d_model = self.smiles_conf.d_model + d_hidden = d_model // 2 + layer = nn.TransformerDecoderLayer(d_model, nhead) + self.cross_att = nn.TransformerDecoder(layer, num_layers) self.reg = nn.Sequential( - nn.Linear(self.smiles_conf.d_model, d_hidden), + nn.Linear(d_model, d_hidden), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_hidden, 1), @@ -73,8 +72,7 @@ def forward( assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2 smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) gene_out = self.gene_encoder(gene_src) - feat = None - feat = self.smixmut_decoder(smiles_out, gene_out) + feat = self.cross_att(smiles_out, gene_out) return self.reg(feat) @@ -88,9 +86,10 @@ def __init__( smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - d_hidden = self.smiles_conf.d_model // 2 + d_model = self.smiles_conf.d_model + d_hidden = d_model // 2 self.reg = nn.Sequential( - nn.Linear(self.smiles_conf.d_model, d_hidden), + nn.Linear(d_model, d_hidden), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_hidden, 1), @@ -101,8 +100,6 @@ def forward( ) -> torch.Tensor: smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) gene_out = self.gene_encoder(gene_src) - - feat = None feat = smiles_out + gene_out return self.reg(feat) diff --git a/dooc/nets.py b/dooc/nets.py index 8002b18..e2d867f 100644 --- a/dooc/nets.py +++ b/dooc/nets.py @@ -38,6 +38,7 @@ class GeneGNN(nn.Module): def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None: super().__init__() self.conf = conf + d_model = self.conf.d_model dg, dg_root, term_size_map, term_direct_gene_map = self._get_params() self.dg, self.dg_root = dg, dg_root @@ -51,8 +52,8 @@ def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None: self._construct_nn_graph() self._construct_nn_drug() self._construct_final_layer() - self.final_fc_layer = nn.Linear(self.conf.num_hiddens_genotype, - self.conf.d_model) + self.out_fc = nn.Linear(self.conf.num_hiddens_genotype, + d_model) def _contruct_direct_gene_layer(self): """ @@ -229,5 +230,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) out = term_nn_out_map[self.dg_root] - out = self.final_fc_layer(out) + out = self.out_fc(out) return out