diff --git a/make_tf_record.py b/make_tf_record.py index 26c2af7..b5a148a 100644 --- a/make_tf_record.py +++ b/make_tf_record.py @@ -1,32 +1,35 @@ import argparse +import os import os.path import sys +from pathlib import Path # disable tensorflow warnings on import: -import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import tensorflow as tf -from pathlib import Path from tensorflow.io import TFRecordWriter # some code below adapted from # keras.io/examples/keras_recipes/creating_tfrecords -def bytes_feature(value: str|bytes): +def bytes_feature(value: str | bytes): """Returns a bytes_list from a string / byte.""" if isinstance(value, str): - value = value.encode() + value = value.encode(encoding="utf-8") return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + def int64_feature(value): """Returns an int64_list from a bool / enum / int / uint.""" return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + def create_example(source_filename: str, is_obfuscated: bool): # note: not all files are in utf-8, so we don't decode yet + # TODO print warning if file is not in UTF-8 source_file_data = Path(source_filename).read_bytes() feature = { @@ -36,23 +39,24 @@ def create_example(source_filename: str, is_obfuscated: bool): } return tf.train.Example(features=tf.train.Features(feature=feature)) + def main(): arg_parser = argparse.ArgumentParser( description="Packs source files into tfrecord format with a label as either obfuscated or non-obfuscated" ) arg_parser.add_argument("-d", "--dir", required=True, type=Path, - help="output directory for tfrecord files") + help="output directory for tfrecord files") arg_parser.add_argument("-f", "--files", required=True, type=Path, - help="list of source files to process, one path per line") + help="list of source files to process, one path per line") arg_parser.add_argument("-l", "--label", required=True, type=int, - help="obfuscation label for files, either 0 (non-obfuscated) or 1 (obfuscated)") + help="obfuscation label for files, either 0 (non-obfuscated) or 1 (obfuscated)") arg_parser.add_argument("-n", "--nrecords", default="65536", type=int, metavar="NUM", - help="number of files to store in each tfrecord file (default 65536)") + help="number of files to store in each tfrecord file (default 65536)") arg_parser.add_argument("-p", "--prefix", default="", - help="prefix to add to generated tfrecord filenames") + help="prefix to add to generated tfrecord filenames") arg_parser.add_argument("-v", "--verbose", action="store_true", default=False, - help="print out verbose progress info") + help="print out verbose progress info") parsed_args = arg_parser.parse_args() @@ -72,41 +76,41 @@ def main(): return 1 if not os.path.exists(files_list): - print(f"path to file list does not exist: {non_obfuscated_js_files_list}") + print(f"path to file list does not exist: {files_list}") return 1 if not record_dir.exists(): os.mkdir(record_dir) - record_number = 1 + record_num = 1 with open(files_list, "rt") as source_files: eof = False while not eof: - record_filename = f"{prefix}{record_number}.tfrec" + record_filename = f"{prefix}{record_num}.tfrec" with TFRecordWriter(bytes(record_dir / record_filename)) as writer: - sample_number = 0 + sample_num = 0 while not eof: - source_filename = source_files.readline().strip() if not source_filename: eof = True break if print_progress: - print(f"[{record_filename} {sample_number % samples_per_record}/{samples_per_record}] {source_filename}") + print(f"[{record_filename} {sample_num % samples_per_record}/{samples_per_record}]", + source_filename) example = create_example(source_filename, obfuscation_label) writer.write(example.SerializeToString()) - sample_number += 1 - if sample_number % samples_per_record == 0: + sample_num += 1 + if sample_num % samples_per_record == 0: break - record_number += 1 + record_num += 1 return 0 + if __name__ == "__main__": ret = main() sys.exit(ret) - diff --git a/requirements.txt b/requirements.txt index 51078cd..568a162 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,18 +8,18 @@ gast==0.4.0 google-auth==2.22.0 google-auth-oauthlib==1.0.0 google-pasta==0.2.0 -grpcio==1.56.2 +grpcio==1.57.0 h5py==3.9.0 idna==3.4 keras==2.13.1 libclang==16.0.6 -Markdown==3.4.3 +Markdown==3.4.4 MarkupSafe==2.1.3 numpy==1.24.3 oauthlib==3.2.2 opt-einsum==3.3.0 packaging==23.1 -protobuf==4.23.4 +protobuf==4.24.2 pyasn1==0.5.0 pyasn1-modules==0.3.0 requests==2.31.0 @@ -31,9 +31,9 @@ tensorboard==2.13.0 tensorboard-data-server==0.7.1 tensorflow==2.13.0 tensorflow-estimator==2.13.0 -tensorflow-io-gcs-filesystem==0.32.0 +tensorflow-io-gcs-filesystem==0.33.0 termcolor==2.3.0 typing_extensions==4.5.0 urllib3==1.26.16 -Werkzeug==2.3.6 +Werkzeug==2.3.7 wrapt==1.15.0 diff --git a/tokenizer_training.py b/tokenizer_training.py index 91558c4..10c942b 100644 --- a/tokenizer_training.py +++ b/tokenizer_training.py @@ -28,7 +28,7 @@ def train_model(input_filenames: Iterable[str], verbose: bool) -> io.BytesIO: min_log_level = LogLevel.INFO if verbose else LogLevel.WARNING - def filename_to_sentence(filename: str) -> bytes: + def filename_to_sentence(filename: str) -> str: return Path(filename).read_text(errors="ignore") model = io.BytesIO()