From 179fcb54550ec1d9fd6fda3830777c2cb802cd75 Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Fri, 30 Sep 2016 16:29:02 -0700 Subject: [PATCH 1/3] Loading/classifying/inferencing Notify users on encoding for higher depth images. Fix _save_means. --- digits/tools/create_db.py | 51 +++++++++++++++++++++++----- digits/utils/image.py | 71 +++++++++++++++++++++++++++++++++------ 2 files changed, 103 insertions(+), 19 deletions(-) diff --git a/digits/tools/create_db.py b/digits/tools/create_db.py index 7b593f232..39eeb8061 100755 --- a/digits/tools/create_db.py +++ b/digits/tools/create_db.py @@ -269,6 +269,9 @@ def create_db(input_file, output_dir, write_queue = Queue.Queue(2*batch_size) summary_queue = Queue.Queue() + # Init helper function for notification between threads + _notification(set_flag=False) + for _ in xrange(num_threads): p = threading.Thread(target=_load_thread, args=(load_queue, write_queue, summary_queue, @@ -331,6 +334,9 @@ def _create_lmdb(image_count, write_queue, batch_size, output_dir, processed_something = False + if _notification(): + break + if not summary_queue.empty(): result_count, result_sum = summary_queue.get() images_loaded += result_count @@ -360,6 +366,9 @@ def _create_lmdb(image_count, write_queue, batch_size, output_dir, _write_batch_lmdb(db, batch, images_written) images_written += len(batch) + if _notification(): + raise WriteError('Encoding should be None for images with color depth higher than 8 bits.') + if images_loaded == 0: raise LoadError('no images loaded from input file') logger.debug('%s images loaded' % images_loaded) @@ -412,6 +421,9 @@ def _create_hdf5(image_count, write_queue, batch_size, output_dir, processed_something = False + if _notification(): + break + if not summary_queue.empty(): result_count, result_sum = summary_queue.get() images_loaded += result_count @@ -442,6 +454,9 @@ def _create_hdf5(image_count, write_queue, batch_size, output_dir, assert images_written == writer.count() + if _notification(): + raise WriteError('Encoding should be None for images with color depth higher than 8 bits.') + if images_loaded == 0: raise LoadError('no images loaded from input file') logger.debug('%s images loaded' % images_loaded) @@ -498,6 +513,14 @@ def _fill_load_queue(filename, queue, shuffle): return valid_lines +def _notification(set_flag=None): + if set_flag is None: + return _notification.flag + elif set_flag: + _notification.flag = True + else: + _notification.flag = False + def _parse_line(line, distribution): """ Parse a line in the input file into (path, label) @@ -579,12 +602,15 @@ def _load_thread(load_queue, write_queue, summary_queue, if compute_mean: image_sum += image - if backend == 'lmdb': - datum = _array_to_datum(image, label, encoding) - write_queue.put(datum) - else: - write_queue.put((image, label)) - + try: + if backend == 'lmdb': + datum = _array_to_datum(image, label, encoding) + write_queue.put(datum) + else: + write_queue.put((image, label)) + except IOError: # try to save 16-bit images with PNG/JPG encoding + _notification(True) + break images_added += 1 summary_queue.put((images_added, image_sum)) @@ -616,6 +642,8 @@ def _array_to_datum(image, label, encoding): image = image[np.newaxis,:,:] else: raise Exception('Image has unrecognized shape: "%s"' % image.shape) + if np.issubdtype(image.dtype, float): + image = image.astype(float) datum = caffe.io.array_to_datum(image, label) else: datum = caffe_pb2.Datum() @@ -667,7 +695,11 @@ def _save_means(image_sum, image_count, mean_files): """ Save mean[s] to file """ - mean = np.around(image_sum / image_count).astype(np.uint8) + mean = np.around(image_sum / image_count) + if mean.max()>255: + mean = mean.astype(np.float) + else: + mean = mean.astype(np.uint8) for mean_file in mean_files: if mean_file.lower().endswith('.npy'): np.save(mean_file, mean) @@ -693,7 +725,10 @@ def _save_means(image_sum, image_count, mean_files): with open(mean_file, 'wb') as outfile: outfile.write(blob.SerializeToString()) elif mean_file.lower().endswith(('.jpg', '.jpeg', '.png')): - image = PIL.Image.fromarray(mean) + if np.issubdtype(mean.dtype, np.float): + image = PIL.Image.fromarray(mean*255/mean.max()).convert('L') + else: + image = PIL.Image.fromarray(mean) image.save(mean_file) else: logger.warning('Unrecognized file extension for mean file: "%s"' % mean_file) diff --git a/digits/utils/image.py b/digits/utils/image.py index 88decefe3..f9052d476 100644 --- a/digits/utils/image.py +++ b/digits/utils/image.py @@ -11,6 +11,20 @@ except ImportError: from StringIO import StringIO +try: + import dicom + dicom_extension = ('.dcm', '.dicom') +except ImportError: + dicom = None + dicom_extension = None + +try: + import nifti + nifti_extension = ('.nii',) +except ImportError: + nifti = None + nifti_extension = None + import numpy as np import PIL.Image import scipy.misc @@ -35,6 +49,27 @@ # List of supported file extensions # Use like "if filename.endswith(SUPPORTED_EXTENSIONS)" SUPPORTED_EXTENSIONS = ('.png','.jpg','.jpeg','.bmp','.ppm') +if dicom is not None: + SUPPORTED_EXTENSIONS = SUPPORTED_EXTENSIONS + dicom_extension +if nifti is not None: + SUPPORTED_EXTENSIONS = SUPPORTED_EXTENSIONS + nifti_extension + +def load_image_ex(path): + """ + Handles images not recognized by load_image + Reads a file from `path` and returns a PIL.Image with mode 'F' (float32) + Raises LoadImageError + + Arguments: + path -- file system path to the image + """ + try: + dm = dicom.read_file(path) + except dicom.errors.InvalidDicomError: + raise errors.LoadImageError, 'Invalid Dicom file' + pixels = dm.pixel_array + image = PIL.Image.fromarray(pixels.astype(np.float)) + return image def load_image(path): """ @@ -62,11 +97,14 @@ def load_image(path): image = PIL.Image.open(path) image.load() except IOError as e: - raise errors.LoadImageError, 'IOError: %s' % e.message + if dicom is not None: + image = load_image_ex(path) + else: + raise errors.LoadImageError, e.message else: raise errors.LoadImageError, '"%s" not found' % path - if image.mode in ['L', 'RGB']: + if image.mode in ['L', 'RGB', 'F']: # No conversion necessary return image elif image.mode in ['1']: @@ -103,7 +141,7 @@ def upscale(image, ratio): width = int(math.floor(image.shape[1] * ratio)) height = int(math.floor(image.shape[0] * ratio)) channels = image.shape[2] - out = np.ndarray((height, width, channels),dtype=np.uint8) + out = np.ndarray((height, width, channels),dtype=image.dtype) for x, y in np.ndindex((width,height)): out[y,x] = image[int(math.floor(y/ratio)), int(math.floor(x/ratio))] return out @@ -128,7 +166,7 @@ def image_to_array(image, # Convert image mode (channels) if channels is None: image_mode = image.mode - if image_mode == 'L': + if image_mode == 'L' or image_mode == 'F': channels = 1 elif image_mode == 'RGB': channels = 3 @@ -136,7 +174,7 @@ def image_to_array(image, raise ValueError('unknown image mode "%s"' % image_mode) elif channels == 1: # 8-bit pixels, black and white - image_mode = 'L' + image_mode = image.mode if image.mode == 'F' else 'L' elif channels == 3: # 3x8-bit pixels, true color image_mode = 'RGB' @@ -208,8 +246,9 @@ def resize_image(image, height, width, width_ratio = float(image.shape[1]) / width height_ratio = float(image.shape[0]) / height + image_data_format = 'F' if np.issubdtype(image.dtype, float) else None if resize_mode == 'squash' or width_ratio == height_ratio: - return scipy.misc.imresize(image, (height, width), interp=interp) + return scipy.misc.imresize(image, (height, width), mode=image_data_format, interp=interp) elif resize_mode == 'crop': # resize to smallest of ratios (relatively larger image), keeping aspect ratio if width_ratio > height_ratio: @@ -218,7 +257,7 @@ def resize_image(image, height, width, else: resize_width = width resize_height = int(round(image.shape[0] / width_ratio)) - image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp) + image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp) # chop off ends of dimension that is still too long if width_ratio > height_ratio: @@ -240,7 +279,7 @@ def resize_image(image, height, width, resize_width = int(round(image.shape[1] / height_ratio)) if (width - resize_width) % 2 == 1: resize_width += 1 - image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp) + image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp) elif resize_mode == 'half_crop': # resize to average ratio keeping aspect ratio new_ratio = (width_ratio + height_ratio) / 2.0 @@ -250,7 +289,7 @@ def resize_image(image, height, width, resize_height += 1 elif width_ratio < height_ratio and (width - resize_width) % 2 == 1: resize_width += 1 - image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp) + image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp) # chop off ends of dimension that is still too long if width_ratio > height_ratio: start = int(round((resize_width-width)/2.0)) @@ -267,14 +306,20 @@ def resize_image(image, height, width, noise_size = (padding, width) if channels > 1: noise_size += (channels,) - noise = np.random.randint(0, 255, noise_size).astype('uint8') + if image_data_format == 'F': + noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype('float') + else: + noise = np.random.randint(0, 255, noise_size).astype('uint8') image = np.concatenate((noise, image, noise), axis=0) else: padding = (width - resize_width)/2 noise_size = (height, padding) if channels > 1: noise_size += (channels,) - noise = np.random.randint(0, 255, noise_size).astype('uint8') + if image_data_format == 'F': + noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype('float') + else: + noise = np.random.randint(0, 255, noise_size).astype('uint8') image = np.concatenate((noise, image, noise), axis=1) return image @@ -305,6 +350,10 @@ def embed_image_html(image): fmt = fmt.lower() string_buf = StringIO() + if image.mode == 'F': + min_pv, max_pv = image.getextrema() + tmp_array = 255.0*(np.asarray(image)-min_pv)/(max_pv-min_pv) + image = PIL.Image.fromarray(tmp_array.astype(np.uint8)) image.save(string_buf, format=fmt) data = string_buf.getvalue().encode('base64').replace('\n', '') return 'data:image/%s;base64,%s' % (fmt, data) From 3120bca295df1aca7050c1ff4905799725f806a2 Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Wed, 5 Oct 2016 16:45:17 -0700 Subject: [PATCH 2/3] UI to select 8 or 32 bit-depth Make _notification more general Passing 8/32 bit-depth among jobs, create_db --- digits/dataset/images/classification/job.py | 1 + digits/dataset/images/classification/views.py | 9 ++- digits/dataset/images/forms.py | 5 ++ digits/dataset/images/job.py | 1 + digits/dataset/tasks/create_db.py | 2 + .../datasets/images/classification/new.html | 7 ++ .../datasets/images/classification/show.html | 2 + .../images/classification/summary.html | 2 + .../datasets/images/generic/summary.html | 1 + digits/tools/create_db.py | 67 ++++++++++++------- digits/utils/image.py | 20 ++++-- 11 files changed, 83 insertions(+), 34 deletions(-) diff --git a/digits/dataset/images/classification/job.py b/digits/dataset/images/classification/job.py index 0696273c1..1ee700063 100644 --- a/digits/dataset/images/classification/job.py +++ b/digits/dataset/images/classification/job.py @@ -149,6 +149,7 @@ def json_dict(self, verbose=False): "image_width": t.image_dims[0], "image_height": t.image_dims[1], "image_channels": t.image_dims[2], + "image_bpp": t.resize_bpp, "backend": t.backend, "encoding": t.encoding, "compression": t.compression, diff --git a/digits/dataset/images/classification/views.py b/digits/dataset/images/classification/views.py index db118e7f5..947704ea0 100644 --- a/digits/dataset/images/classification/views.py +++ b/digits/dataset/images/classification/views.py @@ -109,6 +109,7 @@ def from_folders(job, form): backend = backend, image_dims = job.image_dims, resize_mode = job.resize_mode, + resize_bpp = job.resize_bpp, encoding = encoding, compression = compression, mean_file = utils.constants.MEAN_FILE_CAFFE, @@ -126,6 +127,7 @@ def from_folders(job, form): backend = backend, image_dims = job.image_dims, resize_mode = job.resize_mode, + resize_bpp = job.resize_bpp, encoding = encoding, compression = compression, labels_file = job.labels_file, @@ -142,6 +144,7 @@ def from_folders(job, form): backend = backend, image_dims = job.image_dims, resize_mode = job.resize_mode, + resize_bpp = job.resize_bpp, encoding = encoding, compression = compression, labels_file = job.labels_file, @@ -188,6 +191,7 @@ def from_files(job, form): image_dims = job.image_dims, image_folder= image_folder, resize_mode = job.resize_mode, + resize_bpp = job.resize.bpp, encoding = encoding, compression = compression, mean_file = utils.constants.MEAN_FILE_CAFFE, @@ -220,6 +224,7 @@ def from_files(job, form): image_dims = job.image_dims, image_folder= image_folder, resize_mode = job.resize_mode, + resize_bpp = job.resize_bpp, encoding = encoding, compression = compression, labels_file = job.labels_file, @@ -251,6 +256,7 @@ def from_files(job, form): image_dims = job.image_dims, image_folder= image_folder, resize_mode = job.resize_mode, + resize_bpp = job.resize_bpp, encoding = encoding, compression = compression, labels_file = job.labels_file, @@ -303,7 +309,8 @@ def create(): int(form.resize_width.data), int(form.resize_channels.data), ), - resize_mode = form.resize_mode.data + resize_mode = form.resize_mode.data, + resize_bpp = int(form.resize_bpp.data) ) if form.method.data == 'folder': diff --git a/digits/dataset/images/forms.py b/digits/dataset/images/forms.py index 171f5d078..09dbedbcf 100644 --- a/digits/dataset/images/forms.py +++ b/digits/dataset/images/forms.py @@ -44,3 +44,8 @@ class ImageDatasetForm(DatasetForm): choices=ImageDatasetJob.resize_mode_choices(), tooltip = "Options for dealing with aspect ratio changes during resize. See examples below." ) + resize_bpp = utils.forms.SelectField(u'Bits per pixel', + default='8', + choices=[('8', '8-bit (color or grayscale)'), ('32', '32-bit floating point (grayscale only)')], + tooltip="Storing 32-bit floating point for certain medical images." + ) diff --git a/digits/dataset/images/job.py b/digits/dataset/images/job.py index e3764cd13..bf9fdb6a5 100644 --- a/digits/dataset/images/job.py +++ b/digits/dataset/images/job.py @@ -19,6 +19,7 @@ def __init__(self, **kwargs): """ self.image_dims = kwargs.pop('image_dims', None) self.resize_mode = kwargs.pop('resize_mode', None) + self.resize_bpp = kwargs.pop('resize_bpp', None) super(ImageDatasetJob, self).__init__(**kwargs) self.pickver_job_dataset_image = PICKLE_VERSION diff --git a/digits/dataset/tasks/create_db.py b/digits/dataset/tasks/create_db.py index c716ef371..cfb8de65d 100644 --- a/digits/dataset/tasks/create_db.py +++ b/digits/dataset/tasks/create_db.py @@ -40,6 +40,7 @@ def __init__(self, input_file, db_name, backend, image_dims, **kwargs): self.shuffle = kwargs.pop('shuffle', True) self.resize_mode = kwargs.pop('resize_mode' , None) self.encoding = kwargs.pop('encoding', None) + self.resize_bpp = kwargs.pop('resize_bpp', None) self.compression = kwargs.pop('compression', None) self.mean_file = kwargs.pop('mean_file', None) self.labels_file = kwargs.pop('labels_file', None) @@ -147,6 +148,7 @@ def task_arguments(self, resources, env): '--backend=%s' % self.backend, '--channels=%s' % self.image_dims[2], '--resize_mode=%s' % self.resize_mode, + '--resize_bpp=%s' % self.resize_bpp ] if self.mean_file is not None: diff --git a/digits/templates/datasets/images/classification/new.html b/digits/templates/datasets/images/classification/new.html index cce9ed204..a5f35a2df 100644 --- a/digits/templates/datasets/images/classification/new.html +++ b/digits/templates/datasets/images/classification/new.html @@ -56,6 +56,13 @@

New Image Classification Dataset

{{ form.resize_mode(class='form-control') }} +
+
+ {{ form.resize_bpp.label }} + {{ form.resize_bpp.tooltip }} + {{ form.resize_bpp(class='form-control') }} +
+
See example
diff --git a/digits/templates/datasets/images/classification/show.html b/digits/templates/datasets/images/classification/show.html index cd099aa59..fe40b3851 100644 --- a/digits/templates/datasets/images/classification/show.html +++ b/digits/templates/datasets/images/classification/show.html @@ -19,6 +19,8 @@

Job Information

{{'Color' if job.image_dims[2] == 3 else 'Grayscale'}}
Resize Transformation
{{ job.resize_mode_name() }}
+
Image bit-depth
+
{{ job.resize_bpp }}
DB Backend
{{job.get_backend()}}
Image Encoding
diff --git a/digits/templates/datasets/images/classification/summary.html b/digits/templates/datasets/images/classification/summary.html index 9f6cab95a..998291104 100644 --- a/digits/templates/datasets/images/classification/summary.html +++ b/digits/templates/datasets/images/classification/summary.html @@ -24,6 +24,8 @@

GRAYSCALE {% endif %} +
Image bit-depth
+
{{dataset.resize_bpp}}
DB backend
{{dataset.get_backend()}}
{% for task in dataset.create_db_tasks() %} diff --git a/digits/templates/datasets/images/generic/summary.html b/digits/templates/datasets/images/generic/summary.html index e3754146c..f149a49a8 100644 --- a/digits/templates/datasets/images/generic/summary.html +++ b/digits/templates/datasets/images/generic/summary.html @@ -21,6 +21,7 @@

  • Image Count - {{task.image_count}}
  • Image Dimensions - {{task.image_width}}x{{task.image_height}}x{{task.image_channels}}
  • +
  • Image bit-depth
  • {% endfor %} diff --git a/digits/tools/create_db.py b/digits/tools/create_db.py index 39eeb8061..b81c80902 100755 --- a/digits/tools/create_db.py +++ b/digits/tools/create_db.py @@ -202,6 +202,7 @@ def create_db(input_file, output_dir, image_width, image_height, image_channels, backend, resize_mode = None, + resize_bpp = None, image_folder = None, shuffle = True, mean_files = None, @@ -241,6 +242,8 @@ def create_db(input_file, output_dir, raise ValueError('invalid number of channels') if resize_mode not in [None, 'crop', 'squash', 'fill', 'half_crop']: raise ValueError('invalid resize_mode') + if resize_bpp not in [None, '8', '32']: + raise ValueError('invalid resize_bpp') if image_folder is not None and not os.path.exists(image_folder): raise ValueError('image_folder does not exist') if mean_files: @@ -270,13 +273,13 @@ def create_db(input_file, output_dir, summary_queue = Queue.Queue() # Init helper function for notification between threads - _notification(set_flag=False) + _notification(reset=True) for _ in xrange(num_threads): p = threading.Thread(target=_load_thread, args=(load_queue, write_queue, summary_queue, image_width, image_height, image_channels, - resize_mode, image_folder, compute_mean), + resize_mode, image_folder, compute_mean, resize_bpp), kwargs={'backend': backend, 'encoding': kwargs.get('encoding', None)}, ) @@ -367,7 +370,7 @@ def _create_lmdb(image_count, write_queue, batch_size, output_dir, images_written += len(batch) if _notification(): - raise WriteError('Encoding should be None for images with color depth higher than 8 bits.') + raise WriteError('. '.join(_notification())) if images_loaded == 0: raise LoadError('no images loaded from input file') @@ -421,9 +424,6 @@ def _create_hdf5(image_count, write_queue, batch_size, output_dir, processed_something = False - if _notification(): - break - if not summary_queue.empty(): result_count, result_sum = summary_queue.get() images_loaded += result_count @@ -454,9 +454,6 @@ def _create_hdf5(image_count, write_queue, batch_size, output_dir, assert images_written == writer.count() - if _notification(): - raise WriteError('Encoding should be None for images with color depth higher than 8 bits.') - if images_loaded == 0: raise LoadError('no images loaded from input file') logger.debug('%s images loaded' % images_loaded) @@ -513,13 +510,27 @@ def _fill_load_queue(filename, queue, shuffle): return valid_lines -def _notification(set_flag=None): - if set_flag is None: - return _notification.flag - elif set_flag: - _notification.flag = True +def _notification(reset=False, message=None): + """ + + Args: + reset: clear the message list if True + message: the error message + + Returns: + False: if no message stored and not reset + The messages (a list): if some messages stored + """ + if not reset: + if message is None: + if len(_notification.messages) == 0: + return False + else: + return _notification.messages + else: + _notification.messages.append(message) else: - _notification.flag = False + _notification.messages = list() def _parse_line(line, distribution): """ @@ -564,7 +575,7 @@ def _calculate_num_threads(batch_size, shuffle): def _load_thread(load_queue, write_queue, summary_queue, image_width, image_height, image_channels, - resize_mode, image_folder, compute_mean, + resize_mode, image_folder, compute_mean, resize_bpp, backend=None, encoding=None): """ Consumes items in load_queue @@ -597,6 +608,7 @@ def _load_thread(load_queue, write_queue, summary_queue, image_height, image_width, channels = image_channels, resize_mode = resize_mode, + resize_bpp = resize_bpp ) if compute_mean: @@ -608,8 +620,8 @@ def _load_thread(load_queue, write_queue, summary_queue, write_queue.put(datum) else: write_queue.put((image, label)) - except IOError: # try to save 16-bit images with PNG/JPG encoding - _notification(True) + except IOError as e: # report error to user (possibly save 16-bit image to PNG/JPG) + _notification(message=e.message) break images_added += 1 @@ -695,11 +707,7 @@ def _save_means(image_sum, image_count, mean_files): """ Save mean[s] to file """ - mean = np.around(image_sum / image_count) - if mean.max()>255: - mean = mean.astype(np.float) - else: - mean = mean.astype(np.uint8) + mean = np.around(image_sum / image_count).astype(np.float) for mean_file in mean_files: if mean_file.lower().endswith('.npy'): np.save(mean_file, mean) @@ -725,10 +733,13 @@ def _save_means(image_sum, image_count, mean_files): with open(mean_file, 'wb') as outfile: outfile.write(blob.SerializeToString()) elif mean_file.lower().endswith(('.jpg', '.jpeg', '.png')): - if np.issubdtype(mean.dtype, np.float): + #ensure pixel range is within supported format + if mean.max() < 256: # works for three formats + image = PIL.Image.fromarray(mean.astype(np.uint8)) + elif mean_file.lower().endswith('.png'): # png supports higher color depth + image = PIL.Image.fromarray(mean).convert('I') + else: # reduce color depth for jpg or jpeg image = PIL.Image.fromarray(mean*255/mean.max()).convert('L') - else: - image = PIL.Image.fromarray(mean) image.save(mean_file) else: logger.warning('Unrecognized file extension for mean file: "%s"' % mean_file) @@ -765,6 +776,9 @@ def _save_means(image_sum, image_count, mean_files): parser.add_argument('-r', '--resize_mode', help='resize mode for images (must be "crop", "squash" [default], "fill" or "half_crop")' ) + parser.add_argument('--resize_bpp', + help='bit per pixel for resized images (must be 8 (color/grayscale) or 32 (grayscale only)")' + ) parser.add_argument('-m', '--mean_file', action='append', help="location to output the image mean (doesn't save mean if not specified)") parser.add_argument('-f', '--image_folder', @@ -801,6 +815,7 @@ def _save_means(image_sum, image_count, mean_files): args['width'], args['height'], args['channels'], args['backend'], resize_mode = args['resize_mode'], + resize_bpp = args['resize_bpp'], image_folder = args['image_folder'], shuffle = args['shuffle'], mean_files = args['mean_file'], diff --git a/digits/utils/image.py b/digits/utils/image.py index f9052d476..a88051bb9 100644 --- a/digits/utils/image.py +++ b/digits/utils/image.py @@ -215,6 +215,7 @@ def image_to_array(image, def resize_image(image, height, width, channels=None, resize_mode=None, + resize_bpp='8' ): """ Resizes an image and returns it as a np.array @@ -227,19 +228,24 @@ def resize_image(image, height, width, Keyword Arguments: channels -- channels of new image (stays unchanged if not specified) resize_mode -- can be crop, squash, fill or half_crop + resize_bpp -- bits per pixel (per channel), either 8 or 32. """ if resize_mode is None: resize_mode = 'squash' if resize_mode not in ['crop', 'squash', 'fill', 'half_crop']: raise ValueError('resize_mode "%s" not supported' % resize_mode) + if resize_bpp not in ['8', '32']: + raise ValueError('resize_bpp "%s" not supported' % resize_bpp) + + target_dtype = np.uint8 if resize_bpp == '8' else np.float # convert to array image = image_to_array(image, channels) # No need to resize if image.shape[0] == height and image.shape[1] == width: - return image + return image.astype(target_dtype) ### Resize interp = 'bilinear' @@ -306,20 +312,20 @@ def resize_image(image, height, width, noise_size = (padding, width) if channels > 1: noise_size += (channels,) - if image_data_format == 'F': - noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype('float') + if target_dtype == np.float: + noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype(target_dtype) else: - noise = np.random.randint(0, 255, noise_size).astype('uint8') + noise = np.random.randint(0, 255, noise_size).astype(target_dtype) image = np.concatenate((noise, image, noise), axis=0) else: padding = (width - resize_width)/2 noise_size = (height, padding) if channels > 1: noise_size += (channels,) - if image_data_format == 'F': - noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype('float') + if target_dtype == np.float: + noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype(target_dtype) else: - noise = np.random.randint(0, 255, noise_size).astype('uint8') + noise = np.random.randint(0, 255, noise_size).astype(target_dtype) image = np.concatenate((noise, image, noise), axis=1) return image From b20261bc22adc7e06e8b9bdac88f65f14cf7fd94 Mon Sep 17 00:00:00 2001 From: Isaac Yang Date: Thu, 6 Oct 2016 13:18:46 -0700 Subject: [PATCH 3/3] Add PyDicom in requirements.txt Remove import error handling on dicom package Use model's train data to limit inference data Conditional display bit-depth of dataset Save to 16-bit png for Torch inferencing --- digits/dataset/images/classification/views.py | 2 +- digits/model/tasks/torch_train.py | 4 ++ .../datasets/images/classification/show.html | 2 + .../images/classification/summary.html | 2 + .../datasets/images/generic/summary.html | 1 - digits/tools/create_db.py | 15 ++++--- digits/tools/inference.py | 9 ++++- digits/utils/image.py | 40 +++++-------------- requirements.txt | 1 + 9 files changed, 38 insertions(+), 38 deletions(-) diff --git a/digits/dataset/images/classification/views.py b/digits/dataset/images/classification/views.py index 947704ea0..566f822be 100644 --- a/digits/dataset/images/classification/views.py +++ b/digits/dataset/images/classification/views.py @@ -191,7 +191,7 @@ def from_files(job, form): image_dims = job.image_dims, image_folder= image_folder, resize_mode = job.resize_mode, - resize_bpp = job.resize.bpp, + resize_bpp = job.resize_bpp, encoding = encoding, compression = compression, mean_file = utils.constants.MEAN_FILE_CAFFE, diff --git a/digits/model/tasks/torch_train.py b/digits/model/tasks/torch_train.py index e2ce559e5..bf4ee7d0c 100644 --- a/digits/model/tasks/torch_train.py +++ b/digits/model/tasks/torch_train.py @@ -515,6 +515,8 @@ def infer_one_image(self, image, snapshot_epoch=None, layers=None, gpu=None): temp_image_handle, temp_image_path = tempfile.mkstemp(suffix='.png') os.close(temp_image_handle) image = PIL.Image.fromarray(image) + if image.mode == 'F': + image = image.convert('L') try: image.save(temp_image_path, format='png') except KeyError: @@ -810,6 +812,8 @@ def infer_many_images(self, images, snapshot_epoch=None, gpu=None): temp_image_handle, temp_image_path = tempfile.mkstemp( dir=temp_dir_path, suffix='.png') image = PIL.Image.fromarray(image) + if image.mode == 'F': + image = image.convert('L') try: image.save(temp_image_path, format='png') except KeyError: diff --git a/digits/templates/datasets/images/classification/show.html b/digits/templates/datasets/images/classification/show.html index fe40b3851..6bdd6ac70 100644 --- a/digits/templates/datasets/images/classification/show.html +++ b/digits/templates/datasets/images/classification/show.html @@ -19,8 +19,10 @@

    Job Information

    {{'Color' if job.image_dims[2] == 3 else 'Grayscale'}}
    Resize Transformation
    {{ job.resize_mode_name() }}
    + {% if job.resize_bpp is defined %}
    Image bit-depth
    {{ job.resize_bpp }}
    + {% endif %}
    DB Backend
    {{job.get_backend()}}
    Image Encoding
    diff --git a/digits/templates/datasets/images/classification/summary.html b/digits/templates/datasets/images/classification/summary.html index 998291104..8767849f8 100644 --- a/digits/templates/datasets/images/classification/summary.html +++ b/digits/templates/datasets/images/classification/summary.html @@ -24,8 +24,10 @@

    GRAYSCALE {% endif %} + {% if dataset.resize_bpp is defined %}
    Image bit-depth
    {{dataset.resize_bpp}}
    + {% endif %}
    DB backend
    {{dataset.get_backend()}}
    {% for task in dataset.create_db_tasks() %} diff --git a/digits/templates/datasets/images/generic/summary.html b/digits/templates/datasets/images/generic/summary.html index f149a49a8..e3754146c 100644 --- a/digits/templates/datasets/images/generic/summary.html +++ b/digits/templates/datasets/images/generic/summary.html @@ -21,7 +21,6 @@

  • Image Count - {{task.image_count}}
  • Image Dimensions - {{task.image_width}}x{{task.image_height}}x{{task.image_channels}}
  • -
  • Image bit-depth
  • {% endfor %} diff --git a/digits/tools/create_db.py b/digits/tools/create_db.py index b81c80902..d4e4af82b 100755 --- a/digits/tools/create_db.py +++ b/digits/tools/create_db.py @@ -202,7 +202,7 @@ def create_db(input_file, output_dir, image_width, image_height, image_channels, backend, resize_mode = None, - resize_bpp = None, + resize_bpp = '8', image_folder = None, shuffle = True, mean_files = None, @@ -221,6 +221,7 @@ def create_db(input_file, output_dir, Keyword arguments: resize_mode -- passed to utils.image.resize_image() + resize_bpp -- bit-depth of image on storage shuffle -- if True, shuffle the images in the list before creating mean_files -- a list of mean files to save """ @@ -242,7 +243,7 @@ def create_db(input_file, output_dir, raise ValueError('invalid number of channels') if resize_mode not in [None, 'crop', 'squash', 'fill', 'half_crop']: raise ValueError('invalid resize_mode') - if resize_bpp not in [None, '8', '32']: + if resize_bpp not in ['8', '32']: raise ValueError('invalid resize_bpp') if image_folder is not None and not os.path.exists(image_folder): raise ValueError('image_folder does not exist') @@ -616,7 +617,7 @@ def _load_thread(load_queue, write_queue, summary_queue, try: if backend == 'lmdb': - datum = _array_to_datum(image, label, encoding) + datum = _array_to_datum(image, label, encoding, bpp=resize_bpp) write_queue.put(datum) else: write_queue.put((image, label)) @@ -636,7 +637,7 @@ def _initial_image_sum(width, height, channels): else: return np.zeros((height, width, channels), np.float64) -def _array_to_datum(image, label, encoding): +def _array_to_datum(image, label, encoding, bpp='8'): """ Create a caffe Datum from a numpy.ndarray """ @@ -654,10 +655,10 @@ def _array_to_datum(image, label, encoding): image = image[np.newaxis,:,:] else: raise Exception('Image has unrecognized shape: "%s"' % image.shape) - if np.issubdtype(image.dtype, float): + if bpp == '32': image = image.astype(float) datum = caffe.io.array_to_datum(image, label) - else: + elif bpp == '8': datum = caffe_pb2.Datum() if image.ndim == 3: datum.channels = image.shape[2] @@ -676,6 +677,8 @@ def _array_to_datum(image, label, encoding): raise ValueError('Invalid encoding type') datum.data = s.getvalue() datum.encoded = True + else: + raise ValueError('32 bit-depth can not encoded to PNG/JPG') return datum def _write_batch_lmdb(db, batch, image_count): diff --git a/digits/tools/inference.py b/digits/tools/inference.py index aa1a71697..f5faed029 100755 --- a/digits/tools/inference.py +++ b/digits/tools/inference.py @@ -83,6 +83,7 @@ def infer(input_list, width = image_dims[1] channels = image_dims[2] resize_mode = dataset.resize_mode if hasattr(dataset, 'resize_mode') else 'squash' + resize_bpp = str(dataset.resize_bpp) if hasattr(dataset, 'resize_bpp') else '8' # model's train data bit-depth n_input_samples = 0 # number of samples we were able to load input_ids = [] # indices of samples within file list @@ -126,13 +127,19 @@ def infer(input_list, path = path.strip() try: image = utils.image.load_image(path.strip()) + # model trained with 8-bit image, but we gives high + # bit-depth image. + # Inversely, it's not optimal but may be OK (high bit-depth trained model to inference 8-bit image) + if image.mode == 'F' and resize_bpp == '8': + raise InferenceError('Model trained with 8-bit, can not handle images w/ high bit-depth') if resize: image = utils.image.resize_image( image, height, width, channels=channels, - resize_mode=resize_mode) + resize_mode=resize_mode, + resize_bpp=resize_bpp) else: image = utils.image.image_to_array( image, diff --git a/digits/utils/image.py b/digits/utils/image.py index a88051bb9..8b3bdfe0e 100644 --- a/digits/utils/image.py +++ b/digits/utils/image.py @@ -11,19 +11,7 @@ except ImportError: from StringIO import StringIO -try: - import dicom - dicom_extension = ('.dcm', '.dicom') -except ImportError: - dicom = None - dicom_extension = None - -try: - import nifti - nifti_extension = ('.nii',) -except ImportError: - nifti = None - nifti_extension = None +import dicom import numpy as np import PIL.Image @@ -48,11 +36,7 @@ # List of supported file extensions # Use like "if filename.endswith(SUPPORTED_EXTENSIONS)" -SUPPORTED_EXTENSIONS = ('.png','.jpg','.jpeg','.bmp','.ppm') -if dicom is not None: - SUPPORTED_EXTENSIONS = SUPPORTED_EXTENSIONS + dicom_extension -if nifti is not None: - SUPPORTED_EXTENSIONS = SUPPORTED_EXTENSIONS + nifti_extension +SUPPORTED_EXTENSIONS = ('.png','.jpg','.jpeg','.bmp','.ppm', '.dcm', '.dicom') def load_image_ex(path): """ @@ -65,9 +49,11 @@ def load_image_ex(path): """ try: dm = dicom.read_file(path) - except dicom.errors.InvalidDicomError: - raise errors.LoadImageError, 'Invalid Dicom file' + except: + raise errors.LoadImageError, 'Unable to load Dicom file' pixels = dm.pixel_array + if len(pixels.shape) != 2: + raise errors.LoadImageError, 'Currently support Dicom 2-D image only' image = PIL.Image.fromarray(pixels.astype(np.float)) return image @@ -96,11 +82,8 @@ def load_image(path): try: image = PIL.Image.open(path) image.load() - except IOError as e: - if dicom is not None: - image = load_image_ex(path) - else: - raise errors.LoadImageError, e.message + except IOError: + image = load_image_ex(path) else: raise errors.LoadImageError, '"%s" not found' % path @@ -215,7 +198,7 @@ def image_to_array(image, def resize_image(image, height, width, channels=None, resize_mode=None, - resize_bpp='8' + resize_bpp=None ): """ Resizes an image and returns it as a np.array @@ -235,9 +218,8 @@ def resize_image(image, height, width, resize_mode = 'squash' if resize_mode not in ['crop', 'squash', 'fill', 'half_crop']: raise ValueError('resize_mode "%s" not supported' % resize_mode) - if resize_bpp not in ['8', '32']: - raise ValueError('resize_bpp "%s" not supported' % resize_bpp) - + if resize_bpp is None: + resize_bpp = '8' target_dtype = np.uint8 if resize_bpp == '8' else np.float # convert to array diff --git a/requirements.txt b/requirements.txt index 013c51241..dd64b1314 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,3 +16,4 @@ h5py>=2.2.1,<=2.6.0 pydot>=1.0.28,<=1.0.29 psutil>=1.2.1,<=3.4.2 matplotlib>=1.3.1,<=1.5.1 +pydicom>=0.9.7