diff --git a/coco_data_split.py b/coco_data_split.py index 845fa83..5926483 100644 --- a/coco_data_split.py +++ b/coco_data_split.py @@ -37,7 +37,7 @@ def plot_label_frequencies(coco, data_path, title, ax, labels_common=None): - if labels_common is not None: + if len(labels_common) > 0: label_freqs = {l: 0 for l in labels_common} else: label_freqs = {l: 0 for l in CROP_ENCODING.values()} @@ -162,6 +162,22 @@ def create_dataframe(data_path, tiles, years, common_labels=None): common_lbls = common_labels(train_tiles | test_tiles) + # If common_lbls is an empty set, it is expected to equal None later + if common_lbls is not None and len(common_lbls) == 0: + common_lbls = None + + # If args.tiles is 'all', then train_tiles and test_tiles is also expected + # to be 'all' + if args.tiles == 'all': + train_tiles = 'all' + test_tiles = 'all' + + # If args.years is 'all', then train_years and test_years is also expected + # to be 'all' + if args.years == 'all': + train_years = 'all' + test_years = 'all' + # Define prefix if args.prefix is None: # No prefix given, use current timestamp