Skip to content

Commit

Permalink
fix: del unused class vars & rename decoder & rewrite smiles/gene enc…
Browse files Browse the repository at this point in the history
…oder
  • Loading branch information
zhouxinseeu committed May 22, 2024
1 parent e165fd9 commit eac2322
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 45 deletions.
58 changes: 26 additions & 32 deletions dooc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,21 @@ def __init__(
super().__init__()
self.gene_conf = gene_conf
self.smiles_conf = smiles_conf
self.smiles_encoder = AdaMR(smiles_conf)
self.d_model = self.smiles_conf.d_model
self.init_smiles_encoder = AdaMR(smiles_conf)

self.gene_encoder = nets.GeneGNN(gene_conf)
self.init_gene_encoder = nets.GeneGNN(gene_conf)
self.gene_fc = nn.Linear(gene_conf.num_hiddens_genotype, self.d_model)

def smiles_encoder(self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor) -> torch.Tensor:
smiles_out = self.init_smiles_encoder.forward_feature(smiles_src, smiles_tgt)
return smiles_out
# return smiles_out.unsqueeze(0) if smiles_out.dim() == 1 else smiles_out

def gene_encoder(self, gene_src: torch.Tensor) -> torch.Tensor:
gene_out = self.init_gene_encoder(gene_src)
gene_out = self.gene_fc(gene_out)
return gene_out

def load_ckpt(self, *ckpt_files: str) -> None:
"""load check point model.
Expand Down Expand Up @@ -49,22 +61,17 @@ class MutSmiXAttention(MutSmi):

def __init__(
self,
d_model: int,
nhead: int = 2,
num_layers: int = 2,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
self.d_model = d_model
d_hidden = d_model // 2
self.nhead = nhead
self.num_layers = num_layers
self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model)
d_hidden = self.d_model // 2
decoder_layer = nn.TransformerDecoderLayer(self.d_model,
self.nhead)
self.decoder = nn.TransformerDecoder(decoder_layer,
num_layers=self.num_layers)
nhead)
self.smixmut_decoder = nn.TransformerDecoder(decoder_layer,
num_layers)
self.reg = nn.Sequential(
nn.Linear(self.d_model, d_hidden),
nn.ReLU(),
Expand All @@ -75,15 +82,11 @@ def __init__(
def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
is_batched = smiles_src.dim() == 2

smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_fc(self.gene_encoder(gene_src))
assert smiles_src.dim() == 2 and smiles_tgt.dim() == 2
smiles_out = self.smiles_encoder(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src)
feat = None
if is_batched:
feat = self.decoder(smiles_out, gene_out)
else:
feat = self.decoder(smiles_out.unsqueeze(0), gene_out)
feat = self.smixmut_decoder(smiles_out, gene_out)

return self.reg(feat)

Expand All @@ -93,15 +96,11 @@ class MutSmiFullConnection(MutSmi):

def __init__(
self,
d_model: int,
gene_conf: nets.GeneGNNConfig = nets.GeneGNN.DEFAULT_CONFIG,
smiles_conf: AbsPosEncoderDecoderConfig = AdaMR.CONFIG_BASE,
) -> None:
super().__init__(gene_conf, smiles_conf)
self.d_model = d_model

d_hidden = d_model // 2
self.gene_fc = nn.Linear(self.gene_conf.num_hiddens_genotype, self.d_model)
d_hidden = self.d_model // 2
self.reg = nn.Sequential(
nn.Linear(self.d_model, d_hidden),
nn.ReLU(),
Expand All @@ -112,15 +111,10 @@ def __init__(
def forward(
self, smiles_src: torch.Tensor, smiles_tgt: torch.Tensor, gene_src: torch.Tensor
) -> torch.Tensor:
is_batched = smiles_src.dim() == 2

smiles_out = self.smiles_encoder.forward_feature(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src).unsqueeze(1)
smiles_out = self.smiles_encoder(smiles_src, smiles_tgt)
gene_out = self.gene_encoder(gene_src)

feat = None
if is_batched:
feat = smiles_out + self.gene_fc(gene_out)[:, 0]
else:
feat = smiles_out[0] + self.gene_fc(gene_out)[0]
feat = smiles_out + gene_out

return self.reg(feat)
6 changes: 4 additions & 2 deletions dooc/nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
removed drug layer
"""

x_dim = x.dim()
x = x.unsqueeze(0) if x_dim == 1 else x
gene_input = x.narrow(1, 0, self.conf.gene_dim)
# drug_input = x.narrow(1, self.conf.gene_dim, self.conf.drug_dim)

Expand Down Expand Up @@ -223,4 +224,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
aux_layer1_out
)

return term_nn_out_map[self.dg_root]
out = term_nn_out_map[self.dg_root]
return out
2 changes: 2 additions & 0 deletions dooc/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(

def __call__(self, gene: typing.Sequence[int], smiles: str) -> float:
smi_src, smi_tgt, gene_src = self._model_args(gene, smiles)
smi_src = smi_src.unsqueeze(0)
smi_tgt = smi_tgt.unsqueeze(0)
pred = self.model(smi_src, smi_tgt, gene_src)
return pred.item()

Expand Down
6 changes: 2 additions & 4 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ def test_MutSmiXAttention():
mutations = [[random.choice([0, 1]) for _ in range(3008)],
[random.choice([0, 1]) for _ in range(3008)]]
mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu")
d_model = 768
model = models.MutSmiXAttention(d_model)
model = models.MutSmiXAttention()
out = model(smiles_src, smiles_tgt, mutations_src)
assert out.shape == (2, 1)

Expand All @@ -21,7 +20,6 @@ def test_MutSmiFullConnection():
mutations = [[random.choice([0, 1]) for _ in range(3008)],
[random.choice([0, 1]) for _ in range(3008)]]
mutations_src = torch.tensor(mutations, dtype=torch.float).to("cpu")
d_model = 768
model = models.MutSmiFullConnection(d_model)
model = models.MutSmiFullConnection()
out = model(smiles_src, smiles_tgt, mutations_src)
assert out.shape == (2, 1)
11 changes: 4 additions & 7 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,25 @@ def test_MutSmiXAttention():
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
d_model = 768
model = models.MutSmiXAttention(d_model)
model = models.MutSmiXAttention()
# model.load_ckpt('/path/to/mutsmixattention.ckpt')
pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer,
model=model)
mutation = [[random.choice([1, 0]) for _ in range(3008)]]
mutation = [random.choice([1, 0]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutation, smiles)
print(predict)
assert isinstance(predict, float)


def test_MutSmiFullConnection():
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
d_model = 768
model = models.MutSmiFullConnection(d_model)
model = models.MutSmiFullConnection()
# model.load_ckpt('/path/to/mutsmifullconnection.ckpt')
pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer,
model=model)
mutation = [[random.choice([1, 0]) for _ in range(3008)]]
mutation = [random.choice([1, 0]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutation, smiles)
assert isinstance(predict, float)

0 comments on commit eac2322

Please sign in to comment.