Skip to content

zzzseeu/DOoC

 
 

Repository files navigation

DOoC

Train

import random
import torch
from torch import nn
import torch.optim as optim

from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig

from dooc import models, datasets

Regression

# Regression datasets
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
ds = datasets.MutSmiReg(smi_tokenizer=tokenizer)
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"]
mutations = [[random.choice([0, 1]) for _ in range(3008)]] * 2
# mutations contains 0/1 encoding information of the genome
values = [0.85, 0.78]
mut_x, smi_tgt, out = ds(mutations, smiles, values)

# Regression train
model = models.MutSmiReg()
model.load_pretrained_ckpt(
    mut_ckpt='path/to/drugcell.pt',
    smi_ckpt='path/to/moltx.ckpt'
    )
mse_loss = nn.MSELoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-04,
    foreach=False
    )
optimizer.zero_grad()
pred = model(mut_x=mut_x, smi_tgt=smi_tgt)
loss = mse_loss(pred, out)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), '/path/to/mutsmireg.ckpt')

Pairwise

# Pairwise datasets
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
ds = datasets.MutSmisPairwiseRank(smi_tokenizer=tokenizer)
smiles = [["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"],
          ["CC[N+](C)(C)Cc1ccccc1Br", "CN(Cc1oc2ccccc2c1C)C(=O)\C=C\c1cnc2NC(=O)CCc2c1"]]
mutations = [[random.choice([0, 1]) for _ in range(3008)]] * 2
# mutations contains 0/1 encoding information of the genome
values = [[0.85, 0.78]] * 2
mut_x, smi_tgt, out = ds(mutations, smiles, values)

# Pairwise train
model = models.MutSmisRank()
model.load_pretrained_ckpt(
    mut_ckpt='path/to/drugcell.pt',
    smi_ckpt='path/to/moltx.ckpt'
    )

# Pairwise loss
loss_func = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-04,
    foreach=False
    )

optimizer.zero_grad()
pred = model(mut_x=mut_x, smi_tgt=smi_tgt)
loss = loss_func(pred[:,0] - pred[:,1], out)
loss.backward()
optimizer.step()
torch.save(model.state_dict(), '/path/to/mutsmipairwise.ckpt')

Listwise

# Listwise datasets
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
ds = datasets.MutSmisListwiseRank(smi_tokenizer=tokenizer)
smiles = [["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br", "CN(Cc1oc2ccccc2c1C)C(=O)\C=C\c1cnc2NC(=O)CCc2c1"]] * 2
mutations = [[random.choice([0, 1]) for _ in range(3008)]] * 2
# mutations contains 0/1 encoding information of the genome
values = [[0.85, 0.78, 0.79]] * 2
mut_x, smi_tgt, out = ds(mutations, smiles, values)
mut_x, smi_tgt, out = mut_x.squeeze(0), smi_tgt.squeeze(0), out.squeeze(0)

# Listwise train
model = models.MutSmisRank()
model.load_pretrained_ckpt(
    mut_ckpt='path/to/drugcell.pt',
    smi_ckpt='path/to/moltx.ckpt'
    )

# Listwise loss
loss_func = dooc_list_loss.ListNetLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=1e-04,
    foreach=False
    )

optimizer.zero_grad()
pred = model(mut_x=mut_x, smi_tgt=smi_tgt)
loss = loss_func(pred, out)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), '/path/to/mutsmilistwise.ckpt')

Inference

import random
from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig
from dooc import pipelines, models

# Regression
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
model = models.MutSmiReg()
model.load_ckpt('/path/to/mutsmireg.ckpt')
pipeline = pipelines.MutSmiReg(
    smi_tokenizer=tokenizer, model=model
    )
mutations = [random.choice([0, 1]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mut=mutations, smi=smiles) # e.g. 0.85

# Rank
tokenizer = tkz.MoltxTokenizer.from_pretrain(
    conf=AdaMRTokenizerConfig.Prediction
    )
model = models.MutSmisRank()
model.load_ckpt('/path/to/mutsmirank.ckpt')
pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer, model=model)
mutations = [random.choice([0, 1]) for _ in range(3008)]
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br", "CN(Cc1oc2ccccc2c1C)C(=O)\C=C\c1cnc2NC(=O)CCc2c1"]
predict = pipeline(mut=mutations, smis=smiles) # e.g. ["CN(Cc1oc2ccccc2c1C)C(=O)\C=C\c1cnc2NC(=O)CCc2c1", "CC[N+](C)(C)Cc1ccccc1Br", "c1cccc1c"]

About

Digtal Organoid On Chips

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%