Skip to content

Commit

Permalink
Merge pull request #23 from Goosang-Yu/develop
Browse files Browse the repository at this point in the history
Develop search method in SpCas9 class
  • Loading branch information
Goosang-Yu authored Aug 9, 2023
2 parents 7781bac + 16bdf57 commit 018bbf3
Show file tree
Hide file tree
Showing 9 changed files with 828 additions and 948 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@ df_out = spcas.predict(list_target)
| 1 | CCTTCGTTTTTTTCCTTCTGCAGGAGGACA | CGTTTTTTTCCTTCTGCAGG | 2.253288 |
| 2 | CTTTCAAGAACTCTTCCACCTCCATGGTGT | CAAGAACTCTTCCACCTCCA | 53.43182 |

Alternatively, you can identify all possible SpCas9 target sites within an extensive gene sequence and obtain predictive scores.
```python
from genet.predict import SpCas9

# Put the whole sequence context that you want to find Cas9 target site.
gene = 'ttcagctctacgtctcctccgagagccgcttcaacaccctggccgagttggttcatcatcattcaacggtggccgacgggctcatcaccacgctccattatccagccccaaagcgcaacaagcccactgtctatggtgtgtcccccaactacgacaagtgggagatggaacgcacggacatcaccatgaagcacaagctgggcgggggccagtacggggaggtgtacgagggcgtgtggaagaaatacagcctgacggtggccgtgaagaccttgaaggtagg'

spcas = SpCas9()
df_out = spcas.search(gene)

>>> df_out.head()
```
| | Target | Spacer | Strand | Start | End | SpCas9 |
| - | ------------------------------ | -------------------- | ------ | ----- | --- | -------- |
| 0 | CCTCCGAGAGCCGCTTCAACACCCTGGCCG | CGAGAGCCGCTTCAACACCC | + | 15 | 45 | 67.39446 |
| 1 | GCCGCTTCAACACCCTGGCCGAGTTGGTTC | CTTCAACACCCTGGCCGAGT | + | 24 | 54 | 27.06508 |
| 2 | CCGAGTTGGTTCATCATCATTCAACGGTGG | GTTGGTTCATCATCATTCAA | + | 42 | 72 | 34.11356 |
| 3 | AGTTGGTTCATCATCATTCAACGGTGGCCG | GGTTCATCATCATTCAACGG | + | 45 | 75 | 76.43662 |
| 4 | TCATCATCATTCAACGGTGGCCGACGGGCT | CATCATTCAACGGTGGCCGA | + | 52 | 82 | 29.63767 |


