-
Notifications
You must be signed in to change notification settings - Fork 154
/
Copy pathtransform.py
371 lines (321 loc) · 15.9 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
# (C) Copyright IBM Corp. 2024.
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import atexit
import json
import os
import shutil
import subprocess
import uuid
from argparse import ArgumentParser, Namespace
from typing import Any
import pyarrow as pa
from data_processing.transform import (
AbstractBinaryTransform,
AbstractTableTransform,
TransformConfiguration,
)
from data_processing.utils import CLIArgumentProvider, get_logger
from dpk_code_profiler.higher_order_concepts import *
from dpk_code_profiler.profiler_report import *
from dpk_code_profiler.semantic_concepts import *
from tree_sitter import Language
from tree_sitter import Parser as TSParser
from tree_sitter_languages import get_language
from dpk_code_profiler.UAST_parser import UASTParser, uast_read
short_name = "CodeProfiler"
cli_prefix = f"{short_name}_"
language = "language"
contents = "contents"
class CodeProfilerTransform(AbstractTableTransform):
"""
Implements a simple copy of a pyarrow Table.
"""
def __init__(self, config: dict[str, Any]):
"""
Initialize based on the dictionary of configuration information.
"""
super().__init__(config)
self.contents = self.config.get("contents", "contents")
self.language = self.config.get("language", "language")
if not isinstance(self.contents, str):
raise ValueError(f"'contents' should be a string, got {type(self.contents).__name__}")
def ensure_tree_sitter_bindings():
# Get the directory where the script is located
script_dir = os.path.dirname(os.path.abspath(__file__))
# Generate a unique directory for the bindings based on a UUID
bindings_dir = os.path.join(script_dir, f"tree-sitter-bindings-{uuid.uuid4()}")
# Clone the bindings only if the unique directory does not exist
if not os.path.exists(bindings_dir):
print(f"Cloning tree-sitter bindings into {bindings_dir}...")
result = subprocess.run(
[
"git",
"clone",
"https://github.com/pankajskku/tree-sitter-bindings.git",
bindings_dir,
]
)
if result.returncode != 0:
raise RuntimeError(f"Failed to clone tree-sitter bindings into {bindings_dir}")
return bindings_dir
# Call this function before the main code execution
self.bindings_dir = ensure_tree_sitter_bindings()
# Use the correct architecture for runtime
RUNTIME_HOST_ARCH = os.environ.get("RUNTIME_HOST_ARCH", "x86_64")
bindings_path = self.bindings_dir + "/" + RUNTIME_HOST_ARCH # for MAC: mach-arm64
print(f"Bindings bindings_dir: {self.bindings_dir}")
print(f"Bindings path: {bindings_path}")
# Check if the bindings path exists
if not os.path.exists(bindings_path):
raise FileNotFoundError(f"Bindings path does not exist: {bindings_path}")
try:
AGDA_LANGUAGE = Language(os.path.join(bindings_path, "agda-bindings.so"), "agda")
C_LANGUAGE = get_language("c")
CPP_LANGUAGE = get_language("cpp")
CSHARP_LANGUAGE = Language(os.path.join(bindings_path, "c_sharp-bindings.so"), "c_sharp")
D_LANGUAGE = Language(os.path.join(bindings_path, "d-bindings.so"), "d")
DART_LANGUAGE = Language(os.path.join(bindings_path, "dart-bindings.so"), "dart")
ELM_LANGUAGE = Language(os.path.join(bindings_path, "elm-bindings.so"), "elm")
GOLANG_LANGUAGE = Language(os.path.join(bindings_path, "go-bindings.so"), "go")
HASKELL_LANGUAGE = Language(os.path.join(bindings_path, "haskell-bindings.so"), "haskell")
JAVA_LANGUAGE = get_language("java")
JAVASCRIPT_LANGUAGE = Language(os.path.join(bindings_path, "js-bindings.so"), "javascript")
KOTLIN_LANGUAGE = Language(os.path.join(bindings_path, "kotlin-bindings.so"), "kotlin")
NIM_LANGUAGE = Language(os.path.join(bindings_path, "nim-bindings.so"), "nim")
# OBJECTIVE_C_LANGUAGE = Language(os.path.join(bindings_path, 'objc-bindings.so'), 'objc')
OCAML_LANGUAGE = get_language("ocaml")
PERL_LANGUAGE = get_language("perl")
PY_LANGUAGE = get_language("python")
QMLJS_LANGUAGE = Language(os.path.join(bindings_path, "qmljs-bindings.so"), "qmljs")
RUST_LANGUAGE = get_language("rust")
SCALA_LANGUAGE = Language(os.path.join(bindings_path, "scala-bindings.so"), "scala")
TYPESCRIPT_LANGUAGE = get_language("typescript")
except Exception as e:
self.clean_bindings()
raise Exception("Bindings are not loaded", e)
self.clean_bindings()
# Language map for supported languages
self.language_map = {
"Agda": AGDA_LANGUAGE,
"C": C_LANGUAGE,
"C#": CSHARP_LANGUAGE,
"Csharp": CSHARP_LANGUAGE,
"Cpp": CPP_LANGUAGE,
"D": D_LANGUAGE,
"Dart": DART_LANGUAGE,
"Elm": ELM_LANGUAGE,
"Go": GOLANG_LANGUAGE,
"Haskell": HASKELL_LANGUAGE,
"Java": JAVA_LANGUAGE,
"JavaScript": JAVASCRIPT_LANGUAGE,
"Kotlin": KOTLIN_LANGUAGE,
"Nim": NIM_LANGUAGE,
"Ocaml": OCAML_LANGUAGE,
# "Objective-C": OBJECTIVE_C_LANGUAGE,
"Perl": PERL_LANGUAGE,
"Python": PY_LANGUAGE,
"Qmljs": QMLJS_LANGUAGE,
"Rust": RUST_LANGUAGE,
"Scala": SCALA_LANGUAGE,
"TypeScript": TYPESCRIPT_LANGUAGE,
}
self.uast_language_map = {
"Agda": "agda",
"C": "c",
"C#": "c_sharp",
"Csharp": "c_sharp",
"C++": "cpp",
"Cpp": "cpp",
"D": "d",
"Dart": "dart",
"Elm": "elm",
"Go": "go",
"Haskell": "haskell",
"Java": "java",
"JavaScript": "js",
"Kotlin": "kotlin",
"Nim": "nim",
"Ocaml": "ocaml",
# "Objective-C": 'objc',
"Perl": "perl",
"Python": "py",
"Qmljs": "qmljs",
"Rust": "rust",
"Scala": "scala",
"TypeScript": "typescript",
}
self.logger = get_logger(__name__)
self.ruleset_file = os.path.dirname(os.path.abspath(__file__))
# Semantic profiling related inits
self.ikb_file = config.get("ikb_file", "semantic-ruleset/ikb_model.csv")
self.null_libs_file = config.get("null_libs_file", "semantic-ruleset/null_libs.csv")
src_file_dir = os.path.abspath(os.path.dirname(__file__))
# Check if the file exists; if not, update the default path
if not os.path.exists(self.ikb_file):
print(f"File not found at {self.ikb_file}. Updating to '../semantic-ruleset/ikb_model.csv'")
self.ikb_file = os.path.join(src_file_dir, "semantic-ruleset/ikb_model.csv")
# Raise an error if the file still doesn't exist
if not os.path.exists(self.ikb_file):
raise FileNotFoundError(f"File not found: {self.ikb_file}")
# Check if the file exists; if not, update the default path
if not os.path.exists(self.null_libs_file):
print(f"File not found at {self.null_libs_file}. Updating to '../semantic-ruleset/null_libs.csv'")
self.null_libs_file = os.path.join(src_file_dir, "semantic-ruleset/null_libs.csv")
# Raise an error if the file still doesn't exist
if not os.path.exists(self.null_libs_file):
raise FileNotFoundError(f"File not found: {self.null_libs_file}")
# Higher order semantic features
self.metrics_list = config.get("metrics_list", ["CCR", "code_snippet_len", "avg_fn_len_in_snippet"])
def transform(self, table: pa.Table, file_name: str = None) -> tuple[list[pa.Table], dict[str, Any]]:
"""
Extracts the syntactic constructs
"""
print("Transforming the the input dataframe")
ts_parser = TSParser()
uast_parser = UASTParser()
def get_uast_json(code, lang):
# Create case-insensitive mappings
language_map_lower = {key.lower(): value for key, value in self.language_map.items()}
uast_language_map_lower = {key.lower(): value for key, value in self.uast_language_map.items()}
# Check for the lowercase version of `lang`
lang_lower = lang.lower()
if lang_lower in language_map_lower:
ts_parser.set_language(language_map_lower[lang_lower])
uast_parser.set_language(uast_language_map_lower[lang_lower])
ast = ts_parser.parse(bytes(code, encoding="utf8"))
uast = uast_parser.parse(ast, code)
return uast.get_json()
return None
def extract_packages_from_uast(uast_json):
"""Extract package names from the UAST JSON where node_type is 'uast_package'."""
package_list = []
try:
uast_data = json.loads(uast_json)
if uast_data is not None:
nodes = uast_data.get("nodes", {})
else:
nodes = {}
print("Warning: uast_data is None. Check the data source or initialization process.")
return
# Iterate through nodes to find nodes with type 'uast_package'
for node_id, node_data in nodes.items():
if node_data.get("node_type") == "uast_package":
# Extract the package name from the 'code_snippet' (after 'uast_package ')
package_name = node_data["code_snippet"].split(" ")[1]
package_list.append(package_name)
except json.JSONDecodeError as e:
print(f"Failed to parse UAST JSON: {e}")
return ",".join(package_list) # Return as a comma-separated string
def get_uast_parquet(tmp_table):
# df = pd.read_parquet(f'{db_path}/{filename}', 'pyarrow')
# df = df.reindex(columns=all_columns)
# Extract language and content arrays from the table using PyArrow
print(self.language)
lang_array = tmp_table.column(self.language)
content_array = tmp_table.column(self.contents)
# Ensure both arrays have the same length
assert len(lang_array) == len(content_array)
# Generate UASTs using a list comprehension
uasts = [
json.dumps(get_uast_json(content_array[i].as_py(), lang_array[i].as_py()))
for i in range(len(content_array))
]
# Extract package lists from the UAST column
package_lists = [extract_packages_from_uast(uast) for uast in uasts]
# Add the UAST array as a new column in the PyArrow table
uast_column = pa.array(uasts)
package_list_column = pa.array(package_lists)
tmp_table_with_uast = tmp_table.append_column("UAST", uast_column)
# Add the uast_package column
table_with_package_list = tmp_table_with_uast.append_column("UAST_Package_List", package_list_column)
return table_with_package_list
table_with_uast = get_uast_parquet(table)
## Semantic profiling
self.logger.debug(f"Semantic profiling of one table with {len(table_with_uast)} rows")
# Load Knowledge Base
print(self.ikb_file)
print(self.null_libs_file)
ikb = knowledge_base(self.ikb_file, self.null_libs_file)
ikb.load_ikb_trie()
# Extract concept from IKB
libraries = table_with_uast.column("UAST_Package_List").to_pylist()
language = table_with_uast.column("language").to_pylist()
concepts = [concept_extractor(lib, lang, ikb) for lib, lang in zip(libraries, language)]
# Append concepts column to table and record unknown libraries
new_col = pa.array(concepts)
table_with_uast = table_with_uast.append_column("Concepts", new_col)
ikb.write_null_files()
# Higher order syntactic profiler
self.logger.debug(f"Transforming one table with {len(table_with_uast)} rows")
if self.metrics_list is not None:
uasts = [uast_read(uast_json) for uast_json in table_with_uast["UAST"].to_pylist()]
ccrs = []
code_snippet_len = []
avg_fn_len_in_snippet = []
for uast in uasts:
if "CCR" in self.metrics_list:
ccrs.append(extract_ccr(uast))
if "code_snippet_len" in self.metrics_list:
code_snippet_len.append(extract_code_snippet_length(uast))
if "avg_fn_len_in_snippet" in self.metrics_list:
avg_fn_len_in_snippet.append(extract_code_avg_fn_len_in_snippet(uast))
if "CCR" in self.metrics_list:
table_with_uast = table_with_uast.append_column("CCR", pa.array(ccrs))
if "code_snippet_len" in self.metrics_list:
table_with_uast = table_with_uast.append_column("code_snippet_len", pa.array(code_snippet_len))
if "avg_fn_len_in_snippet" in self.metrics_list:
table_with_uast = table_with_uast.append_column(
"avg_fn_len_in_snippet", pa.array(avg_fn_len_in_snippet)
)
self.logger.debug(f"Transformed one table with {len(table_with_uast)} rows")
metadata = {"nfiles": 1, "nrows": len(table_with_uast)}
# Report generation
if "UAST" in table_with_uast.schema.names and "Concepts" in table_with_uast.schema.names:
generate_report(table_with_uast, self.metrics_list)
# Add some sample metadata.
self.logger.debug(f"Transformed one table with {len(table_with_uast)} rows")
# report statistics
stats = {
"source_documents": table.num_columns,
"result_documents": table_with_uast.num_columns,
}
return [table_with_uast], stats
def clean_bindings(self):
try:
# Use an OS command to remove the folder and its contents
subprocess.run(["rm", "-rf", self.bindings_dir], check=True)
print(f"Successfully deleted: {self.bindings_dir}")
except subprocess.CalledProcessError as e:
print(f"Error deleting {self.bindings_dir}: {e}")
class CodeProfilerTransformConfiguration(TransformConfiguration):
def __init__(self, transform_class: type[AbstractBinaryTransform] = CodeProfilerTransform):
super().__init__(
name=short_name,
transform_class=transform_class,
)
def add_input_params(self, parser: ArgumentParser) -> None:
parser.add_argument(
f"--{language}",
type=str,
default="language",
help="Column name that denotes the programming language",
)
parser.add_argument(
f"--{contents}",
type=str,
default="contents",
help="Column name that contains code snippets",
)
def apply_input_params(self, args: Namespace) -> bool:
captured = CLIArgumentProvider.capture_parameters(args, cli_prefix, False)
self.params = captured
return True