From 91911baf6e0ed2b60c50dbeadc75a33afdebc609 Mon Sep 17 00:00:00 2001 From: zzzseeu Date: Fri, 31 May 2024 02:22:43 +0000 Subject: [PATCH] fix: rewrite readme --- README.md | 86 +++++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 8d18e68..9e0a882 100644 --- a/README.md +++ b/README.md @@ -8,30 +8,60 @@ ```python # Regression train -from moltx import tokenizers -from dooc import models, datasets, nets +import random +import torch +from torch import nn +import torch.optim as optim +from moltx import tokenizers as tkz +from moltx.models import AdaMRTokenizerConfig -tk = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Prediction) -ds = datasets.MutSmiXAttention(tokenizer=tk, device=torch.device('cpu')) +from dooc import models, datasets + + +# datasets +tokenizer = tkz.MoltxTokenizer.from_pretrain( + conf=AdaMRTokenizerConfig.Prediction + ) +ds = datasets.MutSmi(tokenizer) smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"] -mutations = [[1, 0, 0, ...], [1, 0, 1, ...]] -# e.g. -# import random -# [random.choice([0, 1]) for _ in range(3008)] +mutations = [[random.choice([0, 1]) for _ in range(3008)], + [random.choice([0, 1]) for _ in range(3008)]] +# mutations contains 0/1 encoding information of the genome values = [0.85, 0.78] smiles_src, smiles_tgt, mutations_src, out = ds(smiles, mutations, values) -model = models.MutSmiXAttention() +# MutSmiFullConnection train +model = models.MutSmiFullConnection() model.load_pretrained_ckpt('/path/to/drugcell.ckpt', '/path/to/moltx.ckpt') +mse_loss = nn.MSELoss() +optimizer = optim.Adam(model.parameters(), + lr=1e-04, + foreach=False + ) + +optimizer.zero_grad() +pred = model(smiles_src, smiles_tgt, mutations_src) +loss = mse_loss(pred, out) +loss.backward() +optimizer.step() + +torch.save(model.state_dict(), '/path/to/mutsmifullconnection.ckpt') -crt = nn.MSELoss() +# MutSmiXAttention train +model = models.MutSmiXAttention() +model.load_pretrained_ckpt('/path/to/drugcell.ckpt', '/path/to/moltx.ckpt') +mse_loss = nn.MSELoss() +optimizer = optim.Adam(model.parameters(), + lr=1e-04, + foreach=False + ) -optim.zero_grad() +optimizer.zero_grad() pred = model(smiles_src, smiles_tgt, mutations_src) -loss = crt(pred, out) +loss = mse_loss(pred, out) loss.backward() -optim.step() +optimizer.step() torch.save(model.state_dict(), '/path/to/mutsmixattention.ckpt') ``` @@ -39,13 +69,31 @@ torch.save(model.state_dict(), '/path/to/mutsmixattention.ckpt') ### Inference ```python +import random +from moltx import tokenizers as tkz +from moltx.models import AdaMRTokenizerConfig from dooc import pipelines, models -# dooc -model = models.MutSmiXAttention() -model.load_ckpt('/path/to/mutsmixattention.ckpt') -pipeline = pipelines.MutSmiXAttention() -pipeline([1, 0, 0, ...], "C=CC=CC=C") -# 0.85 +# MutSmiFullConnection +tokenizer = tkz.MoltxTokenizer.from_pretrain( + conf=AdaMRTokenizerConfig.Prediction + ) +model = models.MutSmiFullConnection() +model.load_ckpt('/path/to/mutsmifullconnection.ckpt') +pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer, model=model) +mutations = [random.choice([0, 1]) for _ in range(3008)] +smiles = "CC[N+](C)(C)Cc1ccccc1Br" +predict = pipeline(mutations, smiles) # e.g. 0.85 + +# MutSmiXAttention +tokenizer = tkz.MoltxTokenizer.from_pretrain( + conf=AdaMRTokenizerConfig.Prediction + ) +model = models.MutSmiXAttention() +model.load_ckpt('/path/to/mutsmixattention.ckpt') +pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer, model=model) +mutations = [random.choice([0, 1]) for _ in range(3008)] +smiles = "CC[N+](C)(C)Cc1ccccc1Br" +predict = pipeline(mutations, smiles) # e.g. 0.85 ```