Skip to content

Commit

Permalink
code status
Browse files Browse the repository at this point in the history
  • Loading branch information
Dylan Bourgeois committed Mar 15, 2019
1 parent 9b624c0 commit a60f752
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 21 deletions.
6 changes: 4 additions & 2 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _create_examples(self, lines, labels, set_type, prefix=None):

if FLAGS.shuffle:
split_text = np.asarray(text.split(' '))
shuffle_idx = np.random.permutation(range(1,len(split_text)))
shuffle_idx = np.random.permutation(range(1,len(split_text)))
shuffle_idx = np.insert(shuffle_idx,0,0)
text = ' '.join(split_text[shuffle_idx])
adj = np.asarray(adj)
Expand Down Expand Up @@ -866,6 +866,7 @@ def main(_):
files = zip(all_train, all_label)
print(files)
train_examples = processor.get_multi_train_examples(files)
print(len(train_examples))
else:
train_examples = processor.get_train_examples(FLAGS.train_file, FLAGS.train_labels)
if FLAGS.num_train_epochs == 0:
Expand Down Expand Up @@ -1038,7 +1039,8 @@ def main(_):
for prob in class_prob:
out_line.append(str(prob))
writer.write('\t'.join(out_line)+"\n")
if i >= num_actual_predict_examples:
#if i >= num_actual_predict_examples:
if i >= FLAGS.max_nb_preds:
break
else:
output_line = "\t".join(
Expand Down
12 changes: 6 additions & 6 deletions create_pretraining_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def gen_snippet_datasetv2(G, feats, var_map, func_map=None, out_path=None, pre='
open(os.path.join(out_path, pre+name+suffix+'_label_val.txt'), 'w').close()
if regen_vocab:
open(os.path.join(out_path, pre+'vocab-code.txt'), 'w').close()
if mode=='funcdef':
if mode=='methodname':
open(os.path.join(out_path, pre+'vocab-label.txt'), 'w').close()

for j in range(nb_snippets):
Expand All @@ -193,7 +193,7 @@ def gen_snippet_datasetv2(G, feats, var_map, func_map=None, out_path=None, pre='

first_tok = snippet[0][0]
tk = get_name_from_token(feats[first_tok], show_id=False)
if mode=='funcdef':
if mode=='methodname':
if tk=='FunctionDef':
label = gen_func_label(first_tok, func_map)
if label not in label_voc:
Expand Down Expand Up @@ -257,7 +257,7 @@ def gen_snippet_datasetv2(G, feats, var_map, func_map=None, out_path=None, pre='
sep='\n' if not mode=='magret' else '\n\n'
f.write(sep)

if mode=='funcdef' and (label is not None):
if mode=='methodname' and (label is not None):
with open(os.path.join(out_path, pre+name+suffix+'_label.txt'), 'a') as f:
f.write(label+'\n')

Expand Down Expand Up @@ -300,7 +300,7 @@ def gen_snippet_datasetv2(G, feats, var_map, func_map=None, out_path=None, pre='
sep='\n' if not mode=='magret' else '\n\n'
f.write(sep)

if mode=='funcdef' and (label is not None):
if mode=='methodname' and (label is not None):
with open(os.path.join(out_path, pre+name+suffix+'_label_val.txt'), 'a') as f:
f.write(label+'\n')

Expand Down Expand Up @@ -351,7 +351,7 @@ def gen_snippet_datasetv2(G, feats, var_map, func_map=None, out_path=None, pre='
f.write('\n')
print("Vocabulary length: ", len(voc)+5)

if mode=='funcdef':
if mode=='methodname':
with open(os.path.join(out_path, pre+'vocab-label.txt'), 'a') as f:
for v in label_voc:
f.write(v)
Expand All @@ -363,7 +363,7 @@ def main(args):
G = json_graph.node_link_graph(G_data)
var_map = json.load(open(args.path+args.prefix+"-var_map.json"))
func_map = json.load(open(args.path+args.prefix+"-func_map.json"))
if args.mode == 'funcdef':
if args.mode == 'methodname':
regen_vocab = False
else:
regen_vocab = args.regen_vocab
Expand Down
9 changes: 4 additions & 5 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,7 +922,6 @@ def transformer_model(input_tensor,
intermediate_size,
activation=intermediate_act_fn,
kernel_initializer=create_initializer(initializer_range))

# Down-project back to `hidden_size` then add the residual.
with tf.variable_scope("output"):
layer_output = tf.layers.dense(
Expand All @@ -939,10 +938,10 @@ def transformer_model(input_tensor,
for layer_output in all_layer_outputs:
final_output = reshape_from_matrix(layer_output, input_shape)
final_outputs.append(final_output)
#for att_output_prob in all_attention_output_probs:
# final_att_output = reshape_from_matrix(att_output_prob, input_shape)
# final_attention_outputs.append(final_att_output)
return final_outputs, attention_output_probs #final_attention_outputs
for att_output_prob in all_attention_output_probs:
#final_att_output = reshape_from_matrix(att_output_prob, input_shape)
final_attention_outputs.append(att_output_prob)
return final_outputs, tf.concat(final_attention_outputs, axis=-1)
else:
final_output = reshape_from_matrix(prev_output, input_shape)
return final_output, attention_output_probs
Expand Down
4 changes: 1 addition & 3 deletions prepare_pretraining_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,6 @@ def create_training_instances_with_adj(input_files, adj_files, tokenizer, max_se
all_documents = [x for x in all_documents if x]
all_adjs = [x for x in all_adjs if x]
print(len(all_documents), len(all_adjs))
# c = list(zip(all_adjs, all_documents))
# rng.shuffle(c)
# all_adjs, all_documents = zip(*c)

vocab_words = list(tokenizer.vocab.keys())
instances = []
Expand Down Expand Up @@ -503,6 +500,7 @@ def create_instances_from_document_with_adj(
segment_ids = []
# tokens.append("[CLS]")
# segment_ids.append(0)
print(tokens_a)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
Expand Down
17 changes: 12 additions & 5 deletions run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
"masked_lm_positions": masked_lm_positions,
"label_ids": label_ids,
"input_ids": input_ids,
"adjacency": adjacency,
"attention_outputs": attention_output
}

Expand Down Expand Up @@ -512,6 +513,7 @@ def main(_):
tf.gfile.MakeDirs(FLAGS.output_dir)

input_files = []
print(FLAGS.input_file)
for input_pattern in FLAGS.input_file.split(","):
input_files.extend(tf.gfile.Glob(input_pattern))

Expand Down Expand Up @@ -662,13 +664,18 @@ def main(_):
with tf.gfile.GFile(os.path.join(FLAGS.output_dir, 'eval_results_att.txt'), 'w') as writer:
for r in estimator.predict(input_fn, yield_single_examples=True):
att = r['attention_outputs']
adj = r['adjacency']
break

for i in range(12):
for j in range(64):
for k in range(64):
writer.write("%s " % str(att[i][j][k]))
writer.write("\n")
for l in range(3):
for i in range(6):
for j in range(64):
for k in range(64):
writer.write("%s " % str(att[i][j][64*l+k]))
writer.write("\n")
with tf.gfile.GFile(os.path.join(FLAGS.output_dir, 'eval_results_adj.txt'), 'w') as writer:
writer.write("%s " % str(list(adj)))
writer.write("\n")

if __name__ == "__main__":
flags.mark_flag_as_required("input_file")
Expand Down

0 comments on commit a60f752

Please sign in to comment.