Skip to content

Commit

Permalink
feat: vae omics net
Browse files Browse the repository at this point in the history
  • Loading branch information
origyZ committed Jul 17, 2024
1 parent 3d96167 commit c48b8d5
Show file tree
Hide file tree
Showing 6 changed files with 714 additions and 6 deletions.
101 changes: 99 additions & 2 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from moltx import tokenizers, datasets


class _SmiMutBase:
class _SmiBase:
def __init__(self, smi_tokenizer: tokenizers.MoltxTokenizer, device: torch.device = torch.device("cpu")) -> None:
self.smi_ds = datasets.Base(smi_tokenizer, device)
self.device = device
Expand All @@ -22,7 +22,7 @@ def _smi_tokenize(self, smiles: typing.Sequence[str], seq_len: int = None) -> to
"""


class _DrugcellAdamrBase(_SmiMutBase):
class _DrugcellAdamrBase(_SmiBase):
"""Base datasets, convert smiles and genes to torch.Tensor."""

def __init__(
Expand Down Expand Up @@ -141,6 +141,72 @@ def __call__(
return mut_x, smi_tgt, out


class _VAEOmicsAdamr2Base(_SmiBase):
"""Base datasets, convert smiles and omics to torch.Tensor."""

def __init__(
self,
smi_tokenizer: tokenizers.MoltxTokenizer,
device: torch.device = torch.device("cpu")
) -> None:
super().__init__(smi_tokenizer, device)
self.smi_tokenizer = smi_tokenizer

def _smi_tokens(
self,
smiles: typing.Sequence[str],
seq_len: int = 200,
) -> torch.Tensor:
tgt = self._smi_tokenize(
[f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len)
return tgt

def _omics_tokens(self, omics_seq: typing.Sequence[list]) -> typing.Sequence[torch.Tensor]:
return [torch.tensor(omic, device=self.device) for omic in omics_seq]

def _out(self, values: typing.Sequence[float]) -> torch.Tensor:
return torch.tensor(values, device=self.device)


class _VAEOmicsAdamr2OmicsSmi(_VAEOmicsAdamr2Base):
def __call__(
self,
omics_seq: typing.Sequence[list],
smis: typing.Sequence[str],
vals: typing.Sequence[float],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
assert len(smis) == len(vals) and len(omics_seq[0]) == len(vals)
omics_x = self._omics_tokens(omics_seq)
smi_tgt = self._smi_tokens(smis, seq_len)
out = self._out(vals).unsqueeze(-1)
return omics_x, smi_tgt, out


class _VAEOmicsAdamr2OmicsSmis(_VAEOmicsAdamr2Base):
def __call__(
self,
omics_seq: typing.Sequence[list],
lsmis: typing.Sequence[typing.Sequence[str]],
lvals: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
"""
omics_seq: [omic1, omic2, ...](omics type len) omic1: [omic11, omic12, ...](batch size) omics1_1: [gene1, gene2, ...]
bsmiles: [[smi11, smi12], [smi21, smi22], ...]
bvlaues: [[val11, val12], [val21, val22], ...]
"""
assert len(lsmis) == len(lvals) and len(omics_seq[0]) == len(lvals)
omics_x = self._omics_tokens(omics_seq)
batchlen = len(lsmis)
listlen = len(lsmis[0])
smiles = [smi for bsmi in lsmis for smi in bsmi]
smi_tgt = self._smi_tokens(smiles, seq_len)
smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_tgt.size(-1))
out = self._out(lvals)
return omics_x, smi_tgt, out


"""
Mutations(Individual Sample) and Smiles Interaction
Expand Down Expand Up @@ -170,3 +236,34 @@ def __call__(

class MutSmisListwiseRank(_DrugcellAdamr2MutSmis):
pass


"""
Omicsations(Individual Sample) and Smiles Interaction
OmicsSmiReg
OmicsSmis{Pair/List}wiseRank
OmicssSmi{Pair/List}wiseRank
"""


class OmicsSmiReg(_VAEOmicsAdamr2OmicsSmi):
pass


class OmicsSmisPairwiseRank(_VAEOmicsAdamr2OmicsSmis):
def __call__(
self,
omics_seq: typing.Sequence[list],
lsmiles: typing.Sequence[typing.Sequence[str]],
lvalues: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
omics_x, smi_tgt, rout = super().__call__(omics_seq, lsmiles, lvalues, seq_len)
out = torch.zeros(rout.size(0), dtype=rout.dtype, device=self.device)
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1.0
return omics_x, smi_tgt, out


class OmicsSmisListwiseRank(_VAEOmicsAdamr2OmicsSmis):
pass
116 changes: 116 additions & 0 deletions dooc/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing
import torch
import torch.nn as nn
from torch.distributions import Normal, kl_divergence


class ListNetLoss(nn.Module):
Expand All @@ -11,3 +13,117 @@ def __init__(self, reduction: str = 'mean') -> None:
def forward(self, predict: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
out = - (target.softmax(dim=-1) * predict.log_softmax(dim=-1))
return getattr(out, self.reduction)()


class VAEOmicsLoss(nn.Module):

def __init__(self, loss_type: str, omics_num: int) -> None:
super().__init__()
self.loss_type = loss_type
self.k = omics_num
self.kl_loss_weight = 0.1 # TODO: 待定

def forward(self, x: typing.Sequence, out_x: typing.Sequence, **kwargs) -> float:
return getattr(self, f"_forward_{self.loss_type}")(x, out_x, **kwargs)

def _forward_generate(self, x: typing.Sequence, out_x: typing.Sequence, labels: torch.Tensor, **kwargs) -> typing.Sequence:
# out_encoder, out_self, out_cross, out_dsc, out_cl = out_x
out_encoder, out_self, out_cross, out_dsc = out_x
self_loss = self._calc_self_vae_loss(x, out_self)
cross_loss, cross_infer_dsc_loss = self._calc_cross_vae_loss(x, out_cross, out_encoder)
cross_infer_loss = self._calc_cross_infer_loss(out_encoder)
dsc_loss = self._calc_dsc_loss(out_dsc)
# contrastive_loss = self._calc_contrastive_loss(out_cl, labels)
generate_loss = (
self_loss + 0.1 * (cross_loss + cross_infer_loss * cross_infer_loss)
- (dsc_loss + cross_infer_dsc_loss) * 0.01 # + contrastive_loss
)
# return generate_loss, self_loss, cross_loss, cross_infer_loss, dsc_loss
return generate_loss

def _forward_dsc(self, x: typing.Sequence, out_x: typing.Sequence, **kwargs) -> float:
out_encoder, out_cross, out_dsc = out_x
_, cross_infer_dsc_loss = self._calc_cross_vae_loss(x, out_cross, out_encoder)
dsc_loss = self._calc_dsc_loss(out_dsc)
return cross_infer_dsc_loss + dsc_loss

def _calc_self_vae_loss(self, x: typing.Sequence, out_self: typing.Sequence) -> float:
loss = 0.
for i, v in enumerate(out_self):
recon_omics, mu, log_var = v
loss += (self.kl_loss_weight * self._kl_loss(mu, log_var, 1.0) + self.reconstruction_loss(x[i], recon_omics))
return loss

def _calc_cross_vae_loss(self, x: typing.Sequence, out_cross: typing.Sequence, out_encoder: typing.Sequence) -> typing.Sequence:
batch_size = x[0].size(0)
device = x[0].device
cross_elbo, cross_infer_loss, cross_kl_loss, cross_dsc_loss = 0, 0, 0, 0
for i, v in enumerate(out_cross):
_, real_mu, real_log_var = out_encoder[i][i]
reconstruct_omic, poe_mu, poe_log_var, pred_real_modal, pred_infer_modal = v
cross_elbo += (
self.kl_loss_weight * self._kl_loss(poe_mu, poe_log_var, 1.0)
+ self.reconstruction_loss(x[i], reconstruct_omic)
)
cross_infer_loss += self.reconstruction_loss(real_mu, poe_mu)
cross_kl_loss += self._kl_divergence(poe_mu, real_mu, poe_log_var, real_log_var)

real_modal = torch.tensor([1 for _ in range(batch_size)]).to(device)
infer_modal = torch.tensor([0 for _ in range(batch_size)]).to(device)
cross_dsc_loss += torch.nn.CrossEntropyLoss()(pred_real_modal, real_modal)
cross_dsc_loss += torch.nn.CrossEntropyLoss()(pred_infer_modal, infer_modal)

cross_dsc_loss = cross_dsc_loss.sum(0) / (len(out_cross) * batch_size)
return cross_elbo + cross_infer_loss + self.kl_loss_weight * cross_kl_loss, cross_dsc_loss

def _calc_cross_infer_loss(self, out_encoder: typing.Sequence) -> float:
infer_loss = 0
for i in range(self.k):
_, latent_mu, _ = out_encoder[i][i]
for j in range(self.k):
if i == j:
continue
_, latent_mu_infer, _ = out_encoder[j][i]
infer_loss += self.reconstruction_loss(latent_mu_infer, latent_mu)
return infer_loss / self.k

def _calc_dsc_loss(self, out_dsc: typing.Sequence) -> float:
dsc_loss = 0
batch_size = out_dsc[0].size(0)
for i in range(self.k):
real_modal = torch.tensor([i for _ in range(batch_size)])
dsc_loss += torch.nn.CrossEntropyLoss()(out_dsc[i], real_modal.to(out_dsc[i].device))
return dsc_loss.sum(0) / (self.k * batch_size)

def _calc_contrastive_loss(self, out_cl: typing.Sequence, labels: torch.Tensor) -> float:
margin = 1.0
distances = torch.cdist(out_cl, out_cl)

labels_matrix = labels.view(-1, 1) == labels.view(1, -1)

positive_pair_distances = distances * labels_matrix.float()
negative_pair_distances = distances * (1 - labels_matrix.float())

positive_loss = positive_pair_distances.sum() / labels_matrix.float().sum()
negative_loss = torch.nn.ReLU()(margin - negative_pair_distances).sum() / (1 - labels_matrix.float()).sum()

return positive_loss + negative_loss

def _kl_loss(self, mu, logvar, beta):
# KL divergence loss
kld_1 = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return beta * kld_1

def _kl_divergence(self, mu1, mu2, log_sigma1, log_sigma2):
p = Normal(mu1, torch.exp(log_sigma1))
q = Normal(mu2, torch.exp(log_sigma2))

# 计算KL损失
kl_loss = kl_divergence(p, q).mean()
return kl_loss

def reconstruction_loss(self, recon_x, x):
# batch_size = recon_x.size(0)
mse = nn.MSELoss() # reduction='sum'
recons_loss = mse(recon_x, x) # / batch_size
return recons_loss
33 changes: 32 additions & 1 deletion dooc/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import typing
from moltx import nets as mnets
from moltx import models as mmodels
from dooc import nets as dnets
from dooc.nets import heads, drugcell
from dooc.nets import heads, drugcell, vaeomic


"""
Expand Down Expand Up @@ -42,3 +43,33 @@ def forward_cmp(self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> float:
assert mut_x.dim() == 1 and smi_tgt.dim() == 2
out = self.forward(mut_x, smi_tgt) # [2]
return (out[0] - out[1]).item()


class OmicsSmiReg(dnets.VAEOmicsAdamr2OmicsSmisXattn):

def __init__(self, omics_conf: vaeomic.VAEOmicsConfig = dnets.VAEOmics.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(omics_conf, smi_conf)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
return self.reg(super().forward(omics, omics_x, smi_tgt)) # [b, 1]


class OmicsSmisRank(dnets.VAEOmicsAdamr2OmicsSmiXattn):

def __init__(self, omics_conf: vaeomic.VAEOmicsConfig = dnets.VAEOmics.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(omics_conf, smi_conf)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
return self.reg(super().forward(omics, omics_x, smi_tgt)).squeeze(-1) # [b, n]

def forward_cmp(self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> float:
"""
for infer, no batch dim
"""
assert omics_x[0].dim() == 1 and smi_tgt.dim() == 2
out = self.forward(omics, omics_x, smi_tgt) # [2]
return (out[0] - out[1]).item()
86 changes: 86 additions & 0 deletions dooc/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import typing
from torch import nn
from moltx.models import AdaMR, AdaMR2
from dooc.nets.drugcell import Drugcell
from dooc.nets.vaeomic import VAEOmics


"""
Expand Down Expand Up @@ -209,3 +211,87 @@ def forward(
mut_out = self._forward_mut(mut_x)
smi_out = self._forward_smi(smi_tgt)
return self.cross_attn(smi_out, mut_out) # [b, n, dmodel]


class _VAEOmicsAdamr2(nn.Module):

def __init__(self, omics_conf, smi_conf) -> None:
super().__init__()
self.omics_conf = omics_conf
self.smi_conf = smi_conf

self.omics_encoder = VAEOmics(omics_conf)
self.smi_encoder = AdaMR2(smi_conf)

def load_ckpt(self, *ckpt_files: str) -> None:
self.load_state_dict(
torch.load(ckpt_files[0], map_location=torch.device("cpu"))
)

def load_pretrained_ckpt(self, omics_ckpt: str, smi_ckpt: str, freeze_omics: bool = False, freeze_smi: bool = False) -> None:
self.omics_encoder.load_ckpt(omics_ckpt)
self.smi_encoder.load_ckpt(smi_ckpt)
if freeze_smi:
self.smi_encoder.requires_grad_(False)
if freeze_omics:
self.omics_encoder.requires_grad_(False)


class VAEOmicsAdamr2OmicsSmiXattn(_VAEOmicsAdamr2):
def __init__(self, omics_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(omics_conf, smi_conf)
d_model = self.smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
"""
omics_x: [b, omics_seqlen]
smi_tgt: [b, smi_seqlen]
"""
omics_out = self.omics_encoder.forward_encoder(omics, *omics_x).unsqueeze(-2) # [b, 1, dmodel]
smi_out = self.smi_encoder.forward_feature(smi_tgt).unsqueeze(-2) # [b, 1, dmodel]
return self.cross_attn(smi_out, omics_out).squeeze(-2) # [b, dmodel]


class VAEOmicsAdamr2OmicsSmisAdd(_VAEOmicsAdamr2):
def _forward_omics(self, omics: dict, omics_x: typing.Sequence[torch.Tensor]) -> torch.Tensor:
"""
omics_x: [b, omics_seqlen]
out: [b, 1, dmodel]
"""
return self.omics_encoder.forward_encoder(omics, *omics_x).unsqueeze(-2)

def _forward_smi(self, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
smi_tgt: [b, n, smi_seqlen]
out: [b, n, dmodel]
"""
batched = smi_tgt.dim() == 3
if batched:
n = smi_tgt.shape[1]
smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1])
out = self.smi_encoder.forward_feature(smi_tgt)
return out.reshape(-1, n, out.shape[-1])
return self.smi_encoder.forward_feature(smi_tgt)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
omics_out = self._forward_omics(omics, omics_x)
smi_out = self._forward_smi(smi_tgt)
return smi_out + omics_out # [b, n, dmodel]


class VAEOmicsAdamr2OmicsSmisXattn(VAEOmicsAdamr2OmicsSmisAdd):
def __init__(self, omics_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(omics_conf, smi_conf)
d_model = smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
omics_out = self._forward_omics(omics, omics_x)
smi_out = self._forward_smi(smi_tgt)
return self.cross_attn(smi_out, omics_out) # [b, n, dmodel]
Loading

0 comments on commit c48b8d5

Please sign in to comment.