diff --git a/config/inference/default_binary.yaml b/config/inference/default_binary.yaml index 624970ff..f3d07efa 100644 --- a/config/inference/default_binary.yaml +++ b/config/inference/default_binary.yaml @@ -11,6 +11,7 @@ inference: heatmap_threshold: 0.3 flip: False rotate: True + num_classes: 2 # GPU parameters gpu: ${training.num_gpus} diff --git a/config/inference/default_multiclass.yaml b/config/inference/default_multiclass.yaml index 69946c04..060bf61b 100644 --- a/config/inference/default_multiclass.yaml +++ b/config/inference/default_multiclass.yaml @@ -11,6 +11,7 @@ inference: heatmap_threshold: 0.3 flip: False rotate: True + num_classes: 5 # GPU parameters gpu: ${training.num_gpus} diff --git a/inference_segmentation.py b/inference_segmentation.py index cec07a09..00b05ff7 100644 --- a/inference_segmentation.py +++ b/inference_segmentation.py @@ -64,6 +64,7 @@ def main(params:Union[DictConfig, Dict]): validate_path_exists=True) input_stac_item = get_key_def('input_stac_item', params['inference'], expected_type=str, to_path=True, validate_path_exists=True) + num_classes = get_key_def('num_classes', params['inference'], expected_type=int, default=5) vectorize = get_key_def('ras2vec', params['inference'], expected_type=bool, default=False) transform_flip = get_key_def('flip', params['inference'], expected_type=bool, default=False) transform_rotate = get_key_def('rotate', params['inference'], expected_type=bool, default=False) @@ -108,6 +109,7 @@ def main(params:Union[DictConfig, Dict]): mask_to_vec=vectorize, device=device_str, gpu_id=gpu_index, + num_classes=num_classes, prediction_threshold=prediction_threshold, transformers=transforms, transformer_flip=transform_flip,