Skip to content

Commit

Permalink
Change sold2 detector config to dataclass (kornia#2880)
Browse files Browse the repository at this point in the history
* Introduce dataclasses for discussion

* Add dict_to_dataclass function

* Add dataclass_to_dict function

* Update SOLD2_detector to use newly introduced dataclasses

* Remove default_detector_cfg dict

* Remove comment

* Extend LineDetectorCfg to all configs used in LineSegmentDetectionModule

* Update LineSegmentDetectionModule init to use LineDetectorCfg dataclass

* Update LineSegmentDetectinoModule docstring

* Update SOLD2_detector docstring

* Update LineSegmentDetectionModule call to use its dataclass

* Remove dict_to_dataclass since its not used

* Add DeprecationWarning for dict as config in favour of dataclass config

Co-authored-by: João Gustavo A. Amorim <[email protected]>

* Fix DeprecationWarning

* Fix downstream errors of dataclass changes and typos LineDetectorCfg

* Add typing to dict_to_dataclass and dataclass_to_dict functions

* Fix typ checking

* Update dataclass typing to Any

Co-authored-by: João Gustavo A. Amorim <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dict_to_dataclass to TypeVar typing

Co-authored-by: João Gustavo A. Amorim <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move dataclass_to_dict and dict_to_dataclass to kornia/utils/helpers.py

* Fix type checking errors in dict_to_dataclass by bounding TypeVar to dataclass type

* Remove any from dict to dataclass

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: João Gustavo A. Amorim <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored May 7, 2024
1 parent f0ed53e commit e34ed3c
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 113 deletions.
23 changes: 3 additions & 20 deletions kornia/feature/sold2/sold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from kornia.utils import map_location_to_cpu

from .backbones import SOLD2Net
from .sold2_detector import LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions
from .sold2_detector import LineDetectorCfg, LineSegmentDetectionModule, line_map_to_segments, prob_to_junctions

urls: Dict[str, str] = {}
urls["wireframe"] = "http://cmp.felk.cvut.cz/~mishkdmy/models/sold2_wireframe.pth"
Expand All @@ -22,23 +22,7 @@
"keep_border_valid": True,
"detection_thresh": 0.0153846, # = 1/65: threshold of junction detection
"max_num_junctions": 500, # maximum number of junctions per image
"line_detector_cfg": {
"detect_thresh": 0.5,
"num_samples": 64,
"inlier_thresh": 0.99,
"use_candidate_suppression": True,
"nms_dist_tolerance": 3.0,
"use_heatmap_refinement": True,
"heatmap_refine_cfg": {
"mode": "local",
"ratio": 0.2,
"valid_thresh": 0.001,
"num_blocks": 20,
"overlap_ratio": 0.5,
},
"use_junction_refinement": True,
"junction_refine_cfg": {"num_perturbs": 9, "perturb_interval": 0.25},
},
"line_detector_cfg": LineDetectorCfg(),
"line_matcher_cfg": {
"cross_check": True,
"num_samples": 5,
Expand Down Expand Up @@ -92,8 +76,7 @@ def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = N
self.eval()

# Initialize the line detector
self.line_detector_cfg = self.config["line_detector_cfg"]
self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"])
self.line_detector = LineSegmentDetectionModule(LineDetectorCfg())

# Initialize the line matcher
self.line_matcher = WunschLineMatcher(**self.config["line_matcher_cfg"])
Expand Down
220 changes: 128 additions & 92 deletions kornia/feature/sold2/sold2_detector.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,71 @@
import math
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple

import torch

from kornia.core import Module, Tensor, concatenate, sin, stack, tensor, where, zeros
from kornia.core.check import KORNIA_CHECK_SHAPE
from kornia.geometry.bbox import nms
from kornia.utils import map_location_to_cpu, torch_meshgrid
from kornia.utils import dataclass_to_dict, dict_to_dataclass, map_location_to_cpu, torch_meshgrid

from .backbones import SOLD2Net

urls: Dict[str, str] = {}
urls["wireframe"] = "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download"


default_detector_cfg = {
"backbone_cfg": {"input_channel": 1, "depth": 4, "num_stacks": 2, "num_blocks": 1, "num_classes": 5},
"use_descriptor": False,
"grid_size": 8,
"keep_border_valid": True,
"detection_thresh": 0.0153846, # = 1/65: threshold of junction detection
"max_num_junctions": 500, # maximum number of junctions per image
"line_detector_cfg": {
"detect_thresh": 0.5,
"num_samples": 64,
"inlier_thresh": 0.99,
"use_candidate_suppression": True,
"nms_dist_tolerance": 3.0,
"use_heatmap_refinement": True,
"heatmap_refine_cfg": {
"mode": "local",
"ratio": 0.2,
"valid_thresh": 0.001,
"num_blocks": 20,
"overlap_ratio": 0.5,
},
"use_junction_refinement": True,
"junction_refine_cfg": {"num_perturbs": 9, "perturb_interval": 0.25},
},
}
@dataclass
class HeatMapRefineCfg:
mode: str = "local"
ratio: float = 0.2
valid_thresh: float = 0.001
num_blocks: int = 20
overlap_ratio: float = 0.5


@dataclass
class JunctionRefineCfg:
num_perturbs: int = 9
perturb_interval: float = 0.25


@dataclass
class LineDetectorCfg:
detect_thresh: float = 0.5
num_samples: int = 64
inlier_thresh: float = 0.99
use_candidate_suppression: bool = True
nms_dist_tolerance: float = 3.0
heatmap_low_thresh: float = 0.15
heatmap_high_thresh: float = 0.2
max_local_patch_radius: float = 3
lambda_radius: float = 2.0
use_heatmap_refinement: bool = True
heatmap_refine_cfg: HeatMapRefineCfg = field(default_factory=HeatMapRefineCfg)
use_junction_refinement: bool = True
junction_refine_cfg: JunctionRefineCfg = field(default_factory=JunctionRefineCfg)


@dataclass
class BackboneCfg:
input_channel: int = 1
depth: int = 4
num_stacks: int = 2
num_blocks: int = 1
num_classes: int = 5


@dataclass
class DetectorCfg:
backbone_cfg: BackboneCfg = field(default_factory=BackboneCfg)
use_descriptor: bool = False
grid_size: int = 8
keep_border_valid: bool = True
detection_thresh: float = 0.0153846 # = 1/65: threshold of junction detection
max_num_junctions: int = 500 # maximum number of junctions per image
line_detector_cfg: LineDetectorCfg = field(default_factory=LineDetectorCfg)


class SOLD2_detector(Module):
Expand All @@ -48,9 +75,10 @@ class SOLD2_detector(Module):
Occlusion-aware Line Detector and Descriptor". See :cite:`SOLD22021` for more details.
Args:
config: Dict specifying parameters. None will load the default parameters,
which are tuned for images in the range 400~800 px.
pretrained: If True, download and set pretrained weights to the model.
config (DetectorCfg): Configuration object containing all parameters. None will load the default parameters,
which are tuned for images in the range 400~800 px. Using a dataclass ensures type safety and clearer
parameter management.
pretrained (bool): If True, download and set pretrained weights to the model.
Returns:
The raw junction and line heatmaps, as well as the list of detected line segments (ij coordinates convention).
Expand All @@ -61,25 +89,34 @@ class SOLD2_detector(Module):
>>> line_segments = sold2_detector(img)["line_segments"]
"""

def __init__(self, pretrained: bool = True, config: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, pretrained: bool = True, config: Optional[DetectorCfg] = None) -> None:
if isinstance(config, dict):
warnings.warn(
"Usage of config as a plain dictionary is deprecated in favor of"
" `kornia.feature.sold2.sold2_detector.DetectorCfg`. The support of plain dictionaries"
"as config will be removed in kornia v0.8.0 (December 2024).",
category=DeprecationWarning,
stacklevel=2,
)
config = dict_to_dataclass(config, DetectorCfg)
super().__init__()
# Initialize some parameters
self.config = default_detector_cfg if config is None else config
self.grid_size = self.config["grid_size"]
self.junc_detect_thresh = self.config.get("detection_thresh", 1 / 65)
self.max_num_junctions = self.config.get("max_num_junctions", 500)
self.config = config if config is not None else DetectorCfg()
self.grid_size = self.config.grid_size
self.junc_detect_thresh = self.config.detection_thresh
self.max_num_junctions = self.config.max_num_junctions

# Load the pre-trained model
self.model = SOLD2Net(self.config)
self.model = SOLD2Net(dataclass_to_dict(self.config))

if pretrained:
pretrained_dict = torch.hub.load_state_dict_from_url(urls["wireframe"], map_location=map_location_to_cpu)
state_dict = self.adapt_state_dict(pretrained_dict["model_state_dict"])
self.model.load_state_dict(state_dict)
self.eval()

# Initialize the line detector
self.line_detector_cfg = self.config["line_detector_cfg"]
self.line_detector = LineSegmentDetectionModule(**self.config["line_detector_cfg"])
# Initialize the line detector with a configuration from the dataclass
self.line_detector = LineSegmentDetectionModule(self.config.line_detector_cfg)

def adapt_state_dict(self, state_dict: Dict[str, Any]) -> Dict[str, Any]:
del state_dict["w_junc"]
Expand Down Expand Up @@ -127,66 +164,62 @@ class LineSegmentDetectionModule:
r"""Module extracting line segments from junctions and line heatmaps.
Args:
detect_thresh: The probability threshold for mean activation (0. ~ 1.)
num_samples: Number of sampling locations along the line segments.
inlier_thresh: The min inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
heatmap_low_thresh: The lowest threshold for the pixel to be considered as candidate in junction recovery.
heatmap_high_thresh: The higher threshold for NMS in junction recovery.
max_local_patch_radius: The max patch to be considered in local maximum search.
lambda_radius: The lambda factor in linear local maximum search formulation
use_candidate_suppression: Apply candidate suppression to break long segments into short sub-segments.
nms_dist_tolerance: The distance tolerance for nms. Decide whether the junctions are on the line.
use_heatmap_refinement: Use heatmap refinement method or not.
heatmap_refine_cfg: The configs for heatmap refinement methods.
use_junction_refinement: Use junction refinement method or not.
junction_refine_cfg: The configs for junction refinement methods.
config (LineDetectorCfg): Configuration dataclass containing all settings required for line segment detection.
- detect_thresh (float): Probability threshold for mean activation (0. ~ 1.).
- num_samples (int): Number of sampling locations along the line segments.
- inlier_thresh (float): Minimum inlier ratio to satisfy (0. ~ 1.) => 0. means no threshold.
- heatmap_low_thresh (float): Lowest threshold for pixel considered as a candidate in junction recovery.
- heatmap_high_thresh (float): Higher threshold for NMS in junction recovery.
- max_local_patch_radius (float): Maximum patch to be considered in local maximum search.
- lambda_radius (float): Lambda factor in linear local maximum search formulation.
- use_candidate_suppression (bool): Apply candidate suppression to break long segments into sub-segments.
- nms_dist_tolerance (float): Distance tolerance for NMS. Decides whether the junctions are on the line.
- use_heatmap_refinement (bool): Whether to use heatmap refinement methods.
- heatmap_refine_cfg: Configuration for heatmap refinement methods.
- use_junction_refinement (bool): Whether to use junction refinement methods.
- junction_refine_cfg: Configuration for junction refinement methods.
Example:
>>> config = LineDetectorCfg(detect_thresh=0.6, use_heatmap_refinement=True)
>>> module = LineSegmentDetectionModule(config)
>>> junctions, heatmap = torch.rand(10, 2), torch.rand(256, 256)
>>> line_map, junctions, _ = module.detect(junctions, heatmap)
"""

def __init__(
self,
detect_thresh: float,
num_samples: int = 64,
inlier_thresh: float = 0.0,
heatmap_low_thresh: float = 0.15,
heatmap_high_thresh: float = 0.2,
max_local_patch_radius: float = 3,
lambda_radius: float = 2.0,
use_candidate_suppression: bool = False,
nms_dist_tolerance: float = 3.0,
use_heatmap_refinement: bool = False,
heatmap_refine_cfg: Optional[Dict[str, Any]] = None,
use_junction_refinement: bool = False,
junction_refine_cfg: Optional[Dict[str, Any]] = None,
) -> None:
def __init__(self, config: LineDetectorCfg = LineDetectorCfg()) -> None:
# Load LineDetectorCfg
self.config = config

# Line detection parameters
self.detect_thresh = detect_thresh
self.detect_thresh = self.config.detect_thresh
# self.detect_thresh = detect_thresh

# Line sampling parameters
self.num_samples = num_samples
self.inlier_thresh = inlier_thresh
self.local_patch_radius = max_local_patch_radius
self.lambda_radius = lambda_radius
self.num_samples = self.config.num_samples
self.inlier_thresh = self.config.inlier_thresh
self.local_patch_radius = self.config.max_local_patch_radius
self.lambda_radius = self.config.lambda_radius

# Detecting junctions on the boundary parameters
self.low_thresh = heatmap_low_thresh
self.high_thresh = heatmap_high_thresh
self.low_thresh = self.config.heatmap_low_thresh
self.high_thresh = self.config.heatmap_high_thresh

# Pre-compute the linspace sampler
self.torch_sampler = torch.linspace(0, 1, self.num_samples)

# Long line segment suppression configuration
self.use_candidate_suppression = use_candidate_suppression
self.nms_dist_tolerance = nms_dist_tolerance
self.use_candidate_suppression = self.config.use_candidate_suppression
self.nms_dist_tolerance = self.config.nms_dist_tolerance

# Heatmap refinement configuration
self.use_heatmap_refinement = use_heatmap_refinement
self.heatmap_refine_cfg = heatmap_refine_cfg
self.use_heatmap_refinement = self.config.use_heatmap_refinement
self.heatmap_refine_cfg = self.config.heatmap_refine_cfg
if self.use_heatmap_refinement and self.heatmap_refine_cfg is None:
raise ValueError("[Error] Missing heatmap refinement config.")

# Junction refinement configuration
self.use_junction_refinement = use_junction_refinement
self.junction_refine_cfg = junction_refine_cfg
self.use_junction_refinement = self.config.use_junction_refinement
self.junction_refine_cfg = self.config.junction_refine_cfg
if self.use_junction_refinement and self.junction_refine_cfg is None:
raise ValueError("[Error] Missing junction refinement config.")

Expand All @@ -197,18 +230,18 @@ def detect(self, junctions: Tensor, heatmap: Tensor) -> Tuple[Tensor, Tensor, Te
device = junctions.device

# Perform the heatmap refinement
if self.use_heatmap_refinement and isinstance(self.heatmap_refine_cfg, dict):
if self.heatmap_refine_cfg["mode"] == "global":
if self.use_heatmap_refinement and isinstance(self.heatmap_refine_cfg, HeatMapRefineCfg):
if self.heatmap_refine_cfg.mode == "global":
heatmap = self.refine_heatmap(
heatmap, self.heatmap_refine_cfg["ratio"], self.heatmap_refine_cfg["valid_thresh"]
heatmap, self.heatmap_refine_cfg.ratio, self.heatmap_refine_cfg.valid_thresh
)
elif self.heatmap_refine_cfg["mode"] == "local":
elif self.heatmap_refine_cfg.mode == "local":
heatmap = self.refine_heatmap_local(
heatmap,
self.heatmap_refine_cfg["num_blocks"],
self.heatmap_refine_cfg["overlap_ratio"],
self.heatmap_refine_cfg["ratio"],
self.heatmap_refine_cfg["valid_thresh"],
self.heatmap_refine_cfg.num_blocks,
self.heatmap_refine_cfg.overlap_ratio,
self.heatmap_refine_cfg.ratio,
self.heatmap_refine_cfg.valid_thresh,
)

# Initialize empty line map
Expand Down Expand Up @@ -393,10 +426,13 @@ def refine_junction_perturb(
) -> Tuple[Tensor, Tensor]:
"""Refine the line endpoints in a similar way as in LSD."""
# Fetch refinement parameters
if not isinstance(self.junction_refine_cfg, dict):
raise TypeError(f"Expected to have a dict of config for junction. Gotcha {type(self.junction_refine_cfg)}")
num_perturbs = self.junction_refine_cfg["num_perturbs"]
perturb_interval = self.junction_refine_cfg["perturb_interval"]
if not isinstance(self.junction_refine_cfg, JunctionRefineCfg):
raise TypeError(
"Expected to have dataclass of type JunctionRefineCfg for junction."
f"Gotcha {type(self.junction_refine_cfg)}"
)
num_perturbs = self.junction_refine_cfg.num_perturbs
perturb_interval = self.junction_refine_cfg.perturb_interval
side_perturbs = (num_perturbs - 1) // 2

# Fetch the 2D perturb mat
Expand Down
4 changes: 4 additions & 0 deletions kornia/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from .grid import create_meshgrid, create_meshgrid3d
from .helpers import (
_extract_device_dtype,
dataclass_to_dict,
deprecated,
dict_to_dataclass,
get_cuda_device_if_available,
get_cuda_or_mps_device_if_available,
get_mps_device_if_available,
Expand Down Expand Up @@ -58,4 +60,6 @@
"print_image",
"xla_is_available",
"is_mps_tensor_safe",
"dataclass_to_dict",
"dict_to_dataclass",
]
Loading

0 comments on commit e34ed3c

Please sign in to comment.