Skip to content

Commit

Permalink
bunch of edits to plotting and interpretation/ code organization
Browse files Browse the repository at this point in the history
  • Loading branch information
annashcherbina committed May 1, 2019
1 parent a661fc1 commit 2996e82
Show file tree
Hide file tree
Showing 34 changed files with 2,514 additions and 2,407 deletions.
13 changes: 13 additions & 0 deletions dragonn/interpret/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from dragonn.interpret.ism import *
from dragonn.interpret.deeplift import *
from dragonn.interpret.input_grad import *

def multi_method_interpret():
"""
Arguments:
model
input
generate_plots
"""
return

15 changes: 13 additions & 2 deletions dragonn/interpret/deeplift.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import deeplift
import numpy as np

def deeplift_zero_ref(X,score_func,batch_size=200,task_idx=0):
# use a 40% GC reference
Expand Down Expand Up @@ -45,7 +46,17 @@ def deeplift_shuffled_ref(X,score_func,batch_size=200,task_idx=0,num_refs_per_se

def deeplift(model, X, batch_size=200,target_layer_idx=-2,task_idx=0, num_refs_per_seq=10,reference="shuffled_ref",one_hot_func=None):
"""
Returns (num_task, num_samples, 1, num_bases, sequence_length) deeplift score array.
Arguments:
model -- a string containing the path to the hdf5 exported model
X -- numpy array with shape (n_samples, 1, n_bases_in_sample,4) or list of FASTA sequences
batch_size -- number of samples to interpret at once
target_layer_idx -- should be -2 for classification; -1 for regression
task_idx -- index indicating which task to perform interpretation on
reference -- one of 'shuffled_ref','gc_ref','zero_ref'
num_refs_per_seq -- integer indicating number of references to use for each input sequence if the reference is set to 'shuffled_ref';if 'zero_ref' or 'gc_ref' is used, this argument is ignored.
one_hot_func -- one hot function to use for encoding FASTA string inputs; if the inputs are already one-hot-encoded, use the default of None
Returns:
(num_task, num_samples, 1, num_bases, sequence_length) deeplift score array.
"""
assert reference in ["shuffled_ref","gc_ref","zero_ref"]
if one_hot_func==None:
Expand All @@ -66,7 +77,7 @@ def deeplift(model, X, batch_size=200,target_layer_idx=-2,task_idx=0, num_refs_p
elif reference=="zero_ref":
deeplift_scores=deeplift_zero_ref(X,score_func,batch_size,task_idx)
else:
raise Exception("supported DeepLIFT references are 'shuffled_ref' and 'gc_ref'")
raise Exception("supported DeepLIFT references are 'shuffled_ref','gc_ref', 'zero_ref'")
return np.asarray(deeplift_scores)


Expand Down
11 changes: 6 additions & 5 deletions dragonn/interpret/ism.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#utilities for running in-silico mutagenesis within dragonn.
import numpy as np

def get_logit_function(model):
from keras import backend as K
Expand Down Expand Up @@ -33,22 +34,22 @@ def in_silico_mutagenesis(model, X):
#3. Iterate through all tasks, positions
for task_index in range(output_dim[0]):
for sample_index in range(output_dim[1]):
print(str(task_index)+":"+str(sample_index))
print("task:"+str(task_index)+" sample:"+str(sample_index))
#fill in wild type logit values into an array of dim (task,sequence_length,num_bases)
wt_logit_for_task_sample=wild_type_logits[task_index][sample_index]
wt_expanded[task_index][sample_index]=np.tile(wt_logit_for_task_sample,(output_dim[2],output_dim[3]))
#mutagenize each position
for base_pos in range(output_dim[2]):
#for each position, iterate through the 4 bases
for base_letter in range(output_dim[3]):
cur_base=empty_onehot
cur_base=np.copy(empty_onehot)
cur_base[base_letter]=1
Xtmp=np.expand_dims(X[sample_index],axis=0)
Xtmp=np.copy(np.expand_dims(X[sample_index],axis=0))
Xtmp[0][0][base_pos]=cur_base
#get the logit of Xtmp
mutants_expanded[task_index][sample_index][base_pos][base_letter]=np.squeeze(get_logit(functor,Xtmp)[task_index])
#subtract mutants_expanded from wt_expanded
ism_vals=wt_expanded-mutants_expanded
#subtract wt_expanded from mutants_expanded
ism_vals=mutants_expanded-wt_expanded
#For each position subtract the mean ISM score for that position from each of the 4 values
ism_vals_mean=np.expand_dims(np.mean(ism_vals,axis=3),axis=3)
ism_vals_normed=ism_vals-ism_vals_mean
Expand Down
72 changes: 58 additions & 14 deletions dragonn/vis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,53 @@
from matplotlib import pyplot as plt
import numpy as np
from dragonn.vis.plot_letters import *
from dragonn.vis.plot_kmers import *

def plot_sequence_filters(model):
fig = plt.figure(figsize=(15, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
conv_filters=model.layers[0].get_weights()[0]
#transpose for plotting
conv_filters=np.transpose(conv_filters,(3,1,2,0)).squeeze(axis=-1)
num_plots_per_axis = int(len(conv_filters)**0.5) + 1
for i, conv_filter in enumerate(conv_filters):
ax = fig.add_subplot(num_plots_per_axis, num_plots_per_axis, i+1)
add_letters_to_axis(ax, conv_filter)
ax.axis("off")
ax.set_title("Filter %s" % (str(i+1)))


def plot_filters(model,simulation_data):
print("Plotting simulation motifs...")
plot_motifs(simulation_data)
plt.show()
print("Visualizing convolutional sequence filters in SequenceDNN...")
plot_sequence_filters(model)
plt.show()



def plot_motif_scores(motif_scores,title="",figsize=(20,3),ymin=0,ymax=20):
plt.figure(figsize=figsize)
plt.plot(pos_motif_scores, "-o")
f=plt.figure(figsize=figsize)
plt.plot(motif_scores, "-o")
plt.xlabel("Sequence base")
plt.ylabel("Motif scan score")
#threshold motif scores at 0; any negative scores are noise that we do not need to visualize
plt.ylim(ymin, ymax)
plt.title(title)
plt.show()
return f,f.axes

def plot_model_weights(model,layer_idx=-2):
W_dense, b_dense = model.layers[layer_idx].get_weights()
f=plt.figure()
plt.plot(W_dense,'-o')
plt.xlabel('Filter index')
plt.ylabel('Weight value')
plt.show()

def plot_ism(ism_mat,title="",figsize=(20,5)):
return f,f.get_axes()

def plot_ism(ism_mat,title="", xlim=None, ylim=None, figsize=(20,5)):
""" Plot the 4xL heatmap and also the identity and score of the highest scoring (mean subtracted) allele at each position
Args:
Expand All @@ -30,26 +58,39 @@ def plot_ism(ism_mat,title="",figsize=(20,5)):
Returns:
generates a heatmap and letter plot of the ISM matrix
"""
if ism_mat.shape!=2:
print("Warning! The input matrix should represent a single input sequence for ISM, and as such should have dimensions : n_positions x 4. Running np.squeeze to remove extra dimensions.")
ism_mat=np.squeeze(ism_mat)
assert len(ism_mat.shape)==2
assert ism_mat.shape[1]==4

highest_scoring_pos=np.argmax(np.abs(ism_mat),axis=1)
zero_map=np.zeros(ism_mat.shape)
zero_map[:,highest_scoring_pos]=1
product=zero_map*ism_mat

fig,axes=plt.subplots(2, 1,sharex='row',figsize=figsize)
for i in range(zero_map.shape[0]):
zero_map[i][highest_scoring_pos[i]]=1
product=zero_map*ism_mat
f,axes=plt.subplots(2, 1,sharex='row',figsize=figsize)
axes[0]=plot_bases_on_ax(product,axes[0],show_ticks=False)
axes[0].set_title(title)
extent = [0, ism_mat.shape[0], 0, 100*ism_mat.shape[1]]
ymin=np.amin(ism_mat)
ymax=np.amax(ism_mat)
axes[1].imshow(ism_mat.T,extent=extent,vmin=ymin, vmax=ymax, interpolation='nearest',aspect='auto')
hmap=axes[1].imshow(ism_mat.T,extent=extent,vmin=ymin, vmax=ymax, interpolation='nearest',aspect='auto')
axes[1].set_yticks(np.arange(50,100*ism_mat.shape[1],100),("A","C","G","T"))
axes[1].set_xlabel("Sequence base")
axes[1].set_ylabel("ISM Score")
axes[1].set_title(title)
axes[1].set_yticks(np.arange(50,100*ism_mat.shape[1],100),("A","C","G","T"))
if xlim!=None:
axes[0].set_xlim(xlim)
axes[1].set_xlim(xlim)
if ylim!=None:
axes[0].set_ylim(ylim)
axes[1].set_ylim(ylim)

plt.set_cmap('RdBu')
plt.tight_layout()
plt.colorbar()
plt.colorbar(hmap,ax=axes[1],orientation='horizontal')
plt.show()

return f,axes

def plot_seq_importance(grads, x, xlim=None, ylim=None, figsize=(25, 3),title="",snp_pos=0):
"""Plot sequence importance score
Expand All @@ -69,12 +110,13 @@ def plot_seq_importance(grads, x, xlim=None, ylim=None, figsize=(25, 3),title=""
xlim = (0, seq_len)
if ylim is None:
ylim= (np.amin(vals_to_plot),np.amax(vals_to_plot))
seqlogo_fig(vals_to_plot, figsize=figsize)
f,ax=plot_bases(vals_to_plot, figsize=figsize,ylab="")
plt.xticks(list(range(xlim[0], xlim[1], 5)))
plt.xlim(xlim)
plt.ylim(ylim)
plt.title(title)
plt.axvline(x=snp_pos, color='k', linestyle='--')
return f,ax

def plot_learning_curve(history):
train_losses=history.history['loss']
Expand All @@ -90,3 +132,5 @@ def plot_learning_curve(history):
ax.set_ylim((min(train_losses+valid_losses),max(train_losses+valid_losses)))
ax.set_xlabel("Epoch")
plt.show()


22 changes: 0 additions & 22 deletions dragonn/vis/plot_kmers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,3 @@ def plot_motifs(simulation_data):
for motif_name in simulation_data.motif_names:
plot_motif(motif_name, figsize=(10, 4), ylab=motif_name)

def plot_sequence_filters(model):
fig = plt.figure(figsize=(15, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)
conv_filters=model.layers[0].get_weights()[0]
#transpose for plotting
conv_filters=np.transpose(conv_filters,(3,1,2,0)).squeeze(axis=-1)
num_plots_per_axis = int(len(conv_filters)**0.5) + 1
for i, conv_filter in enumerate(conv_filters):
ax = fig.add_subplot(num_plots_per_axis, num_plots_per_axis, i+1)
add_letters_to_axis(ax, conv_filter)
ax.axis("off")
ax.set_title("Filter %s" % (str(i+1)))


def plot_filters(model,simulation_data):
print("Plotting simulation motifs...")
plot_motifs(simulation_data)
plt.show()
print("Visualizing convolutional sequence filters in SequenceDNN...")
plot_sequence_filters(model)
plt.show()

6 changes: 3 additions & 3 deletions dragonn/vis/plot_letters.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def plot_bases_on_ax(letter_heights, ax, show_ticks=True):
top=False, # ticks along the top edge are off
labelbottom=False)
ax.set_aspect(aspect='auto', adjustable='box')
#ax.autoscale_view()
ax.autoscale_view()
return ax

def plot_bases(letter_heights, figsize=(12, 6), ylab='bits'):
Expand All @@ -265,11 +265,11 @@ def plot_bases(letter_heights, figsize=(12, 6), ylab='bits'):

fig = pyplot.figure(figsize=figsize)
ax = fig.add_subplot(111)
ax.set_xlabel('pos')
ax.set_xlabel('base pair position')
ax.set_ylabel(ylab)
plot_bases_on_ax(letter_heights, ax)

return fig
return fig,ax



Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
'packages': ['dragonn'],
'setup_requires': [],
'install_requires': ['numpy>=1.15', 'keras>=2.2.0','tensorflow>=1.6','deeplift>=0.6.9.0', 'shapely', 'matplotlib',
'scikit-learn>=0.20.0', 'pydot_ng==1.0.0', 'h5py','concise','seqdataloader>=0.124','simdna_dragonn','abstention'],
'scikit-learn>=0.20.0', 'pydot_ng==1.0.0', 'h5py','seqdataloader>=0.124','simdna_dragonn','abstention'],
'extras_requires':{'tensorflow with gpu':['tensorflow-gpu>=1.7']},
'dependency_links': [],
'scripts': [],
Expand Down
Loading

0 comments on commit 2996e82

Please sign in to comment.