10
10
from pathlib import Path
11
11
from sklearn import model_selection as sklearn_model_selection
12
12
13
- METHOD_NAME , NUM = ' METHODNAME' , ' NUM'
13
+ METHOD_NAME , NUM = " METHODNAME" , " NUM"
14
14
15
15
parser = argparse .ArgumentParser ()
16
- parser .add_argument (' --data_dir' , required = True , type = str )
17
- parser .add_argument (' --valid_p' , type = float , default = 0.2 )
18
- parser .add_argument (' --max_path_length' , type = int , default = 8 )
19
- parser .add_argument (' --max_path_width' , type = int , default = 2 )
20
- parser .add_argument (' --use_method_name' , type = bool , default = True )
21
- parser .add_argument (' --use_nums' , type = bool , default = True )
22
- parser .add_argument (' --output_dir' , required = True , type = str )
23
- parser .add_argument (' --n_jobs' , type = int , default = multiprocessing .cpu_count ())
24
- parser .add_argument (' --seed' , type = int , default = 239 )
16
+ parser .add_argument (" --data_dir" , required = True , type = str )
17
+ parser .add_argument (" --valid_p" , type = float , default = 0.2 )
18
+ parser .add_argument (" --max_path_length" , type = int , default = 8 )
19
+ parser .add_argument (" --max_path_width" , type = int , default = 2 )
20
+ parser .add_argument (" --use_method_name" , type = bool , default = True )
21
+ parser .add_argument (" --use_nums" , type = bool , default = True )
22
+ parser .add_argument (" --output_dir" , required = True , type = str )
23
+ parser .add_argument (" --n_jobs" , type = int , default = multiprocessing .cpu_count ())
24
+ parser .add_argument (" --seed" , type = int , default = 239 )
25
25
26
26
27
27
def __collect_asts (json_file ):
28
28
asts = []
29
- with open (json_file , 'r' , encoding = ' utf-8' ) as f :
29
+ with open (json_file , "r" , encoding = " utf-8" ) as f :
30
30
for line in f :
31
31
ast = json .loads (line .strip ())
32
32
asts .append (ast )
@@ -42,22 +42,22 @@ def dfs(v):
42
42
43
43
v_node = ast [v ]
44
44
45
- if ' value' in v_node :
45
+ if " value" in v_node :
46
46
if v == node_index : # Top-level func def node.
47
47
if args .use_method_name :
48
48
paths .append ((stack .copy (), METHOD_NAME ))
49
49
else :
50
- v_type = v_node [' type' ]
50
+ v_type = v_node [" type" ]
51
51
52
- if v_type .startswith (' Name' ):
53
- paths .append ((stack .copy (), v_node [' value' ]))
54
- elif args .use_nums and v_type == ' Num' :
52
+ if v_type .startswith (" Name" ):
53
+ paths .append ((stack .copy (), v_node [" value" ]))
54
+ elif args .use_nums and v_type == " Num" :
55
55
paths .append ((stack .copy (), NUM ))
56
56
else :
57
57
pass
58
58
59
- if ' children' in v_node :
60
- for child in v_node [' children' ]:
59
+ if " children" in v_node :
60
+ for child in v_node [" children" ]:
61
61
dfs (child )
62
62
63
63
stack .pop ()
@@ -84,12 +84,13 @@ def __raw_tree_paths(ast, node_index, args):
84
84
85
85
tree_paths = []
86
86
for (v_path , v_value ), (u_path , u_value ) in itertools .combinations (
87
- iterable = tnodes ,
88
- r = 2 ,
87
+ iterable = tnodes ,
88
+ r = 2 ,
89
89
):
90
90
prefix , lca , suffix = __merge_terminals2_paths (v_path , u_path )
91
- if (len (prefix ) + 1 + len (suffix ) <= args .max_path_length ) \
92
- and (abs (len (prefix ) - len (suffix )) <= args .max_path_width ):
91
+ if (len (prefix ) + 1 + len (suffix ) <= args .max_path_length ) and (
92
+ abs (len (prefix ) - len (suffix )) <= args .max_path_width
93
+ ):
93
94
path = prefix + [lca ] + suffix
94
95
tree_path = v_value , path , u_value
95
96
tree_paths .append (tree_path )
@@ -103,49 +104,49 @@ def __delim_name(name):
103
104
104
105
def camel_case_split (identifier ):
105
106
matches = re .finditer (
106
- ' .+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)' ,
107
+ " .+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)" ,
107
108
identifier ,
108
109
)
109
110
return [m .group (0 ) for m in matches ]
110
111
111
112
blocks = []
112
- for underscore_block in name .split ('_' ):
113
+ for underscore_block in name .split ("_" ):
113
114
blocks .extend (camel_case_split (underscore_block ))
114
115
115
- return '|' .join (block .lower () for block in blocks )
116
+ return "|" .join (block .lower () for block in blocks )
116
117
117
118
118
119
def __collect_sample (ast , fd_index , args ):
119
120
root = ast [fd_index ]
120
- if root [' type' ] != ' FunctionDef' :
121
- raise ValueError (' Wrong node type.' )
121
+ if root [" type" ] != " FunctionDef" :
122
+ raise ValueError (" Wrong node type." )
122
123
123
- target = root [' value' ]
124
+ target = root [" value" ]
124
125
125
126
tree_paths = __raw_tree_paths (ast , fd_index , args )
126
127
contexts = []
127
128
for tree_path in tree_paths :
128
129
start , connector , finish = tree_path
129
130
130
131
start , finish = __delim_name (start ), __delim_name (finish )
131
- connector = '|' .join (ast [v ][' type' ] for v in connector )
132
+ connector = "|" .join (ast [v ][" type" ] for v in connector )
132
133
133
- context = f' { start } ,{ connector } ,{ finish } '
134
+ context = f" { start } ,{ connector } ,{ finish } "
134
135
contexts .append (context )
135
136
136
137
if len (contexts ) == 0 :
137
138
return None
138
139
139
140
target = __delim_name (target )
140
- context = ' ' .join (contexts )
141
+ context = " " .join (contexts )
141
142
142
- return f' { target } { context } '
143
+ return f" { target } { context } "
143
144
144
145
145
146
def __collect_samples (ast , args ):
146
147
samples = []
147
148
for node_index , node in enumerate (ast ):
148
- if node [' type' ] == ' FunctionDef' :
149
+ if node [" type" ] == " FunctionDef" :
149
150
sample = __collect_sample (ast , node_index , args )
150
151
if sample is not None :
151
152
samples .append (sample )
@@ -160,18 +161,18 @@ def __collect_all_and_save(asts, args, output_file):
160
161
samples = parallel (func (ast , args ) for ast in tqdm .tqdm (asts ))
161
162
samples = list (itertools .chain .from_iterable (samples ))
162
163
163
- with open (output_file , 'w' ) as f :
164
+ with open (output_file , "w" ) as f :
164
165
for line_index , line in enumerate (samples ):
165
- f .write (line + ('' if line_index == len (samples ) - 1 else ' \n ' ))
166
+ f .write (line + ("" if line_index == len (samples ) - 1 else " \n " ))
166
167
167
168
168
169
def main ():
169
170
args = parser .parse_args ()
170
171
np .random .seed (args .seed )
171
172
172
173
data_dir = Path (args .data_dir )
173
- trains = __collect_asts (data_dir / ' python100k_train.json' )
174
- evals = __collect_asts (data_dir / ' python50k_eval.json' )
174
+ trains = __collect_asts (data_dir / " python100k_train.json" )
175
+ evals = __collect_asts (data_dir / " python50k_eval.json" )
175
176
176
177
train , valid = sklearn_model_selection .train_test_split (
177
178
trains ,
@@ -182,12 +183,12 @@ def main():
182
183
output_dir = Path (args .output_dir )
183
184
output_dir .mkdir (exist_ok = True )
184
185
for split_name , split in zip (
185
- ( ' train' , ' valid' , ' test' ),
186
- (train , valid , test ),
186
+ ( " train" , " valid" , " test" ),
187
+ (train , valid , test ),
187
188
):
188
- output_file = output_dir / f' { split_name } _output_file.txt'
189
+ output_file = output_dir / f" { split_name } _output_file.txt"
189
190
__collect_all_and_save (split , args , output_file )
190
191
191
192
192
- if __name__ == ' __main__' :
193
+ if __name__ == " __main__" :
193
194
main ()
0 commit comments