Skip to content

Commit

Permalink
bump requirements.txt versions, minor tweaks to python scripts
Browse files Browse the repository at this point in the history
Signed-off-by: Max Fisher <[email protected]>
  • Loading branch information
maxfisher-g committed Aug 28, 2023
1 parent 48603a3 commit 504dc13
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 26 deletions.
44 changes: 24 additions & 20 deletions make_tf_record.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
import argparse
import os
import os.path
import sys
from pathlib import Path

# disable tensorflow warnings on import:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
from pathlib import Path
from tensorflow.io import TFRecordWriter


# some code below adapted from
# keras.io/examples/keras_recipes/creating_tfrecords

def bytes_feature(value: str|bytes):
def bytes_feature(value: str | bytes):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, str):
value = value.encode()
value = value.encode(encoding="utf-8")

return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def create_example(source_filename: str, is_obfuscated: bool):
# note: not all files are in utf-8, so we don't decode yet
# TODO print warning if file is not in UTF-8
source_file_data = Path(source_filename).read_bytes()

feature = {
Expand All @@ -36,23 +39,24 @@ def create_example(source_filename: str, is_obfuscated: bool):
}
return tf.train.Example(features=tf.train.Features(feature=feature))


def main():
arg_parser = argparse.ArgumentParser(
description="Packs source files into tfrecord format with a label as either obfuscated or non-obfuscated"
)

arg_parser.add_argument("-d", "--dir", required=True, type=Path,
help="output directory for tfrecord files")
help="output directory for tfrecord files")
arg_parser.add_argument("-f", "--files", required=True, type=Path,
help="list of source files to process, one path per line")
help="list of source files to process, one path per line")
arg_parser.add_argument("-l", "--label", required=True, type=int,
help="obfuscation label for files, either 0 (non-obfuscated) or 1 (obfuscated)")
help="obfuscation label for files, either 0 (non-obfuscated) or 1 (obfuscated)")
arg_parser.add_argument("-n", "--nrecords", default="65536", type=int, metavar="NUM",
help="number of files to store in each tfrecord file (default 65536)")
help="number of files to store in each tfrecord file (default 65536)")
arg_parser.add_argument("-p", "--prefix", default="",
help="prefix to add to generated tfrecord filenames")
help="prefix to add to generated tfrecord filenames")
arg_parser.add_argument("-v", "--verbose", action="store_true", default=False,
help="print out verbose progress info")
help="print out verbose progress info")

parsed_args = arg_parser.parse_args()

Expand All @@ -72,41 +76,41 @@ def main():
return 1

if not os.path.exists(files_list):
print(f"path to file list does not exist: {non_obfuscated_js_files_list}")
print(f"path to file list does not exist: {files_list}")
return 1

if not record_dir.exists():
os.mkdir(record_dir)

record_number = 1
record_num = 1
with open(files_list, "rt") as source_files:
eof = False
while not eof:
record_filename = f"{prefix}{record_number}.tfrec"
record_filename = f"{prefix}{record_num}.tfrec"
with TFRecordWriter(bytes(record_dir / record_filename)) as writer:
sample_number = 0
sample_num = 0
while not eof:

source_filename = source_files.readline().strip()
if not source_filename:
eof = True
break

if print_progress:
print(f"[{record_filename} {sample_number % samples_per_record}/{samples_per_record}] {source_filename}")
print(f"[{record_filename} {sample_num % samples_per_record}/{samples_per_record}]",
source_filename)

example = create_example(source_filename, obfuscation_label)
writer.write(example.SerializeToString())

sample_number += 1
if sample_number % samples_per_record == 0:
sample_num += 1
if sample_num % samples_per_record == 0:
break

record_number += 1
record_num += 1

return 0


if __name__ == "__main__":
ret = main()
sys.exit(ret)

10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ gast==0.4.0
google-auth==2.22.0
google-auth-oauthlib==1.0.0
google-pasta==0.2.0
grpcio==1.56.2
grpcio==1.57.0
h5py==3.9.0
idna==3.4
keras==2.13.1
libclang==16.0.6
Markdown==3.4.3
Markdown==3.4.4
MarkupSafe==2.1.3
numpy==1.24.3
oauthlib==3.2.2
opt-einsum==3.3.0
packaging==23.1
protobuf==4.23.4
protobuf==4.24.2
pyasn1==0.5.0
pyasn1-modules==0.3.0
requests==2.31.0
Expand All @@ -31,9 +31,9 @@ tensorboard==2.13.0
tensorboard-data-server==0.7.1
tensorflow==2.13.0
tensorflow-estimator==2.13.0
tensorflow-io-gcs-filesystem==0.32.0
tensorflow-io-gcs-filesystem==0.33.0
termcolor==2.3.0
typing_extensions==4.5.0
urllib3==1.26.16
Werkzeug==2.3.6
Werkzeug==2.3.7
wrapt==1.15.0
2 changes: 1 addition & 1 deletion tokenizer_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def train_model(input_filenames: Iterable[str], verbose: bool) -> io.BytesIO:

min_log_level = LogLevel.INFO if verbose else LogLevel.WARNING

def filename_to_sentence(filename: str) -> bytes:
def filename_to_sentence(filename: str) -> str:
return Path(filename).read_text(errors="ignore")

model = io.BytesIO()
Expand Down

0 comments on commit 504dc13

Please sign in to comment.