Skip to content

Commit

Permalink
fix: drugcell net
Browse files Browse the repository at this point in the history
  • Loading branch information
origyZ committed Jun 7, 2024
1 parent 3cd8895 commit 8ddd0e0
Showing 1 changed file with 3 additions and 11 deletions.
14 changes: 3 additions & 11 deletions dooc/nets/drugcell.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def __init__(self, conf: DrugcellConfig = DEFAULT_CONFIG) -> None:
self._cal_term_dim()
self._contruct_direct_gene_layer()
self._construct_nn_graph()
self._construct_final_layer()
self.out_fc = nn.Linear(self.conf.num_hiddens_genotype,
self.conf.d_model)

def _contruct_direct_gene_layer(self):
"""
Expand Down Expand Up @@ -182,15 +183,6 @@ def _construct_nn_graph(self):

self.dg.remove_nodes_from(leaves)

def _construct_final_layer(self):
"""
add modules for final layer
"""
self.add_module(
"final_linear_layer",
nn.Linear(self.conf.num_hiddens_genotype, self.conf.d_model),
)

def _cal_term_dim(self):
"""
calculate the number of values in a state (term)
Expand Down Expand Up @@ -256,7 +248,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
tanh_out
)

out = self._modules['final_linear_layer'](term_nn_out_map[self.dg_root])
out = self.out_fc(term_nn_out_map[self.dg_root])
if x_dim == 1:
out = out.squeeze(0)
return out

0 comments on commit 8ddd0e0

Please sign in to comment.