Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: vae omics net #16

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 99 additions & 2 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from moltx import tokenizers, datasets


class _SmiMutBase:
class _SmiBase:
def __init__(self, smi_tokenizer: tokenizers.MoltxTokenizer, device: torch.device = torch.device("cpu")) -> None:
self.smi_ds = datasets.Base(smi_tokenizer, device)
self.device = device
Expand All @@ -22,7 +22,7 @@ def _smi_tokenize(self, smiles: typing.Sequence[str], seq_len: int = None) -> to
"""


class _DrugcellAdamrBase(_SmiMutBase):
class _DrugcellAdamrBase(_SmiBase):
"""Base datasets, convert smiles and genes to torch.Tensor."""

def __init__(
Expand Down Expand Up @@ -141,6 +141,72 @@ def __call__(
return mut_x, smi_tgt, out


class _VAEOmicsAdamr2Base(_SmiBase):
"""Base datasets, convert smiles and omics to torch.Tensor."""

def __init__(
self,
smi_tokenizer: tokenizers.MoltxTokenizer,
device: torch.device = torch.device("cpu")
) -> None:
super().__init__(smi_tokenizer, device)
self.smi_tokenizer = smi_tokenizer

def _smi_tokens(
self,
smiles: typing.Sequence[str],
seq_len: int = 200,
) -> torch.Tensor:
tgt = self._smi_tokenize(
[f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len)
return tgt

def _omics_tokens(self, omics_seq: typing.Sequence[list]) -> typing.Sequence[torch.Tensor]:
return [torch.tensor(omic, device=self.device) for omic in omics_seq]

def _out(self, values: typing.Sequence[float]) -> torch.Tensor:
return torch.tensor(values, device=self.device)


class _VAEOmicsAdamr2OmicsSmi(_VAEOmicsAdamr2Base):
def __call__(
self,
omics_seq: typing.Sequence[list],
smis: typing.Sequence[str],
vals: typing.Sequence[float],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
assert len(smis) == len(vals) and len(omics_seq[0]) == len(vals)
omics_x = self._omics_tokens(omics_seq)
smi_tgt = self._smi_tokens(smis, seq_len)
out = self._out(vals).unsqueeze(-1)
return omics_x, smi_tgt, out


class _VAEOmicsAdamr2OmicsSmis(_VAEOmicsAdamr2Base):
def __call__(
self,
omics_seq: typing.Sequence[list],
lsmis: typing.Sequence[typing.Sequence[str]],
lvals: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
"""
omics_seq: [omic1, omic2, ...](omics type len) omic1: [omic11, omic12, ...](batch size) omics1_1: [gene1, gene2, ...]
bsmiles: [[smi11, smi12], [smi21, smi22], ...]
bvlaues: [[val11, val12], [val21, val22], ...]
"""
assert len(lsmis) == len(lvals) and len(omics_seq[0]) == len(lvals)
omics_x = self._omics_tokens(omics_seq)
batchlen = len(lsmis)
listlen = len(lsmis[0])
smiles = [smi for bsmi in lsmis for smi in bsmi]
smi_tgt = self._smi_tokens(smiles, seq_len)
smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_tgt.size(-1))
out = self._out(lvals)
return omics_x, smi_tgt, out


"""
Mutations(Individual Sample) and Smiles Interaction

