Skip to content

Commit

Permalink
Added qtwindow option to save GREMLIN SSM plot
Browse files Browse the repository at this point in the history
  • Loading branch information
niklases committed Nov 19, 2024
1 parent 306d540 commit 5d48bd6
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 12 deletions.
30 changes: 25 additions & 5 deletions gui/qt_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,21 @@ def __init__(
self.button_dca_inference_gremlin = QtWidgets.QPushButton("MSA optimization (GREMLIN)")
self.button_dca_inference_gremlin.setMinimumWidth(80)
self.button_dca_inference_gremlin.setToolTip(
"Generating DCA parameters using GREMLIN (\"MSA optimization\"), "
"you have to provide an MSA in FASTA or A2M format"
"Generating DCA parameters using GREMLIN (\"MSA optimization\"); "
"requires an MSA in FASTA or A2M format"
)
self.button_dca_inference_gremlin.clicked.connect(self.pypef_gremlin)
self.button_dca_inference_gremlin.setStyleSheet(button_style)

self.button_dca_inference_gremlin_msa_info = QtWidgets.QPushButton("GREMLIN SSM prediction")
self.button_dca_inference_gremlin_msa_info.setMinimumWidth(80)
self.button_dca_inference_gremlin_msa_info.setToolTip(
"Generating DCA parameters using GREMLIN (\"MSA optimization\") and save plots of "
"visualized results; requires an MSA in FASTA or A2M format"
)
self.button_dca_inference_gremlin_msa_info.clicked.connect(self.pypef_gremlin_msa_info)
self.button_dca_inference_gremlin_msa_info.setStyleSheet(button_style)

self.button_dca_test_dca = QtWidgets.QPushButton("Test (DCA)")
self.button_dca_test_dca.setMinimumWidth(80)
self.button_dca_test_dca.setToolTip(
Expand Down Expand Up @@ -208,8 +217,9 @@ def __init__(

layout.addWidget(self.dca_text, 3, 1, 1, 1)
layout.addWidget(self.button_dca_inference_gremlin, 4, 1, 1, 1)
layout.addWidget(self.button_dca_test_dca, 5, 1, 1, 1)
layout.addWidget(self.button_dca_predict_dca, 6, 1, 1, 1)
layout.addWidget(self.button_dca_inference_gremlin_msa_info, 5, 1, 1, 1)
layout.addWidget(self.button_dca_test_dca, 6, 1, 1, 1)
layout.addWidget(self.button_dca_predict_dca, 7, 1, 1, 1)

layout.addWidget(self.hybrid_text, 3, 2, 1, 1)
layout.addWidget(self.button_hybrid_train_dca, 4, 2, 1, 1)
Expand All @@ -222,7 +232,7 @@ def __init__(
layout.addWidget(self.button_supervised_train_test_dca, 4, 3, 1, 1)
layout.addWidget(self.button_supervised_train_test_onehot, 5, 3, 1, 1)

layout.addWidget(self.textedit_out, 7, 0, 1, -1)
layout.addWidget(self.textedit_out, 8, 0, 1, -1)

self.process = QtCore.QProcess(self)
self.process.setProcessChannelMode(QtCore.QProcess.MergedChannels)
Expand All @@ -233,6 +243,8 @@ def __init__(
self.process.finished.connect(lambda: self.button_mklsts.setEnabled(True))
self.process.started.connect(lambda: self.button_dca_inference_gremlin.setEnabled(False))
self.process.finished.connect(lambda: self.button_dca_inference_gremlin.setEnabled(True))
self.process.started.connect(lambda: self.button_dca_inference_gremlin_msa_info.setEnabled(False))
self.process.finished.connect(lambda: self.button_dca_inference_gremlin_msa_info.setEnabled(True))
self.process.started.connect(lambda: self.button_dca_test_dca.setEnabled(False))
self.process.finished.connect(lambda: self.button_dca_test_dca.setEnabled(True))
self.process.started.connect(lambda: self.button_dca_predict_dca.setEnabled(False))
Expand Down Expand Up @@ -290,6 +302,14 @@ def pypef_gremlin(self):
self.version_text.setText("Running GREMLIN (DCA) optimization on MSA...")
self.exec_pypef(f'param_inference --wt {wt_fasta_file} --msa {msa_file}') # --opt_iter 100

@QtCore.Slot()
def pypef_gremlin_msa_info(self):
wt_fasta_file = QtWidgets.QFileDialog.getOpenFileName(self, "Select WT FASTA File")[0]
msa_file = QtWidgets.QFileDialog.getOpenFileName(
self, "Select Multiple Sequence Alignment (MSA) file (in FASTA or A2M format)")[0]
if wt_fasta_file and msa_file:
self.version_text.setText("Running GREMLIN (DCA) optimization on MSA...")
self.exec_pypef(f'save_msa_info --wt {wt_fasta_file} --msa {msa_file}')

@QtCore.Slot()
def pypef_dca_test(self):
Expand Down
3 changes: 2 additions & 1 deletion pypef/dca/dca_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pypef.utils.variant_data import read_csv, get_wt_sequence
from pypef.dca.plmc_encoding import save_plmc_dca_encoding_model
from pypef.dca.hybrid_model import get_model_and_type, performance_ls_ts, predict_ps, generate_model_and_save_pkl
from pypef.dca.gremlin_inference import save_gremlin_as_pickle, save_corr_csv, plot_all_corr_mtx
from pypef.dca.gremlin_inference import save_gremlin_as_pickle, save_corr_csv, plot_all_corr_mtx, plot_predicted_ssm
from pypef.utils.low_n_mutation_extrapolation import performance_mutation_extrapolation, low_n


Expand Down Expand Up @@ -128,6 +128,7 @@ def run_pypef_hybrid_modeling(arguments):
)
save_corr_csv(gremlin)
plot_all_corr_mtx(gremlin)
plot_predicted_ssm(gremlin)

else:
performance_ls_ts(
Expand Down
56 changes: 54 additions & 2 deletions pypef/dca/gremlin_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@
from scipy.special import logsumexp
from scipy.stats import boxcox
import pandas as pd
from tqdm import tqdm
import tensorflow as tf
tf.get_logger().setLevel('DEBUG')
tf.get_logger().setLevel('WARNING')
# Uncomment to hide GPU devices
#environ['CUDA_VISIBLE_DEVICES'] = '-1'

Expand Down Expand Up @@ -718,7 +719,7 @@ def save_gremlin_as_pickle(alignment: str, wt_seq: str, opt_iter: int = 100):
},
open('Pickles/GREMLIN', 'wb')
)
logger.info(f"Saved GREMLIN model as Pickle file ({os.path.abspath('Pickles/GREMLIN')})...")
logger.info(f"Saved GREMLIN model as Pickle file as {os.path.abspath('Pickles/GREMLIN')}...")
return gremlin


Expand All @@ -733,3 +734,54 @@ def save_corr_csv(gremlin: GREMLIN, min_distance: int = 0, sort_by: str = 'apc')
min_distance=min_distance, sort_by=sort_by
)
df_mtx_sorted_mindist.to_csv(f"coevolution_{sort_by}_sorted.csv")


def plot_predicted_ssm(gremlin: GREMLIN):
"""
Function to plot all predicted 19 amino acid substitution
effects at all predictable WT/input sequence positions; e.g.:
M1A, M1C, M1E, ..., D2A, D2C, D2E, ..., ..., T300V, T300W, T300Y
"""
wt_sequence = gremlin.wt_seq
wt_score = gremlin.get_wt_score()[0]
aas = "".join(sorted(gremlin.char_alphabet.replace("-", "")))
variantss, variant_sequencess, variant_scoress = [], [], []
for i, aa_wt in enumerate(tqdm(wt_sequence)):
variants, variant_sequences, variant_scores = [], [], []
for aa_sub in aas:
variant = aa_wt + str(i + 1) + aa_sub
variant_sequence = wt_sequence[:i] + aa_sub + wt_sequence[i + 1:]
variant_score = gremlin.get_score(variant_sequence)[0]
variants.append(variant)
variant_sequences.append(variant_sequence)
variant_scores.append(variant_score - wt_score)
variantss.append(variants)
variant_sequencess.append(variant_sequences)
variant_scoress.append(variant_scores)
print(np.shape(variant_scoress))
fig, ax = plt.subplots(figsize=(30, 3))
ax.imshow(np.array(variant_scoress).T)
for i_vss, vss in enumerate(variant_scoress):
for i_vs, vs in enumerate(vss):
ax.text(
i_vss, i_vs,
f'{variantss[i_vss][i_vs]}\n{round(vs, 1)}',
size=2, va='center', ha='center'
)
ax.set_xticks(
range(len(wt_sequence)),
[f'{aa}{i + 1}' for i, aa in enumerate(wt_sequence)],
size=6, rotation=90
)
ax.set_yticks(range(len(aas)), aas, size=6)
plt.tight_layout()
plt.savefig('SSM_landscape.png', dpi=300)
pd.DataFrame(
{
'Variant': np.array(variantss).flatten(),
'Sequence': np.array(variant_sequencess).flatten(),
'Variant_Score': np.array(variant_scoress).flatten()
}
).to_csv('SSM_landscape.csv', sep=',')
logger.info(f"Saved SSM landscape as {os.path.abspath('SSM_landscape.png')} "
f"and CSV data as {os.path.abspath('SSM_landscape.csv')}...")
2 changes: 1 addition & 1 deletion pypef/dca/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,6 @@ def save_model_to_dict_pickle(
model_type = 'MODEL'

pkl_path = os.path.abspath(f'Pickles/{model_type}')
logger.info(f'Saving model as Pickle file ({pkl_path})...')
pickle.dump(
{
'model': model,
Expand All @@ -698,6 +697,7 @@ def save_model_to_dict_pickle(
},
open(f'Pickles/{model_type}', 'wb')
)
logger.info(f'Saved model as Pickle file ({pkl_path})...')


global_model = None
Expand Down
2 changes: 1 addition & 1 deletion pypef/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,10 +959,10 @@ def save_model(
if model_type in ['PLMC', 'GREMLIN'] and encoding not in ['aaidx', 'onehot']:
name = 'ML' + model_type.lower()
f_name = os.path.abspath(os.path.join(path, 'Pickles', name))
logger.info(f'Saving model ({f_name})...')
file = open(f_name, 'wb')
pickle.dump(regressor_, file)
file.close()
logger.info(f'Saved model as {f_name}...')

except IndexError:
raise IndexError
Expand Down
2 changes: 1 addition & 1 deletion pypef/utils/low_n_mutation_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ def performance_mutation_extrapolation(
logger.info('Fitting regressor on lvl 1 substitution data...')
regressor.fit(x_train, y_train)
if save_model:
logger.info(f'Saving model as Pickle file: ML_LVL_1')
pickle.dump(regressor, open(os.path.join('Pickles', 'ML_LVL_1'), 'wb'))
logger.info(f'Saved model as Pickle file: ML_LVL_1')
for i, _ in enumerate(tqdm(collected_levels)):
if i < len(collected_levels) - 1: # not last i else error, last entry is: lvl 1 --> all higher variants
test_idx = collected_levels[i + 1]
Expand Down
2 changes: 1 addition & 1 deletion pypef/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,6 @@ def plot_y_true_vs_y_pred(
# i += 1 # iterate until finding an unused file name
# file_name = f'DCA_Hybrid_Model_LS_TS_Performance({i}).png'
plt.colorbar()
logger.info(f'Saving plot ({os.path.abspath(file_name)})...')
plt.savefig(file_name, dpi=500)
plt.close('all')
logger.info(f'Saved plot as {os.path.abspath(file_name)}...')

0 comments on commit 5d48bd6

Please sign in to comment.