@@ -85,15 +85,16 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
85
85
# The data_json variable will contain the metadata for the dataset
86
86
# that you should use to train the model.
87
87
def parse_args ():
88
- """ Dataset file and model output directory are required parameters . These
89
- must be parsed as command line arguments and then used as the model input
90
- and output, respectively .
88
+ """ Returns dataset file, model output directory, and num_epochs if present . These
89
+ must be parsed as command line arguments and then used as the model input and output, respectively.
90
+ The number of epochs can be used to optionally override the default .
91
91
"""
92
92
parser = argparse.ArgumentParser()
93
93
parser.add_argument(" --dataset_file" , dest = " data_json" , type = str )
94
94
parser.add_argument(" --model_output_directory" , dest = " model_dir" , type = str )
95
+ parser.add_argument(" --num_epochs" , dest = " num_epochs" , type = int )
95
96
args = parser.parse_args()
96
- return args.data_json, args.model_dir
97
+ return args.data_json, args.model_dir, args.num_epochs
97
98
98
99
# This is used for parsing the dataset file (produced and stored in Viam),
99
100
# parse it to get the label annotations
@@ -102,6 +103,7 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
102
103
filename : str , all_labels : ty.List[str ], model_type : str
103
104
) -> ty.Tuple[ty.List[str ], ty.List[str ]]:
104
105
""" Load and parse JSON file to return image filenames and corresponding labels.
106
+ The JSON file contains lines, where each line has the key "image_path" and "classification_annotations".
105
107
Args:
106
108
filename: JSONLines file containing filenames and labels
107
109
all_labels: list of all N_LABELS
@@ -121,6 +123,8 @@ Follow this guide to create, upload, and submit a Python script that loads a tra
121
123
if model_type == multi_label:
122
124
if annotation[" annotation_label" ] in all_labels:
123
125
labels.append(annotation[" annotation_label" ])
126
+ # For single label model, we want at most one label.
127
+ # If multiple valid labels are present, we arbitrarily select the last one.
124
128
if model_type == single_label:
125
129
if annotation[" annotation_label" ] in all_labels:
126
130
labels = [annotation[" annotation_label" ]]
0 commit comments