Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Merge pull request #1531 from jmancewicz/fix-banding-in-segmentation-…
Browse files Browse the repository at this point in the history
…visualization

Clamp distance values from segementation boundaries before begin conv…
  • Loading branch information
jmancewicz authored Mar 24, 2017
2 parents d8ec85a + 0214e85 commit 5dccb29
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions digits/extensions/view/imageSegmentation/view.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.
"""Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved."""
from __future__ import absolute_import

import json
Expand All @@ -25,11 +25,16 @@

@subclass
class Visualization(VisualizationInterface):
"""
A visualization extension to display the network output as an image
"""
"""A visualization extension to display the network output as an image."""

def __init__(self, dataset, **kwargs):
"""Constructor for Visualization class.
:param dataset:
:type dataset:
:param kwargs:
:type kwargs:
"""
# memorize view template for later use
extension_dir = os.path.dirname(os.path.abspath(__file__))
self.view_template = open(
Expand Down Expand Up @@ -64,11 +69,16 @@ def __init__(self, dataset, **kwargs):

@staticmethod
def get_config_form():
"""Utility function.
returns: ConfigForm().
"""
return ConfigForm()

@staticmethod
def get_config_template(form):
"""
"""Get the template and context.
parameters:
- form: form returned by get_config_form(). This may be populated
with values if the job was cloned
Expand All @@ -84,8 +94,8 @@ def get_config_template(form):
return (template, {'form': form})

def get_legend_for(self, found_classes, skip_classes=[]):
"""
Return the legend color image squares and text for each class
"""Return the legend color image squares and text for each class.
:param found_classes: list of class indices
:param skip_classes: list of class indices to skip
:return: list of dicts of text hex_color for each class
Expand All @@ -111,9 +121,7 @@ def get_legend_for(self, found_classes, skip_classes=[]):

@override
def get_header_template(self):
"""
Implements get_header_template() method from view extension interface
"""
"""Implement get_header_template method from view extension interface."""
extension_dir = os.path.dirname(os.path.abspath(__file__))
template = open(
os.path.join(extension_dir, HEADER_TEMPLATE), "r").read()
Expand All @@ -122,29 +130,31 @@ def get_header_template(self):

@override
def get_ng_templates(self):
"""
Implements get_ng_templates() method from view extension interface
"""
"""Implement get_ng_templates method from view extension interface."""
extension_dir = os.path.dirname(os.path.abspath(__file__))
header = open(os.path.join(extension_dir, APP_BEGIN_TEMPLATE), "r").read()
footer = open(os.path.join(extension_dir, APP_END_TEMPLATE), "r").read()
return header, footer

@staticmethod
def get_id():
"""returns: id string that identifies the extension."""
return 'image-segmentation'

@staticmethod
def get_title():
"""returns: name string to display in html."""
return 'Image Segmentation'

@staticmethod
def get_dirname():
"""returns: extension dir name to locate static dir."""
return 'imageSegmentation'

@override
def get_view_template(self, data):
"""
"""Get the view template.
returns:
- (template, context) tuple
- template is a Jinja template to use for rendering config options
Expand All @@ -165,9 +175,7 @@ def get_view_template(self, data):

@override
def process_data(self, input_id, input_data, output_data):
"""
Process one inference and return data to visualize
"""
"""Process one inference and return data to visualize."""
# assume the only output is a CHW image where C is the number
# of classes, H and W are the height and width of the image
class_data = output_data[output_data.keys()[0]].astype('float32')
Expand Down Expand Up @@ -226,6 +234,8 @@ def normalize(array):
max_distance = np.maximum(max_distance, distance + 128)

line_data[:, :, 3] = line_mask * 255
max_distance = np.maximum(max_distance, np.zeros(max_distance.shape, dtype=float))
max_distance = np.minimum(max_distance, np.zeros(max_distance.shape, dtype=float) + 255)
seg_data[:, :, 3] = max_distance

# Input image with outlines
Expand Down

0 comments on commit 5dccb29

Please sign in to comment.