Skip to content

Commit

Permalink
add num_classes param
Browse files Browse the repository at this point in the history
  • Loading branch information
mpelchat04 committed Oct 7, 2024
1 parent e8c9e0c commit dddeafb
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions config/inference/default_binary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ inference:
heatmap_threshold: 0.3
flip: False
rotate: True
num_classes: 2

# GPU parameters
gpu: ${training.num_gpus}
Expand Down
1 change: 1 addition & 0 deletions config/inference/default_multiclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ inference:
heatmap_threshold: 0.3
flip: False
rotate: True
num_classes: 5

# GPU parameters
gpu: ${training.num_gpus}
Expand Down
2 changes: 2 additions & 0 deletions inference_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def main(params:Union[DictConfig, Dict]):
validate_path_exists=True)
input_stac_item = get_key_def('input_stac_item', params['inference'], expected_type=str, to_path=True,
validate_path_exists=True)
num_classes = get_key_def('num_classes', params['inference'], expected_type=int, default=5)
vectorize = get_key_def('ras2vec', params['inference'], expected_type=bool, default=False)
transform_flip = get_key_def('flip', params['inference'], expected_type=bool, default=False)
transform_rotate = get_key_def('rotate', params['inference'], expected_type=bool, default=False)
Expand Down Expand Up @@ -108,6 +109,7 @@ def main(params:Union[DictConfig, Dict]):
mask_to_vec=vectorize,
device=device_str,
gpu_id=gpu_index,
num_classes=num_classes,
prediction_threshold=prediction_threshold,
transformers=transforms,
transformer_flip=transform_flip,
Expand Down

0 comments on commit dddeafb

Please sign in to comment.