diff --git a/dooc/models.py b/dooc/models.py index 37b5a3a..1964f4f 100644 --- a/dooc/models.py +++ b/dooc/models.py @@ -16,21 +16,9 @@ def __init__( super().__init__() self.gene_conf = gene_conf self.smiles_conf = smiles_conf - self.d_model = self.smiles_conf.d_model - self.init_smiles_encoder = AdaMR(smiles_conf) + self.smiles_encoder = AdaMR(smiles_conf) - self.init_gene_encoder = nets.GeneGNN(gene_conf) - self.gene_fc = nn.Linear(gene_conf.num_hiddens_genotype, self.d_model) - - def smiles_encoder(self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor) -> torch.Tensor: - smiles_out = self.init_smiles_encoder.forward_feature(smiles_src, smiles_tgt) - return smiles_out - # return smiles_out.unsqueeze(0) if smiles_out.dim() == 1 else smiles_out - - def gene_encoder(self, gene_src: torch.Tensor) -> torch.Tensor: - gene_out = self.init_gene_encoder(gene_src) - gene_out = self.gene_fc(gene_out) - return gene_out + self.gene_encoder = nets.GeneGNN(gene_conf) def load_ckpt(self, *ckpt_files: str) -> None: """load check point model. @@ -67,13 +55,13 @@ def __init__( smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - d_hidden = self.d_model // 2 - decoder_layer = nn.TransformerDecoderLayer(self.d_model, + d_hidden = self.smiles_conf.d_model // 2 + decoder_layer = nn.TransformerDecoderLayer(self.smiles_conf.d_model, nhead) self.smixmut_decoder = nn.TransformerDecoder(decoder_layer, num_layers) self.reg = nn.Sequential( - nn.Linear(self.d_model, d_hidden), + nn.Linear(self.smiles_conf.d_model, d_hidden), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_hidden, 1), @@ -83,7 +71,7 @@ def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2 - smiles_out = self.smiles_encoder(smiles_src, smiles_tgt) + smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) gene_out = self.gene_encoder(gene_src) feat = None feat = self.smixmut_decoder(smiles_out, gene_out) @@ -100,9 +88,9 @@ def __init__( smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE, ) -> None: super().__init__(gene_conf, smiles_conf) - d_hidden = self.d_model // 2 + d_hidden = self.smiles_conf.d_model // 2 self.reg = nn.Sequential( - nn.Linear(self.d_model, d_hidden), + nn.Linear(self.smiles_conf.d_model, d_hidden), nn.ReLU(), nn.Dropout(0.1), nn.Linear(d_hidden, 1), @@ -111,7 +99,7 @@ def __init__( def forward( self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor ) -> torch.Tensor: - smiles_out = self.smiles_encoder(smiles_src, smiles_tgt) + smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt) gene_out = self.gene_encoder(gene_src) feat = None diff --git a/dooc/nets.py b/dooc/nets.py index 3505ac4..8002b18 100644 --- a/dooc/nets.py +++ b/dooc/nets.py @@ -7,6 +7,7 @@ @dataclass class GeneGNNConfig: + d_model: int gene_dim: int drug_dim: int num_hiddens_genotype: int @@ -24,6 +25,7 @@ class GeneGNN(nn.Module): """ DEFAULT_CONFIG = GeneGNNConfig( + d_model=768, gene_dim=3008, drug_dim=2048, num_hiddens_genotype=6, @@ -49,6 +51,8 @@ def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None: self._construct_nn_graph() self._construct_nn_drug() self._construct_final_layer() + self.final_fc_layer = nn.Linear(self.conf.num_hiddens_genotype, + self.conf.d_model) def _contruct_direct_gene_layer(self): """ @@ -225,4 +229,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) out = term_nn_out_map[self.dg_root] + out = self.final_fc_layer(out) return out diff --git a/requirements.txt b/requirements.txt index 1003af0..07a13fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -moltx~=1.0.2 +moltx~=1.0.4 networkx~=3.1 scikit-learn~=1.3.0