Skip to content

Commit

Permalink
fix: put gene fc layer in GNN & del self.d_model
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouxinseeu committed May 22, 2024
1 parent eac2322 commit afe17db
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
30 changes: 9 additions & 21 deletions dooc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,9 @@ def __init__(
super().__init__()
self.gene_conf = gene_conf
self.smiles_conf = smiles_conf
self.d_model = self.smiles_conf.d_model
self.init_smiles_encoder = AdaMR(smiles_conf)
self.smiles_encoder = AdaMR(smiles_conf)

self.init_gene_encoder = nets.GeneGNN(gene_conf)
self.gene_fc = nn.Linear(gene_conf.num_hiddens_genotype, self.d_model)

def smiles_encoder(self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor) -> torch.Tensor:
smiles_out = self.init_smiles_encoder.forward_feature(smiles_src, smiles_tgt)
return smiles_out
# return smiles_out.unsqueeze(0) if smiles_out.dim() == 1 else smiles_out

def gene_encoder(self, gene_src: torch.Tensor) -> torch.Tensor:
gene_out = self.init_gene_encoder(gene_src)
gene_out = self.gene_fc(gene_out)
return gene_out
self.gene_encoder = nets.GeneGNN(gene_conf)

def load_ckpt(self, *ckpt_files: str) -> None:
"""load check point model.
Expand Down Expand Up @@ -67,13 +55,13 @@ def __init__(
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
d_hidden = self.d_model // 2
decoder_layer = nn.TransformerDecoderLayer(self.d_model,
d_hidden = self.smiles_conf.d_model // 2
decoder_layer = nn.TransformerDecoderLayer(self.smiles_conf.d_model,
nhead)
self.smixmut_decoder = nn.TransformerDecoder(decoder_layer,
num_layers)
self.reg = nn.Sequential(
nn.Linear(self.d_model, d_hidden),
nn.Linear(self.smiles_conf.d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
Expand All @@ -83,7 +71,7 @@ def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2
smiles_out = self.smiles_encoder(smiles_src, smiles_tgt)
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src)
feat = None
feat = self.smixmut_decoder(smiles_out, gene_out)
Expand All @@ -100,9 +88,9 @@ def __init__(
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
d_hidden = self.d_model // 2
d_hidden = self.smiles_conf.d_model // 2
self.reg = nn.Sequential(
nn.Linear(self.d_model, d_hidden),
nn.Linear(self.smiles_conf.d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
Expand All @@ -111,7 +99,7 @@ def __init__(
def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
smiles_out = self.smiles_encoder(smiles_src, smiles_tgt)
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src)

feat = None
Expand Down
5 changes: 5 additions & 0 deletions dooc/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

@dataclass
class GeneGNNConfig:
d_model: int
gene_dim: int
drug_dim: int
num_hiddens_genotype: int
Expand All @@ -24,6 +25,7 @@ class GeneGNN(nn.Module):
"""

DEFAULT_CONFIG = GeneGNNConfig(
d_model=768,
gene_dim=3008,
drug_dim=2048,
num_hiddens_genotype=6,
Expand All @@ -49,6 +51,8 @@ def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None:
self._construct_nn_graph()
self._construct_nn_drug()
self._construct_final_layer()
self.final_fc_layer = nn.Linear(self.conf.num_hiddens_genotype,
self.conf.d_model)

def _contruct_direct_gene_layer(self):
"""
Expand Down Expand Up @@ -225,4 +229,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)

out = term_nn_out_map[self.dg_root]
out = self.final_fc_layer(out)
return out
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
moltx~=1.0.2
moltx~=1.0.4
networkx~=3.1
scikit-learn~=1.3.0

0 comments on commit afe17db

Please sign in to comment.