Skip to content

Commit

Permalink
feat: multi omics
Browse files Browse the repository at this point in the history
  • Loading branch information
origyZ committed Jul 19, 2024
1 parent 3d96167 commit 804e539
Show file tree
Hide file tree
Showing 8 changed files with 540 additions and 6 deletions.
109 changes: 107 additions & 2 deletions dooc/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from moltx import tokenizers, datasets


class _SmiMutBase:
class _SmiBase:
def __init__(self, smi_tokenizer: tokenizers.MoltxTokenizer, device: torch.device = torch.device("cpu")) -> None:
self.smi_ds = datasets.Base(smi_tokenizer, device)
self.device = device
Expand All @@ -22,7 +22,7 @@ def _smi_tokenize(self, smiles: typing.Sequence[str], seq_len: int = None) -> to
"""


class _DrugcellAdamrBase(_SmiMutBase):
class _DrugcellAdamrBase(_SmiBase):
"""Base datasets, convert smiles and genes to torch.Tensor."""

def __init__(
Expand Down Expand Up @@ -141,6 +141,87 @@ def __call__(
return mut_x, smi_tgt, out


class _MultiOmicsAdamr2Base(_SmiBase):
"""Base datasets, convert smiles and multi omics to torch.Tensor."""

def __init__(
self,
smi_tokenizer: tokenizers.MoltxTokenizer,
device: torch.device = torch.device("cpu")
) -> None:
super().__init__(smi_tokenizer, device)
self.smi_tokenizer = smi_tokenizer

def _smi_tokens(
self,
smiles: typing.Sequence[str],
seq_len: int = 200,
) -> torch.Tensor:
tgt = self._smi_tokenize(
[f"{self.smi_tokenizer.BOS}{smi}{self.smi_tokenizer.EOS}" for smi in smiles], seq_len)
return tgt

def _multi_omics_tokens(self, muts: typing.Sequence[list], rnas: typing.Sequence[list], pathways: typing.Sequence[list]) -> torch.Tensor:
mut_x = torch.tensor(muts, device=self.device)
rna_x = torch.tensor(rnas, device=self.device)
pathway_x = torch.tensor(pathways, device=self.device)
return mut_x, rna_x, pathway_x

def _out(self, values: typing.Sequence[float]) -> torch.Tensor:
return torch.tensor(values, device=self.device)


class _MultiOmicsAdamr2MutSmi(_MultiOmicsAdamr2Base):
def __call__(
self,
muts: typing.Sequence[list],
rnas: typing.Sequence[list],
pathways: typing.Sequence[list],
smis: typing.Sequence[str],
vals: typing.Sequence[float],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
assert (
len(smis) == len(vals) and len(muts) == len(vals) and
len(rnas) == len(vals) and len(pathways) == len(vals)
)
mut_x, rna_x, pathway_x = self._multi_omics_tokens(muts, rnas, pathways)
smi_tgt = self._smi_tokens(smis, seq_len)
out = self._out(vals).unsqueeze(-1)
return mut_x, rna_x, pathway_x, smi_tgt, out


class _MultiOmicsAdamr2MutSmis(_MultiOmicsAdamr2Base):
def __call__(
self,
muts: typing.Sequence[list],
rnas: typing.Sequence[list],
pathways: typing.Sequence[list],
lsmis: typing.Sequence[typing.Sequence[str]],
lvals: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
"""
muts: [mut1, mut2, ...] mut1: [gene1, gene2, ...]
rnas: [rna1, rna2, ...] rna1: [gene1, gene2, ...]
pathways: [pathway1, pathway2, ...] pathway1: [gene1, gene2, ...]
bsmiles: [[smi11, smi12], [smi21, smi22], ...]
bvlaues: [[val11, val12], [val21, val22], ...]
"""
assert (
len(lsmis) == len(lvals) and len(muts) == len(lvals) and
len(rnas) == len(lvals) and len(pathways) == len(lvals)
)
mut_x, rna_x, pathway_x = self._multi_omics_tokens(muts, rnas, pathways)
batchlen = len(lsmis)
listlen = len(lsmis[0])
smiles = [smi for bsmi in lsmis for smi in bsmi]
smi_tgt = self._smi_tokens(smiles, seq_len)
smi_tgt = smi_tgt.reshape(batchlen, listlen, smi_tgt.size(-1))
out = self._out(lvals)
return mut_x, rna_x, pathway_x, smi_tgt, out


"""
Mutations(Individual Sample) and Smiles Interaction
Expand Down Expand Up @@ -170,3 +251,27 @@ def __call__(

class MutSmisListwiseRank(_DrugcellAdamr2MutSmis):
pass


class MultiOmicsSmisListwiseRank(_MultiOmicsAdamr2MutSmis):
pass


class MultiOmicsSmiReg(_MultiOmicsAdamr2MutSmi):
pass


class MultiOmicsSmisPairwiseRank(_MultiOmicsAdamr2MutSmis):
def __call__(
self,
muts: typing.Sequence[list],
rnas: typing.Sequence[list],
pathways: typing.Sequence[list],
lsmiles: typing.Sequence[typing.Sequence[str]],
lvalues: typing.Sequence[typing.Sequence[float]],
seq_len: int = 200
) -> typing.Tuple[torch.Tensor]:
mut_x, rna_x, pathway_x, smi_tgt, rout = super().__call__(muts, rnas, pathways, lsmiles, lvalues, seq_len)
out = torch.zeros(rout.size(0), dtype=rout.dtype, device=self.device)
out[(rout[:, 0] - rout[:, 1]) > 0.0] = 1.0
return mut_x, rna_x, pathway_x, smi_tgt, out
32 changes: 31 additions & 1 deletion dooc/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from moltx import nets as mnets
from moltx import models as mmodels
from dooc import nets as dnets
from dooc.nets import heads, drugcell
from dooc.nets import heads, drugcell, multiomics


"""
Expand Down Expand Up @@ -42,3 +42,33 @@ def forward_cmp(self, mut_x: torch.Tensor, smi_tgt: torch.Tensor) -> float:
assert mut_x.dim() == 1 and smi_tgt.dim() == 2
out = self.forward(mut_x, smi_tgt) # [2]
return (out[0] - out[1]).item()


class MultiOmicsSmiReg(dnets.MultiOmicsAdamr2MutSmiXattn):

def __init__(self, multi_omics_conf: multiomics.MultiOmicsConfig = dnets.MultiOmicsEncoder.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(multi_omics_conf, smi_conf)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(
self, mut_x: torch.Tensor, rna_x: torch.Tensor, pathway_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
return self.reg(super().forward(mut_x, rna_x, pathway_x, smi_tgt)) # [b, 1]


class MultiOmicsSmisRank(dnets.MultiOmicsAdamr2MutSmisXattn):

def __init__(self, multi_omics_conf: multiomics.MultiOmicsConfig = dnets.MultiOmicsEncoder.DEFAULT_CONFIG, smi_conf: mnets.AbsPosEncoderCausalConfig = mmodels.AdaMR2.CONFIG_LARGE) -> None:
super().__init__(multi_omics_conf, smi_conf)
self.reg = heads.RegHead(self.smi_conf.d_model)

def forward(
self, mut_x: torch.Tensor, rna_x: torch.Tensor, pathway_x: torch.Tensor, smi_tgt: torch.Tensor) -> torch.Tensor:
return self.reg(super().forward(mut_x, rna_x, pathway_x, smi_tgt)).squeeze(-1) # [b, n]

def forward_cmp(self, mut_x: torch.Tensor, rna_x: torch.Tensor, pathway_x: torch.Tensor, smi_tgt: torch.Tensor) -> float:
"""
for infer, no batch dim
"""
assert mut_x.dim() == 1 and rna_x.dim() == 1 and pathway_x.dim() == 1 and smi_tgt.dim() == 2
out = self.forward(mut_x, rna_x, pathway_x, smi_tgt) # [2]
return (out[0] - out[1]).item()
81 changes: 81 additions & 0 deletions dooc/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import nn
from moltx.models import AdaMR, AdaMR2
from dooc.nets.drugcell import Drugcell
from dooc.nets.multiomics import MultiOmicsEncoder


"""
Expand Down Expand Up @@ -209,3 +210,83 @@ def forward(
mut_out = self._forward_mut(mut_x)
smi_out = self._forward_smi(smi_tgt)
return self.cross_attn(smi_out, mut_out) # [b, n, dmodel]


class _MultiOmicsAdamr2(nn.Module):

def __init__(self, multi_omics_conf, smi_conf) -> None:
super().__init__()
self.multi_omics_conf = multi_omics_conf
self.smi_conf = smi_conf

self.multi_omics_encoder = MultiOmicsEncoder(multi_omics_conf)
self.smi_encoder = AdaMR2(smi_conf)

def load_ckpt(self, *ckpt_files: str) -> None:
self.load_state_dict(
torch.load(ckpt_files[0], map_location=torch.device("cpu"))
)

def load_pretrained_ckpt(self, mut_ckpt: str, smi_ckpt: str, freeze_mut: bool = False, freeze_smi: bool = False) -> None:
self.smi_encoder.load_ckpt(smi_ckpt)
if freeze_smi:
self.smi_encoder.requires_grad_(False)
self.multi_omics_encoder.load_pretrained_ckpt(mut_ckpt, freeze_mut)


class MultiOmicsAdamr2MutSmiXattn(_MultiOmicsAdamr2):
def __init__(self, multi_omics_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(multi_omics_conf, smi_conf)
d_model = self.smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def forward(
self, mut_x: torch.Tensor, rna_x: torch.Tensor, pathway_x: torch.Tensor, smi_tgt: torch.Tensor
) -> torch.Tensor:
"""
mut_x: [b, mut_seqlen]
rna_x: [b, rna_seqlen]
pathway_x: [b, pathway_seqlen]
smi_tgt: [b, smi_seqlen]
"""
multi_omics_out = self.multi_omics_encoder(mut_x, rna_x, pathway_x).unsqueeze(-2) # [b, 1, dmodel]
smi_out = self.smi_encoder.forward_feature(smi_tgt).unsqueeze(-2) # [b, 1, dmodel]
return self.cross_attn(smi_out, multi_omics_out).squeeze(-2) # [b, dmodel]


class MultiOmicsAdamr2MutSmisXattn(_MultiOmicsAdamr2):
def __init__(self, multi_omics_conf, smi_conf, nhead: int = 2, num_layers: int = 2) -> None:
super().__init__(multi_omics_conf, smi_conf)
d_model = smi_conf.d_model
layer = nn.TransformerDecoderLayer(d_model, nhead, batch_first=True)
self.cross_attn = nn.TransformerDecoder(layer, num_layers)

def _forward_multi_omics(self, mut_x: torch.Tensor, rna_x: torch.Tensor, pathway_x: torch.Tensor) -> torch.Tensor:
"""
mut_x: [b, mut_seqlen]
rna_x: [b, rna_seqlen]
pathway_x: [b, pathway_seqlen]
out: [b, 1, dmodel]
"""
return self.multi_omics_encoder(mut_x, rna_x, pathway_x).unsqueeze(-2)

def _forward_smi(self, smi_tgt: torch.Tensor) -> torch.Tensor:
"""
smi_tgt: [b, n, smi_seqlen]
out: [b, n, dmodel]
"""
batched = smi_tgt.dim() == 3
if batched:
n = smi_tgt.shape[1]
smi_tgt = smi_tgt.reshape(-1, smi_tgt.shape[-1])
out = self.smi_encoder.forward_feature(smi_tgt)
return out.reshape(-1, n, out.shape[-1])
return self.smi_encoder.forward_feature(smi_tgt)

def forward(
self, mut_x: torch.Tensor, rna_x: torch.Tensor, pathway_x: torch.Tensor, smi_tgt: torch.Tensor
) -> torch.Tensor:
multi_omics_out = self._forward_multi_omics(mut_x, rna_x, pathway_x) # [b, 1, dmodel]
smi_out = self._forward_smi(smi_tgt)
return self.cross_attn(smi_out, multi_omics_out) # [b, n, dmodel]
51 changes: 51 additions & 0 deletions dooc/nets/multiomics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn as nn
from dataclasses import dataclass
from dooc.nets.drugcell import Drugcell, DrugcellConfig


@dataclass
class MultiOmicsConfig:
d_model: int
mut_dim: int
rna_dim: int
pathway_dim: int
drug_cell_conf: DrugcellConfig


class MultiOmicsEncoder(nn.Module):

DEFAULT_CONFIG = MultiOmicsConfig(
d_model=768,
mut_dim=3008,
rna_dim=5537,
pathway_dim=3793,
drug_cell_conf=Drugcell.DEFAULT_CONFIG,
)

def __init__(self, conf: MultiOmicsConfig = DEFAULT_CONFIG) -> None:
super().__init__()
self.conf = conf

self.mut_encoder = Drugcell(self.conf.drug_cell_conf)
hidden_dim = self.conf.rna_dim + self.conf.drug_cell_conf.d_model + self.conf.pathway_dim
self.out_fc = nn.Linear(hidden_dim, self.conf.d_model)

def forward(
self, mut_x: torch.Tensor, rna_x: torch.Tensor, pathway_x: torch.Tensor
) -> torch.Tensor:
dim = mut_x.dim()
mut_x = mut_x.unsqueeze(0) if mut_x.dim() == 1 else mut_x
rna_x = rna_x.unsqueeze(0) if rna_x.dim() == 1 else rna_x
pathway_x = pathway_x.unsqueeze(0) if pathway_x.dim() == 1 else pathway_x
mut_out = self.mut_encoder(mut_x)
x = torch.concat((mut_out, rna_x, pathway_x), dim=1)
out = self.out_fc(x)
if dim == 1:
out = out.squeeze(0)
return out

def load_pretrained_ckpt(self, mut_ckpt: str, freeze_mut: bool = False) -> None:
self.mut_encoder.load_ckpt(mut_ckpt)
if freeze_mut:
self.mut_encoder.requires_grad_(False)
Loading

0 comments on commit 804e539

Please sign in to comment.