Skip to content

Commit

Permalink
DOCS-2661: Updates to some custom training script functions
Browse files Browse the repository at this point in the history
  • Loading branch information
npentrel committed Aug 13, 2024
1 parent 6731760 commit 0189986
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions docs/services/ml/upload-training-script.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
# The data_json variable will contain the metadata for the dataset
# that you should use to train the model.
def parse_args():
"""Dataset file and model output directory are required parameters. These
must be parsed as command line arguments and then used as the model input
and output, respectively.
"""Returns dataset file, model output directory, and num_epochs if present. These must be parsed as command line
arguments and then used as the model input and output, respectively. The number of epochs can be used to optionally override the default.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_file", dest="data_json", type=str)
parser.add_argument("--model_output_directory", dest="model_dir", type=str)
parser.add_argument("--num_epochs", dest="num_epochs", type=int)
args = parser.parse_args()
return args.data_json, args.model_dir
return args.data_json, args.model_dir, args.num_epochs

# This is used for parsing the dataset file (produced and stored in Viam),
# parse it to get the label annotations
Expand All @@ -102,6 +102,7 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
filename: str, all_labels: ty.List[str], model_type: str
) -> ty.Tuple[ty.List[str], ty.List[str]]:
"""Load and parse JSON file to return image filenames and corresponding labels.
The JSON file contains lines, where each line has the key "image_path" and "classification_annotations".
Args:
filename: JSONLines file containing filenames and labels
all_labels: list of all N_LABELS
Expand All @@ -121,6 +122,8 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
if model_type == multi_label:
if annotation["annotation_label"] in all_labels:
labels.append(annotation["annotation_label"])
# For single label model, we want at most one label.
# If multiple valid labels are present, we arbitrarily select the last one.
if model_type == single_label:
if annotation["annotation_label"] in all_labels:
labels = [annotation["annotation_label"]]
Expand Down

0 comments on commit 0189986

Please sign in to comment.