Skip to content

Commit 42077fc

Browse files
authored
DOCS-2661: Updates to some custom training script functions (#3240)
1 parent 6731760 commit 42077fc

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

docs/services/ml/upload-training-script.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,16 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
8585
# The data_json variable will contain the metadata for the dataset
8686
# that you should use to train the model.
8787
def parse_args():
88-
"""Dataset file and model output directory are required parameters. These
89-
must be parsed as command line arguments and then used as the model input
90-
and output, respectively.
88+
"""Returns dataset file, model output directory, and num_epochs if present. These
89+
must be parsed as command line arguments and then used as the model input and output, respectively.
90+
The number of epochs can be used to optionally override the default.
9191
"""
9292
parser = argparse.ArgumentParser()
9393
parser.add_argument("--dataset_file", dest="data_json", type=str)
9494
parser.add_argument("--model_output_directory", dest="model_dir", type=str)
95+
parser.add_argument("--num_epochs", dest="num_epochs", type=int)
9596
args = parser.parse_args()
96-
return args.data_json, args.model_dir
97+
return args.data_json, args.model_dir, args.num_epochs
9798

9899
# This is used for parsing the dataset file (produced and stored in Viam),
99100
# parse it to get the label annotations
@@ -102,6 +103,7 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
102103
filename: str, all_labels: ty.List[str], model_type: str
103104
) -> ty.Tuple[ty.List[str], ty.List[str]]:
104105
"""Load and parse JSON file to return image filenames and corresponding labels.
106+
The JSON file contains lines, where each line has the key "image_path" and "classification_annotations".
105107
Args:
106108
filename: JSONLines file containing filenames and labels
107109
all_labels: list of all N_LABELS
@@ -121,6 +123,8 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
121123
if model_type == multi_label:
122124
if annotation["annotation_label"] in all_labels:
123125
labels.append(annotation["annotation_label"])
126+
# For single label model, we want at most one label.
127+
# If multiple valid labels are present, we arbitrarily select the last one.
124128
if model_type == single_label:
125129
if annotation["annotation_label"] in all_labels:
126130
labels = [annotation["annotation_label"]]

0 commit comments

Comments
 (0)