Skip to content

Commit

Permalink
fix: neat syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouxinseeu committed May 22, 2024
1 parent afe17db commit e4733ae
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
21 changes: 9 additions & 12 deletions dooc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,12 @@ def __init__(
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
d_hidden = self.smiles_conf.d_model // 2
decoder_layer = nn.TransformerDecoderLayer(self.smiles_conf.d_model,
nhead)
self.smixmut_decoder = nn.TransformerDecoder(decoder_layer,
num_layers)
d_model = self.smiles_conf.d_model
d_hidden = d_model // 2
layer = nn.TransformerDecoderLayer(d_model, nhead)
self.cross_att = nn.TransformerDecoder(layer, num_layers)
self.reg = nn.Sequential(
nn.Linear(self.smiles_conf.d_model, d_hidden),
nn.Linear(d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
Expand All @@ -73,8 +72,7 @@ def forward(
assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src)
feat = None
feat = self.smixmut_decoder(smiles_out, gene_out)
feat = self.cross_att(smiles_out, gene_out)

return self.reg(feat)

Expand All @@ -88,9 +86,10 @@ def __init__(
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
d_hidden = self.smiles_conf.d_model // 2
d_model = self.smiles_conf.d_model
d_hidden = d_model // 2
self.reg = nn.Sequential(
nn.Linear(self.smiles_conf.d_model, d_hidden),
nn.Linear(d_model, d_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(d_hidden, 1),
Expand All @@ -101,8 +100,6 @@ def forward(
) -> torch.Tensor:
smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src)

feat = None
feat = smiles_out + gene_out

return self.reg(feat)
7 changes: 4 additions & 3 deletions dooc/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class GeneGNN(nn.Module):
def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None:
super().__init__()
self.conf = conf
d_model = self.conf.d_model

dg, dg_root, term_size_map, term_direct_gene_map = self._get_params()
self.dg, self.dg_root = dg, dg_root
Expand All @@ -51,8 +52,8 @@ def __init__(self, conf: GeneGNNConfig = DEFAULT_CONFIG) -> None:
self._construct_nn_graph()
self._construct_nn_drug()
self._construct_final_layer()
self.final_fc_layer = nn.Linear(self.conf.num_hiddens_genotype,
self.conf.d_model)
self.out_fc = nn.Linear(self.conf.num_hiddens_genotype,
d_model)

def _contruct_direct_gene_layer(self):
"""
Expand Down Expand Up @@ -229,5 +230,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
)

out = term_nn_out_map[self.dg_root]
out = self.final_fc_layer(out)
out = self.out_fc(out)
return out

0 comments on commit e4733ae

Please sign in to comment.