## Tutorial 2: Predict SpCas9variants activity (by DeepSpCas9variants)
DeepSpCas9 is a prediction model developed to evaluate to indel frequency introduced by sgRNAs at specific target sites mediated by the SpCas9 PAM variants ([Kim et al. Nat.Biotechnol. 2020](https://doi.org/10.1038/s41587-020-0537-9)). The model was developed on tensorflow (version >= 2.6). Any dependent packages will be installed along with the GenET package.
Expand Down
8 changes: 6 additions & 2 deletions genet/models/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
'SpCas9': {
'type': 'DeepSpCas9',
'repo': 'Goosang-Yu/genet-models/main/genet_models',
'path': 'DeepSpCas9'
'path': 'DeepSpCas9',
'regex': {'+': '[ATGC]{25}GG[ATGC]{3}',
'-': '[ATGC]{3}CC[ATGC]{25}',},
},

# DeepSpCas9variants
'SpCas9-NG': {
'type': 'DeepSpCas9variants',
'repo': 'Goosang-Yu/genet-models/main/genet_models',
'path': 'DeepSpCas9variants/PAM_variant_NG'
'path': 'DeepSpCas9variants/PAM_variant_NG',
'regex': {'+': '[ATGC]{25}G[ATGC]{4}',
'-': '[ATGC]{4}C[ATGC]{25}',},
},
'SpCas9-NRCH': {
'type': 'DeepSpCas9variants',
Expand Down
10 changes: 5 additions & 5 deletions genet/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def __init__(self, model:str, effector:str, cell_type=None):

# 이 모델이 genet에서 지원하는 것인지 확인하기
try:
self.model_info = models.constants.dict_model_info[model_type]
self.info = models.constants.dict_model_info[model_type]
except:
print('[Warning] Not available model in GenET!')
sys.exit()

# model_dir:
self.model_dir = inspect.getfile(models).replace('__init__.py', '') + self.model_info['path']
self.model_dir = inspect.getfile(models).replace('__init__.py', '') + self.info['path']

# 만약 모델이 아직 다운로드 되지 않았다면, 다운로드 하기.
if not os.path.exists(self.model_dir):
Expand All @@ -42,9 +42,9 @@ def __init__(self, model:str, effector:str, cell_type=None):
dict_files = models.constants.dict_model_requests

self.download_from_github(
repo = self.model_info['repo'],
path = self.model_info['path'],
files = dict_files[self.model_info['type']],
repo = self.info['repo'],
path = self.info['path'],
files = dict_files[self.info['type']],
save_dir = self.model_dir,
)

Expand Down
59 changes: 51 additions & 8 deletions genet/predict/DeepSpCas9.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os, sys
import os, sys, regex
import numpy as np
import pandas as pd

Expand Down Expand Up @@ -26,19 +26,18 @@ def __init__(self, gpu_env=0):
'CCTTCGTTTTTTTCCTTCTGCAGGAGGACA',
'CTTTCAAGAACTCTTCCACCTCCATGGTGT',
]
>>> list_out = spcas9_score(list_target30)
>>> deepspcas9 = genet.predict.SpCas9()
>>> list_out = [2.80322408676147, 2.25273704528808, 53.4233360290527]
>>> spcas_score = deepspcas9(list_target30)
'''

# TensorFlow config
self.conf = tf.compat.v1.ConfigProto()
self.conf.gpu_options.allow_growth = True
os.environ['CUDA_VISIBLE_DEVICES'] = '%d' % gpu_env

model_info = LoadModel('DeepSpCas9', 'SpCas9')
model_dir = model_info.model_dir
self.model = LoadModel('DeepSpCas9', 'SpCas9')
model_dir = self.model.model_dir
best_model = 'PreTrain-Final-3-5-7-100-70-40-0.001-550-80-60'

self.model_save = '%s/%s' % (model_dir, best_model)
Expand All @@ -59,12 +58,12 @@ def predict(self, list_target30: list) -> pd.DataFrame:

with tf.compat.v1.Session(config=self.conf) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
model = DeepCas9(self.params[0], self.params[1], 80, 60, self.params[2])
interpreter = DeepCas9(self.params[0], self.params[1], 80, 60, self.params[2])

saver = tf.compat.v1.train.Saver()
saver.restore(sess, self.model_save)

list_score = Model_Finaltest(sess, seq_processed, model)
list_score = Model_Finaltest(sess, seq_processed, interpreter)

df_out = pd.DataFrame()
df_out['Target'] = list_target30
Expand All @@ -73,6 +72,50 @@ def predict(self, list_target30: list) -> pd.DataFrame:

return df_out

def search(self, seq: str) -> pd.DataFrame:
'''주어진 sequence 내에 가능한 모든 target sequence를 찾고,
그 정보와 예측 점수를 계산하는 method
'''

self.seq = seq.upper()
dict_re = self.model.info['regex']

seq_target, seq_guide, seq_strand, pos_start, pos_end = [], [], [], [], []

for strand in ['+', '-']:
ptn = dict_re[strand]

for re_idx in regex.finditer(ptn, self.seq, overlapped=True):
if strand == '+': match = re_idx.group()
else : match = reverse_complement(re_idx.group())

seq_target.append(match)
seq_guide.append(match[4:24])
seq_strand.append(strand)
pos_start.append(re_idx.start())
pos_end.append(re_idx.end())


seq_processed = preprocess_seq(seq_target, 30)

with tf.compat.v1.Session(config=self.conf) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
interpreter = DeepCas9(self.params[0], self.params[1], 80, 60, self.params[2])

saver = tf.compat.v1.train.Saver()
saver.restore(sess, self.model_save)

list_score = Model_Finaltest(sess, seq_processed, interpreter)

df_out = pd.DataFrame({'Target': seq_target,
'Spacer': seq_guide,
'Strand': seq_strand,
'Start' : pos_start,
'End' : pos_end,
'SpCas9': list_score})

return df_out


def Model_Finaltest(sess, TEST_X, model):
test_batch = 500
Expand Down
58 changes: 50 additions & 8 deletions genet/predict/DeepSpCas9Variants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Reference: https://blog.naver.com/PostView.naver?blogId=seodaewoo&logNo=222043145688&parentCategoryNo=&categoryNo=62&viewDate=&isShowPopularPosts=false&from=postView

import tensorflow as tf
import regex
import numpy as np
import pandas as pd
import tensorflow as tf

from genet.predict.PredUtils import *
from genet.models import LoadModel
Expand All @@ -20,13 +20,11 @@ def __init__(self, effector:str):
'CCTTCGTTTTTTTCCTTCTGCAGGAGGACA',
'CTTTCAAGAACTCTTCCACCTCCATGGTGT',
]
\n
'''

self.effector = effector

self.model_info = LoadModel('DeepSpCas9variants', effector)
self.model_dir = self.model_info.model_dir
self.effector = effector
self.model = LoadModel('DeepSpCas9variants', effector)
self.model_dir = self.model.model_dir


def predict(self, list_target30: list) -> pd.DataFrame:
Expand Down Expand Up @@ -71,7 +69,51 @@ def predict(self, list_target30: list) -> pd.DataFrame:

df_out[self.effector] = list_out

return df_out
return df_out

def search(self, seq: str) -> pd.DataFrame:
'''주어진 sequence 내에 가능한 모든 target sequence를 찾고,
그 정보와 예측 점수를 계산하는 method
'''

self.seq = seq.upper()
dict_re = self.model.info['regex']

seq_target, seq_guide, seq_strand, pos_start, pos_end = [], [], [], [], []

for strand in ['+', '-']:
ptn = dict_re[strand]

for re_idx in regex.finditer(ptn, self.seq, overlapped=True):
if strand == '+': match = re_idx.group()
else : match = reverse_complement(re_idx.group())

seq_target.append(match)
seq_guide.append(match[4:24])
seq_strand.append(strand)
pos_start.append(re_idx.start())
pos_end.append(re_idx.end())


seq_processed = preprocess_seq(seq_target, 30)

with tf.compat.v1.Session(config=self.conf) as sess:
sess.run(tf.compat.v1.global_variables_initializer())
interpreter = DeepCas9(self.params[0], self.params[1], 80, 60, self.params[2])

saver = tf.compat.v1.train.Saver()
saver.restore(sess, self.model_save)

list_score = Model_Finaltest(sess, seq_processed, interpreter)

df_out = pd.DataFrame({'Target': seq_target,
'Spacer': seq_guide,
'Strand': seq_strand,
'Start' : pos_start,
'End' : pos_end,
'SpCas9': list_score})

return df_out



Expand Down
11 changes: 10 additions & 1 deletion genet/predict/PredUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,13 @@ def preprocess_seq(data, seq_length):
print("[Input Error] Non-ATGC character " + data[l])
sys.exit()

return seq_onehot
return seq_onehot

def reverse_complement(sSeq):
dict_sBases = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'N': 'N', 'U': 'U', 'n': '',
'.': '.', '*': '*', 'a': 't', 'c': 'g', 'g': 'c', 't': 'a'}
list_sSeq = list(sSeq) # Turns the sequence in to a gigantic list
list_sSeq = [dict_sBases[sBase] for sBase in list_sSeq]
return ''.join(list_sSeq)[::-1]

# def END: reverse_complement
32 changes: 1 addition & 31 deletions genet/predict/functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# from genet.utils import *
import genet
import genet.utils
from genet.predict.PredUtils import *

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -116,28 +117,6 @@ def Model_Finaltest(sess, TEST_X, model):



def preprocess_seq(data, seq_length):

seq_onehot = np.zeros((len(data), 1, seq_length, 4), dtype=float)

for l in range(len(data)):
for i in range(seq_length):
try:
data[l][i]
except Exception:
print(data[l], i, seq_length, len(data))

if data[l][i] in "Aa": seq_onehot[l, 0, i, 0] = 1
elif data[l][i] in "Cc": seq_onehot[l, 0, i, 1] = 1
elif data[l][i] in "Gg": seq_onehot[l, 0, i, 2] = 1
elif data[l][i] in "Tt": seq_onehot[l, 0, i, 3] = 1
elif data[l][i] in "Xx": pass
elif data[l][i] in "Nn.": pass
else:
print("[Input Error] Non-ATGC character " + data[l])
sys.exit()

return seq_onehot

def spcas9_score_tf2(list_target30:list, gpu_env=0):
'''Tensorflow2 version function
Expand Down Expand Up @@ -245,15 +224,6 @@ def spcas9_score(list_target30:list, gpu_env=0):
return list_score


def reverse_complement(sSeq):
dict_sBases = {'A': 'T', 'C': 'G', 'G': 'C', 'T': 'A', 'N': 'N', 'U': 'U', 'n': '',
'.': '.', '*': '*', 'a': 't', 'c': 'g', 'g': 'c', 't': 'a'}
list_sSeq = list(sSeq) # Turns the sequence in to a gigantic list
list_sSeq = [dict_sBases[sBase] for sBase in list_sSeq]
return ''.join(list_sSeq)[::-1]

# def END: reverse_complement

def set_alt_position_window(sStrand, sAltKey, nAltIndex, nIndexStart, nIndexEnd, nAltLen):
if sStrand == '+':

Expand Down
Loading

0 comments on commit 018bbf3

Please sign in to comment.