Expand Down Expand Up @@ -170,3 +236,34 @@ def __call__(

class MutSmisListwiseRank(_DrugcellAdamr2MutSmis):
pass


"""
Omicsations(Individual Sample) and Smiles Interaction

OmicsSmiReg
OmicsSmis{Pair/List}wiseRank
OmicssSmi{Pair/List}wiseRank
"""


class OmicsSmiReg(_VAEOmicsAdamr2OmicsSmi):
pass


class OmicsSmisPairwiseRank(_VAEOmicsAdamr2OmicsSmis):
def __call__(
self,
omics_seq: typing.Sequence[list],
lsmiles: typing.Sequence[typing.Sequence[str]],
lvalues: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
omics_x, smi_tgt, rout = super().__call__(omics_seq, lsmiles, lvalues, seq_len)
out = torch.zeros(rout.size(0), dtype=rout.dtype, device=self.device)
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1.0
return omics_x, smi_tgt, out


class OmicsSmisListwiseRank(_VAEOmicsAdamr2OmicsSmis):
pass
116 changes: 116 additions & 0 deletions dooc/loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing
import torch
import torch.nn as nn
from torch.distributions import Normal, kl_divergence


class ListNetLoss(nn.Module):
Expand All @@ -11,3 +13,117 @@ def __init__(self, reduction: str = 'mean') -> None:
def forward(self, predict: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
out = - (target.softmax(dim=-1) * predict.log_softmax(dim=-1))
return getattr(out, self.reduction)()


class VAEOmicsLoss(nn.Module):

def __init__(self, loss_type: str, omics_num: int) -> None:
super().__init__()
self.loss_type = loss_type
self.k = omics_num
self.kl_loss_weight = 0.1 # TODO: 待定

def forward(self, x: typing.Sequence, out_x: typing.Sequence, **kwargs) -> float:
return getattr(self, f"_forward_{self.loss_type}")(x, out_x, **kwargs)

def _forward_generate(self, x: typing.Sequence, out_x: typing.Sequence, labels: torch.Tensor, **kwargs) -> typing.Sequence:
# out_encoder, out_self, out_cross, out_dsc, out_cl = out_x
out_encoder, out_self, out_cross, out_dsc = out_x
self_loss = self._calc_self_vae_loss(x, out_self)
cross_loss, cross_infer_dsc_loss = self._calc_cross_vae_loss(x, out_cross, out_encoder)
cross_infer_loss = self._calc_cross_infer_loss(out_encoder)
dsc_loss = self._calc_dsc_loss(out_dsc)
# contrastive_loss = self._calc_contrastive_loss(out_cl, labels)
generate_loss = (
self_loss + 0.1 * (cross_loss + cross_infer_loss * cross_infer_loss)
- (dsc_loss + cross_infer_dsc_loss) * 0.01 # + contrastive_loss
)
# return generate_loss, self_loss, cross_loss, cross_infer_loss, dsc_loss
return generate_loss

def _forward_dsc(self, x: typing.Sequence, out_x: typing.Sequence, **kwargs) -> float:
out_encoder, out_cross, out_dsc = out_x
_, cross_infer_dsc_loss = self._calc_cross_vae_loss(x, out_cross, out_encoder)
dsc_loss = self._calc_dsc_loss(out_dsc)
return cross_infer_dsc_loss + dsc_loss

def _calc_self_vae_loss(self, x: typing.Sequence, out_self: typing.Sequence) -> float:
loss = 0.
for i, v in enumerate(out_self):
recon_omics, mu, log_var = v
loss += (self.kl_loss_weight * self._kl_loss(mu, log_var, 1.0) + self.reconstruction_loss(x[i], recon_omics))
return loss

def _calc_cross_vae_loss(self, x: typing.Sequence, out_cross: typing.Sequence, out_encoder: typing.Sequence) -> typing.Sequence:
batch_size = x[0].size(0)
device = x[0].device
cross_elbo, cross_infer_loss, cross_kl_loss, cross_dsc_loss = 0, 0, 0, 0
for i, v in enumerate(out_cross):
_, real_mu, real_log_var = out_encoder[i][i]
reconstruct_omic, poe_mu, poe_log_var, pred_real_modal, pred_infer_modal = v
cross_elbo += (
self.kl_loss_weight * self._kl_loss(poe_mu, poe_log_var, 1.0)
+ self.reconstruction_loss(x[i], reconstruct_omic)
)
cross_infer_loss += self.reconstruction_loss(real_mu, poe_mu)
cross_kl_loss += self._kl_divergence(poe_mu, real_mu, poe_log_var, real_log_var)

real_modal = torch.tensor([1 for _ in range(batch_size)]).to(device)
infer_modal = torch.tensor([0 for _ in range(batch_size)]).to(device)
cross_dsc_loss += torch.nn.CrossEntropyLoss()(pred_real_modal, real_modal)
cross_dsc_loss += torch.nn.CrossEntropyLoss()(pred_infer_modal, infer_modal)

cross_dsc_loss = cross_dsc_loss.sum(0) / (len(out_cross) * batch_size)
return cross_elbo + cross_infer_loss + self.kl_loss_weight * cross_kl_loss, cross_dsc_loss

def _calc_cross_infer_loss(self, out_encoder: typing.Sequence) -> float:
infer_loss = 0
for i in range(self.k):
_, latent_mu, _ = out_encoder[i][i]
for j in range(self.k):
if i == j:
continue
_, latent_mu_infer, _ = out_encoder[j][i]
infer_loss += self.reconstruction_loss(latent_mu_infer, latent_mu)
return infer_loss / self.k

def _calc_dsc_loss(self, out_dsc: typing.Sequence) -> float:
dsc_loss = 0
batch_size = out_dsc[0].size(0)
for i in range(self.k):
real_modal = torch.tensor([i for _ in range(batch_size)])
dsc_loss += torch.nn.CrossEntropyLoss()(out_dsc[i], real_modal.to(out_dsc[i].device))
return dsc_loss.sum(0) / (self.k * batch_size)

def _calc_contrastive_loss(self, out_cl: typing.Sequence, labels: torch.Tensor) -> float:
margin = 1.0
distances = torch.cdist(out_cl, out_cl)

labels_matrix = labels.view(-1, 1) == labels.view(1, -1)

positive_pair_distances = distances * labels_matrix.float()
negative_pair_distances = distances * (1 - labels_matrix.float())

positive_loss = positive_pair_distances.sum() / labels_matrix.float().sum()
negative_loss = torch.nn.ReLU()(margin - negative_pair_distances).sum() / (1 - labels_matrix.float()).sum()

return positive_loss + negative_loss

def _kl_loss(self, mu, logvar, beta):
# KL divergence loss
kld_1 = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return beta * kld_1

def _kl_divergence(self, mu1, mu2, log_sigma1, log_sigma2):
p = Normal(mu1, torch.exp(log_sigma1))
q = Normal(mu2, torch.exp(log_sigma2))

# 计算KL损失
kl_loss = kl_divergence(p, q).mean()
return kl_loss

def reconstruction_loss(self, recon_x, x):
# batch_size = recon_x.size(0)
mse = nn.MSELoss() # reduction='sum'
recons_loss = mse(recon_x, x) # / batch_size
return recons_loss
33 changes: 32 additions & 1 deletion dooc/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import torch
import typing
from moltx import nets as mnets
from moltx import models as mmodels
from dooc import nets as dnets
from dooc.nets import heads, drugcell
from dooc.nets import heads, drugcell, vaeomic


"""
Expand Down Expand Up @@ -42,3 +43,33 @@ def forward_cmp(self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> float:
assert mut_x.dim() == 1 and smi_tgt.dim() == 2
out = self.forward(mut_x, smi_tgt) # [2]
return (out[0] - out[1]).item()


class OmicsSmiReg(dnets.VAEOmicsAdamr2OmicsSmisXattn):

def __init__(self, omics_conf: vaeomic.VAEOmicsConfig = dnets.VAEOmics.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(omics_conf, smi_conf)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
return self.reg(super().forward(omics, omics_x, smi_tgt)) # [b, 1]


class OmicsSmisRank(dnets.VAEOmicsAdamr2OmicsSmiXattn):

def __init__(self, omics_conf: vaeomic.VAEOmicsConfig = dnets.VAEOmics.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(omics_conf, smi_conf)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
return self.reg(super().forward(omics, omics_x, smi_tgt)).squeeze(-1) # [b, n]

def forward_cmp(self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> float:
"""
for infer, no batch dim
"""
assert omics_x[0].dim() == 1 and smi_tgt.dim() == 2
out = self.forward(omics, omics_x, smi_tgt) # [2]
return (out[0] - out[1]).item()
86 changes: 86 additions & 0 deletions dooc/nets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import typing
from torch import nn
from moltx.models import AdaMR, AdaMR2
from dooc.nets.drugcell import Drugcell
from dooc.nets.vaeomic import VAEOmics


"""
Expand Down Expand Up @@ -209,3 +211,87 @@ def forward(
mut_out = self._forward_mut(mut_x)
smi_out = self._forward_smi(smi_tgt)
return self.cross_attn(smi_out, mut_out) # [b, n, dmodel]


class _VAEOmicsAdamr2(nn.Module):

def __init__(self, omics_conf, smi_conf) -> None:
super().__init__()
self.omics_conf = omics_conf
self.smi_conf = smi_conf

self.omics_encoder = VAEOmics(omics_conf)
self.smi_encoder = AdaMR2(smi_conf)

def load_ckpt(self, *ckpt_files: str) -> None:
self.load_state_dict(
torch.load(ckpt_files[0], map_location=torch.device("cpu"))
)

def load_pretrained_ckpt(self, omics_ckpt: str, smi_ckpt: str, freeze_omics: bool = False, freeze_smi: bool = False) -> None:
self.omics_encoder.load_ckpt(omics_ckpt)
self.smi_encoder.load_ckpt(smi_ckpt)
if freeze_smi:
self.smi_encoder.requires_grad_(False)
if freeze_omics:
self.omics_encoder.requires_grad_(False)


class VAEOmicsAdamr2OmicsSmiXattn(_VAEOmicsAdamr2):
def __init__(self, omics_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(omics_conf, smi_conf)
d_model = self.smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
"""
omics_x: [b, omics_seqlen]
smi_tgt: [b, smi_seqlen]
"""
omics_out = self.omics_encoder.forward_encoder(omics, *omics_x).unsqueeze(-2) # [b, 1, dmodel]
smi_out = self.smi_encoder.forward_feature(smi_tgt).unsqueeze(-2) # [b, 1, dmodel]
return self.cross_attn(smi_out, omics_out).squeeze(-2) # [b, dmodel]


class VAEOmicsAdamr2OmicsSmisAdd(_VAEOmicsAdamr2):
def _forward_omics(self, omics: dict, omics_x: typing.Sequence[torch.Tensor]) -> torch.Tensor:
"""
omics_x: [b, omics_seqlen]
out: [b, 1, dmodel]
"""
return self.omics_encoder.forward_encoder(omics, *omics_x).unsqueeze(-2)

def _forward_smi(self, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
smi_tgt: [b, n, smi_seqlen]
out: [b, n, dmodel]
"""
batched = smi_tgt.dim() == 3
if batched:
n = smi_tgt.shape[1]
smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1])
out = self.smi_encoder.forward_feature(smi_tgt)
return out.reshape(-1, n, out.shape[-1])
return self.smi_encoder.forward_feature(smi_tgt)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
omics_out = self._forward_omics(omics, omics_x)
smi_out = self._forward_smi(smi_tgt)
return smi_out + omics_out # [b, n, dmodel]


class VAEOmicsAdamr2OmicsSmisXattn(VAEOmicsAdamr2OmicsSmisAdd):
def __init__(self, omics_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(omics_conf, smi_conf)
d_model = smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, omics: dict, omics_x: typing.Sequence[torch.Tensor], smi_tgt: torch.Tensor) -> torch.Tensor:
omics_out = self._forward_omics(omics, omics_x)
smi_out = self._forward_smi(smi_tgt)
return self.cross_attn(smi_out, omics_out) # [b, n, dmodel]
Loading
Loading