Skip to content

Commit

Permalink
Merge pull request js-ish#7 from zzzseeu/fix-readme
Browse files Browse the repository at this point in the history
fix: rewrite readme
  • Loading branch information
yandy authored May 31, 2024
2 parents 21e1bcd + 91911ba commit cbe3e8d
Showing 1 changed file with 67 additions and 19 deletions.
86 changes: 67 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,92 @@

```python
# Regression train
from moltx import tokenizers
from dooc import models, datasets, nets
import random
import torch
from torch import nn
import torch.optim as optim

from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig

tk = tokenizers.MoltxTokenizer.from_pretrain(models.AdaMRTokenizerConfig.Prediction)
ds = datasets.MutSmiXAttention(tokenizer=tk, device=torch.device('cpu'))
from dooc import models, datasets


# datasets
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
ds = datasets.MutSmi(tokenizer)
smiles = ["c1cccc1c", "CC[N+](C)(C)Cc1ccccc1Br"]
mutations = [[1, 0, 0, ...], [1, 0, 1, ...]]
# e.g.
# import random
# [random.choice([0, 1]) for _ in range(3008)]
mutations = [[random.choice([0, 1]) for _ in range(3008)],
[random.choice([0, 1]) for _ in range(3008)]]
# mutations contains 0/1 encoding information of the genome
values = [0.85, 0.78]
smiles_src, smiles_tgt, mutations_src, out = ds(smiles, mutations, values)

model = models.MutSmiXAttention()
# MutSmiFullConnection train
model = models.MutSmiFullConnection()
model.load_pretrained_ckpt('/path/to/drugcell.ckpt', '/path/to/moltx.ckpt')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),
lr=1e-04,
foreach=False
)

optimizer.zero_grad()
pred = model(smiles_src, smiles_tgt, mutations_src)
loss = mse_loss(pred, out)
loss.backward()
optimizer.step()

torch.save(model.state_dict(), '/path/to/mutsmifullconnection.ckpt')

crt = nn.MSELoss()
# MutSmiXAttention train
model = models.MutSmiXAttention()
model.load_pretrained_ckpt('/path/to/drugcell.ckpt', '/path/to/moltx.ckpt')
mse_loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),
lr=1e-04,
foreach=False
)

optim.zero_grad()
optimizer.zero_grad()
pred = model(smiles_src, smiles_tgt, mutations_src)
loss = crt(pred, out)
loss = mse_loss(pred, out)
loss.backward()
optim.step()
optimizer.step()

torch.save(model.state_dict(), '/path/to/mutsmixattention.ckpt')
```

### Inference

```python
import random
from moltx import tokenizers as tkz
from moltx.models import AdaMRTokenizerConfig
from dooc import pipelines, models
# dooc
model = models.MutSmiXAttention()
model.load_ckpt('/path/to/mutsmixattention.ckpt')
pipeline = pipelines.MutSmiXAttention()
pipeline([1, 0, 0, ...], "C=CC=CC=C")
# 0.85


# MutSmiFullConnection
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
model = models.MutSmiFullConnection()
model.load_ckpt('/path/to/mutsmifullconnection.ckpt')
pipeline = pipelines.MutSmiFullConnection(smi_tokenizer=tokenizer, model=model)
mutations = [random.choice([0, 1]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutations, smiles) # e.g. 0.85

# MutSmiXAttention
tokenizer = tkz.MoltxTokenizer.from_pretrain(
conf=AdaMRTokenizerConfig.Prediction
)
model = models.MutSmiXAttention()
model.load_ckpt('/path/to/mutsmixattention.ckpt')
pipeline = pipelines.MutSmiXAttention(smi_tokenizer=tokenizer, model=model)
mutations = [random.choice([0, 1]) for _ in range(3008)]
smiles = "CC[N+](C)(C)Cc1ccccc1Br"
predict = pipeline(mutations, smiles) # e.g. 0.85
```

0 comments on commit cbe3e8d

Please sign in to comment.