From eac2322e458d9ad2ed22bc8dd4c766ff45c9430e Mon Sep 17 00:00:00 2001 From: seeuzhouxin <18986114551@163.com> Date: Wed, 22 May 2024 18:27:14 +0800 Subject: [PATCH] fix: del unused class vars & rename decoder & rewrite smiles/gene encoder --- dooc/models.py | 58 ++++++++++++++++++----------------------- dooc/nets.py | 6 +++-- dooc/pipelines.py | 2 ++ tests/test_models.py | 6 ++--- tests/test_pipelines.py | 11 +++----- 5 files changed, 38 insertions(+), 45 deletions(-) diff --git a/dooc/models.py b/dooc/models.py index 673ba05..37b5a3a 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -16,9 +16,21 @@ def __init__( super().__init__() self.gene_conf = gene_conf self.smiles_conf = smiles_conf - self.smiles_encoder = AdaMR(smiles_conf) + self.d_model = self.smiles_conf.d_model + self.init_smiles_encoder = AdaMR(smiles_conf) - self.gene_encoder = nets.GeneGNN(gene_conf) + self.init_gene_encoder = nets.GeneGNN(gene_conf) + self.gene_fc = nn.Linear(gene_conf.num_hiddens_genotype, self.d_model) + + def smiles_encoder(self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor) -> torch.Tensor: + smiles_out = self.init_smiles_encoder.forward_feature(smiles_src, smiles_tgt) + return smiles_out + # return smiles_out.unsqueeze(0) if smiles_out.dim() == 1 else smiles_out + + def gene_encoder(self, gene_src: torch.Tensor) -> torch.Tensor: + gene_out = self.init_gene_encoder(gene_src) + gene_out = self.gene_fc(gene_out) + return gene_out def load_ckpt(self, *ckpt_files: str) -> None: """load check point model. @@ -49,22 +61,17 @@ class MutSmiXAttention(MutSmi): def __init__( self, - d_model: int, nhead: int = 2, num_layers: int = 2, gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - self.d_model = d_model - d_hidden = d_model // 2 - self.nhead = nhead - self.num_layers = num_layers - self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model) + d_hidden = self.d_model // 2 decoder_layer = nn.TransformerDecoderLayer(self.d_model, - self.nhead) - self.decoder = nn.TransformerDecoder(decoder_layer, - num_layers=self.num_layers) + nhead) + self.smixmut_decoder = nn.TransformerDecoder(decoder_layer, + num_layers) self.reg = nn.Sequential( nn.Linear(self.d_model, d_hidden), nn.ReLU(), @@ -75,15 +82,11 @@ def __init__( def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: - is_batched = smiles_src.dim() == 2 - - smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) - gene_out = self.gene_fc(self.gene_encoder(gene_src)) + assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2 + smiles_out = self.smiles_encoder(smiles_src, smiles_tgt) + gene_out = self.gene_encoder(gene_src) feat = None - if is_batched: - feat = self.decoder(smiles_out, gene_out) - else: - feat = self.decoder(smiles_out.unsqueeze(0), gene_out) + feat = self.smixmut_decoder(smiles_out, gene_out) return self.reg(feat) @@ -93,15 +96,11 @@ class MutSmiFullConnection(MutSmi): def __init__( self, - d_model: int, gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG, smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - self.d_model = d_model - - d_hidden = d_model // 2 - self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model) + d_hidden = self.d_model // 2 self.reg = nn.Sequential( nn.Linear(self.d_model, d_hidden), nn.ReLU(), @@ -112,15 +111,10 @@ def __init__( def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: - is_batched = smiles_src.dim() == 2 - - smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) - gene_out = self.gene_encoder(gene_src).unsqueeze(1) + smiles_out = self.smiles_encoder(smiles_src, smiles_tgt) + gene_out = self.gene_encoder(gene_src) feat = None - if is_batched: - feat = smiles_out + self.gene_fc(gene_out)[:, 0] - else: - feat = smiles_out[0] + self.gene_fc(gene_out)[0] + feat = smiles_out + gene_out return self.reg(feat) diff --git a/dooc/nets.py b/dooc/nets.py index 9a0150e..3505ac4 100644 --- a/dooc/nets.py +++ b/dooc/nets.py @@ -181,7 +181,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ removed drug layer """ - + x_dim = x.dim() + x = x.unsqueeze(0) if x_dim == 1 else x gene_input = x.narrow(1, 0, self.conf.gene_dim) # drug_input = x.narrow(1, self.conf.gene_dim, self.conf.drug_dim) @@ -223,4 +224,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: aux_layer1_out ) - return term_nn_out_map[self.dg_root] + out = term_nn_out_map[self.dg_root] + return out diff --git a/dooc/pipelines.py b/dooc/pipelines.py index 23b1e4a..abe7a39 100644 --- a/dooc/pipelines.py +++ b/dooc/pipelines.py @@ -50,6 +50,8 @@ def __init__( def __call__(self, gene: typing.Sequence[int], smiles: str) -> float: smi_src, smi_tgt, gene_src = self._model_args(gene, smiles) + smi_src = smi_src.unsqueeze(0) + smi_tgt = smi_tgt.unsqueeze(0) pred = self.model(smi_src, smi_tgt, gene_src) return pred.item() diff --git a/tests/test_models.py b/tests/test_models.py index 5c8ede1..46fe205 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -9,8 +9,7 @@ def test_MutSmiXAttention(): mutations = [[random.choice([0, 1]) for _ in range(3008)], [random.choice([0, 1]) for _ in range(3008)]] mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu") - d_model = 768 - model = models.MutSmiXAttention(d_model) + model = models.MutSmiXAttention() out = model(smiles_src, smiles_tgt, mutations_src) assert out.shape == (2, 1) @@ -21,7 +20,6 @@ def test_MutSmiFullConnection(): mutations = [[random.choice([0, 1]) for _ in range(3008)], [random.choice([0, 1]) for _ in range(3008)]] mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu") - d_model = 768 - model = models.MutSmiFullConnection(d_model) + model = models.MutSmiFullConnection() out = model(smiles_src, smiles_tgt, mutations_src) assert out.shape == (2, 1) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index e7fa354..c95efa9 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -8,15 +8,13 @@ def test_MutSmiXAttention(): tokenizer = tkz.MoltxTokenizer.from_pretrain( conf=AdaMRTokenizerConfig.Prediction ) - d_model = 768 - model = models.MutSmiXAttention(d_model) + model = models.MutSmiXAttention() # model.load_ckpt('/path/to/mutsmixattention.ckpt') pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer, model=model) - mutation = [[random.choice([1, 0]) for _ in range(3008)]] + mutation = [random.choice([1, 0]) for _ in range(3008)] smiles = "CC[N+](C)(C)Cc1ccccc1Br" predict = pipeline(mutation, smiles) - print(predict) assert isinstance(predict, float) @@ -24,12 +22,11 @@ def test_MutSmiFullConnection(): tokenizer = tkz.MoltxTokenizer.from_pretrain( conf=AdaMRTokenizerConfig.Prediction ) - d_model = 768 - model = models.MutSmiFullConnection(d_model) + model = models.MutSmiFullConnection() # model.load_ckpt('/path/to/mutsmifullconnection.ckpt') pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer, model=model) - mutation = [[random.choice([1, 0]) for _ in range(3008)]] + mutation = [random.choice([1, 0]) for _ in range(3008)] smiles = "CC[N+](C)(C)Cc1ccccc1Br" predict = pipeline(mutation, smiles) assert isinstance(predict, float)