Skip to content

Commit

Permalink
fix: rewrite readme
Browse files Browse the repository at this point in the history
  • Loading branch information
zzzseeu committed May 31, 2024
1 parent 21e1bcd commit 91911ba
Showing 1 changed file with 67 additions and 19 deletions.
86 changes: 67 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,92 @@

```python
# Regression train
from moltx import tokenizers
from dooc import models, datasets, nets
import random
import torch
from torch import nn
import torch.optim as optim

from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig

tk = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Prediction)
ds = datasets.MutSmiXAttention(tokenizer=tk, device=torch.device('cpu'))
from dooc import models, datasets


# datasets
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
ds = datasets.MutSmi(tokenizer)
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"]
mutations = [[1, 0, 0, ...], [1, 0, 1, ...]]
# e.g.
# import random
# [random.choice([0, 1]) for _ in range(3008)]
mutations = [[random.choice([0, 1]) for _ in range(3008)],
[random.choice([0, 1]) for _ in range(3008)]]
# mutations contains 0/1 encoding information of the genome
values = [0.85, 0.78]
smiles_src, smiles_tgt, mutations_src, out = ds(smiles, mutations, values)

model = models.MutSmiXAttention()
# MutSmiFullConnection train
model = models.MutSmiFullConnection()
model.load_pretrained_ckpt('/path/to/drugcell.ckpt', '/path/to/moltx.ckpt')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),
lr=1e-04,
foreach=False
)

optimizer.zero_grad()
pred = model(smiles_src, smiles_tgt, mutations_src)
loss = mse_loss(pred, out)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), '/path/to/mutsmifullconnection.ckpt')

crt = nn.MSELoss()
# MutSmiXAttention train
model = models.MutSmiXAttention()
model.load_pretrained_ckpt('/path/to/drugcell.ckpt', '/path/to/moltx.ckpt')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),
lr=1e-04,
foreach=False
)

optim.zero_grad()
optimizer.zero_grad()
pred = model(smiles_src, smiles_tgt, mutations_src)
loss = crt(pred, out)
loss = mse_loss(pred, out)
loss.backward()
optim.step()
optimizer.step()

torch.save(model.state_dict(), '/path/to/mutsmixattention.ckpt')
```

### Inference

```python
import random
from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig
from dooc import pipelines, models
# dooc
model = models.MutSmiXAttention()
model.load_ckpt('/path/to/mutsmixattention.ckpt')
pipeline = pipelines.MutSmiXAttention()
pipeline([1, 0, 0, ...], "C=CC=CC=C")
# 0.85


# MutSmiFullConnection
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
model = models.MutSmiFullConnection()
model.load_ckpt('/path/to/mutsmifullconnection.ckpt')
pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer, model=model)
mutations = [random.choice([0, 1]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutations, smiles) # e.g. 0.85

# MutSmiXAttention
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
model = models.MutSmiXAttention()
model.load_ckpt('/path/to/mutsmixattention.ckpt')
pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer, model=model)
mutations = [random.choice([0, 1]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutations, smiles) # e.g. 0.85
```

0 comments on commit 91911ba

Please sign in to comment.