Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Supporting DICOM via pydicom #1136

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 43 additions & 8 deletions digits/tools/create_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ def create_db(input_file, output_dir,
write_queue = Queue.Queue(2*batch_size)
summary_queue = Queue.Queue()

# Init helper function for notification between threads
_notification(set_flag=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be more explicitly named or be made more generic? We could make the flag a bit field, and support a number of reasons for terminating threads (wrong encoding being just one of those reasons).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is too vague, I agreed. I thought Python didn't natively support bit field. I will make this a dict with meaningful key names. Hopefully, it can be a better helpful function.


for _ in xrange(num_threads):
p = threading.Thread(target=_load_thread,
args=(load_queue, write_queue, summary_queue,
Expand Down Expand Up @@ -331,6 +334,9 @@ def _create_lmdb(image_count, write_queue, batch_size, output_dir,

processed_something = False

if _notification():
break

if not summary_queue.empty():
result_count, result_sum = summary_queue.get()
images_loaded += result_count
Expand Down Expand Up @@ -360,6 +366,9 @@ def _create_lmdb(image_count, write_queue, batch_size, output_dir,
_write_batch_lmdb(db, batch, images_written)
images_written += len(batch)

if _notification():
raise WriteError('Encoding should be None for images with color depth higher than 8 bits.')

if images_loaded == 0:
raise LoadError('no images loaded from input file')
logger.debug('%s images loaded' % images_loaded)
Expand Down Expand Up @@ -412,6 +421,9 @@ def _create_hdf5(image_count, write_queue, batch_size, output_dir,

processed_something = False

if _notification():
break

if not summary_queue.empty():
result_count, result_sum = summary_queue.get()
images_loaded += result_count
Expand Down Expand Up @@ -442,6 +454,9 @@ def _create_hdf5(image_count, write_queue, batch_size, output_dir,

assert images_written == writer.count()

if _notification():
raise WriteError('Encoding should be None for images with color depth higher than 8 bits.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the notion of "encoding" is irrelevant for HDF5 databases

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. It clearly states on the beginning of Hdf5Writer function DTYPE='float32.'


if images_loaded == 0:
raise LoadError('no images loaded from input file')
logger.debug('%s images loaded' % images_loaded)
Expand Down Expand Up @@ -498,6 +513,14 @@ def _fill_load_queue(filename, queue, shuffle):

return valid_lines

def _notification(set_flag=None):
if set_flag is None:
return _notification.flag
elif set_flag:
_notification.flag = True
else:
_notification.flag = False

def _parse_line(line, distribution):
"""
Parse a line in the input file into (path, label)
Expand Down Expand Up @@ -579,12 +602,15 @@ def _load_thread(load_queue, write_queue, summary_queue,
if compute_mean:
image_sum += image

if backend == 'lmdb':
datum = _array_to_datum(image, label, encoding)
write_queue.put(datum)
else:
write_queue.put((image, label))

try:
if backend == 'lmdb':
datum = _array_to_datum(image, label, encoding)
write_queue.put(datum)
else:
write_queue.put((image, label))
except IOError: # try to save 16-bit images with PNG/JPG encoding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can IOError occur for other reasons, besides attempting to store 16-bit data with PNG/JPG encoding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me catch the error message, too, and set _notification with that error message.

_notification(True)
break
images_added += 1

summary_queue.put((images_added, image_sum))
Expand Down Expand Up @@ -616,6 +642,8 @@ def _array_to_datum(image, label, encoding):
image = image[np.newaxis,:,:]
else:
raise Exception('Image has unrecognized shape: "%s"' % image.shape)
if np.issubdtype(image.dtype, float):
image = image.astype(float)
datum = caffe.io.array_to_datum(image, label)
else:
datum = caffe_pb2.Datum()
Expand Down Expand Up @@ -667,7 +695,11 @@ def _save_means(image_sum, image_count, mean_files):
"""
Save mean[s] to file
"""
mean = np.around(image_sum / image_count).astype(np.uint8)
mean = np.around(image_sum / image_count)
if mean.max()>255:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not always use float for the mean? We don't use encoding for the mean.
if you want to keep the float v.s. uint8 dtype, I think you also need to check the lower bound

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gheinrich , thanks on the comments. My understanding is image_sum is always positive, so the variable mean is also positive. The only reason that astype(np.uint8) doesn't convert properly is some pixels in mean are greater than 255. So I keep the original codes to handle original case.
However, it will be easier to understand if we make mean to float. Let me try that.

mean = mean.astype(np.float)
else:
mean = mean.astype(np.uint8)
for mean_file in mean_files:
if mean_file.lower().endswith('.npy'):
np.save(mean_file, mean)
Expand All @@ -693,7 +725,10 @@ def _save_means(image_sum, image_count, mean_files):
with open(mean_file, 'wb') as outfile:
outfile.write(blob.SerializeToString())
elif mean_file.lower().endswith(('.jpg', '.jpeg', '.png')):
image = PIL.Image.fromarray(mean)
if np.issubdtype(mean.dtype, np.float):
image = PIL.Image.fromarray(mean*255/mean.max()).convert('L')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't we resort to 16-bit png in this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I don't think any pixel in mean can be greater than 65535. We should be able to do that for PNG case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gheinrich , it looks like if the max value is greater than 255, but much smaller than 65536 (for example, around 800), and we save it to png file. That image is almost black. The reason, I believe, is the dynamic range of 16-bit PNG is much higher than mean pixel value. We can either scale it up to full dynamic range of 16-bit PNG, or scale it down to 8-bit png/jpg. Both look the same on screen.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it's impossible to know how to best visualize data on server side. This is something we discussed with @jmancewicz. Joe thinks the best way to deal with this is to have the user choose the amount of contrast, etc. though sliders on client side. I agree with him.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds a good idea.

else:
image = PIL.Image.fromarray(mean)
image.save(mean_file)
else:
logger.warning('Unrecognized file extension for mean file: "%s"' % mean_file)
Expand Down
71 changes: 60 additions & 11 deletions digits/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@
except ImportError:
from StringIO import StringIO

try:
Copy link
Contributor

@gheinrich gheinrich Oct 4, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't you want to add pydicom to requirements.txt? If the package is not installed, DIGITS will silently refuse to read DICOM files, which could be confusing to the user.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's required, we're going to need a deb package for it. Does this package work?
http://packages.ubuntu.com/trusty/python/python-dicom

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me check if 0.9.7 works. Pip installed 0.9.9 and it's the one I tested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested 0.9.7 and it worked without problem. We should be able to use apt python-dicom package of 0.9.7.

Copy link
Contributor

@gheinrich gheinrich Oct 6, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you can add pydicom to requirements.txt and import dicom without a try/except then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

import dicom
dicom_extension = ('.dcm', '.dicom')
except ImportError:
dicom = None
dicom_extension = None

try:
import nifti
nifti_extension = ('.nii',)
except ImportError:
nifti = None
nifti_extension = None

import numpy as np
import PIL.Image
import scipy.misc
Expand All @@ -35,6 +49,27 @@
# List of supported file extensions
# Use like "if filename.endswith(SUPPORTED_EXTENSIONS)"
SUPPORTED_EXTENSIONS = ('.png','.jpg','.jpeg','.bmp','.ppm')
if dicom is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we always import dicom maybe you don't need this if?

Copy link
Contributor Author

@IsaacYangSLA IsaacYangSLA Oct 7, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. We can remove that if after we require pydicom package.

SUPPORTED_EXTENSIONS = SUPPORTED_EXTENSIONS + dicom_extension
if nifti is not None:
SUPPORTED_EXTENSIONS = SUPPORTED_EXTENSIONS + nifti_extension

def load_image_ex(path):
"""
Handles images not recognized by load_image
Reads a file from `path` and returns a PIL.Image with mode 'F' (float32)
Raises LoadImageError

Arguments:
path -- file system path to the image
"""
try:
dm = dicom.read_file(path)
except dicom.errors.InvalidDicomError:
raise errors.LoadImageError, 'Invalid Dicom file'
pixels = dm.pixel_array
image = PIL.Image.fromarray(pixels.astype(np.float))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you deal with 3D/4D data (which seem to be the most common use of DICOM)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. This definitely breaks when pixels.shape is higher than 2D.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can support that in the future. Adding 3D/4D data support from one single DICOM file would involve many modifications. Maybe that's because my limited understanding on the current codes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that can be a future improvement. If pixels is a 3-D array, how do you determine whether it's an RGB image (OK in DIGITS) or a volumetric grayscale image (not OK in DIGITS if more than there aren't exactly 3 slices)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Each slice should have its meta data. Therefore, we should be able to tell if it's a multi-slice DICOM, or a multi-channel DICOM. I need to dig out that information, but I remember I saw that before.

return image

def load_image(path):
"""
Expand Down Expand Up @@ -62,11 +97,14 @@ def load_image(path):
image = PIL.Image.open(path)
image.load()
except IOError as e:
raise errors.LoadImageError, 'IOError: %s' % e.message
if dicom is not None:
image = load_image_ex(path)
else:
raise errors.LoadImageError, e.message
else:
raise errors.LoadImageError, '"%s" not found' % path

if image.mode in ['L', 'RGB']:
if image.mode in ['L', 'RGB', 'F']:
# No conversion necessary
return image
elif image.mode in ['1']:
Expand Down Expand Up @@ -103,7 +141,7 @@ def upscale(image, ratio):
width = int(math.floor(image.shape[1] * ratio))
height = int(math.floor(image.shape[0] * ratio))
channels = image.shape[2]
out = np.ndarray((height, width, channels),dtype=np.uint8)
out = np.ndarray((height, width, channels),dtype=image.dtype)
for x, y in np.ndindex((width,height)):
out[y,x] = image[int(math.floor(y/ratio)), int(math.floor(x/ratio))]
return out
Expand All @@ -128,15 +166,15 @@ def image_to_array(image,
# Convert image mode (channels)
if channels is None:
image_mode = image.mode
if image_mode == 'L':
if image_mode == 'L' or image_mode == 'F':
channels = 1
elif image_mode == 'RGB':
channels = 3
else:
raise ValueError('unknown image mode "%s"' % image_mode)
elif channels == 1:
# 8-bit pixels, black and white
image_mode = 'L'
image_mode = image.mode if image.mode == 'F' else 'L'
elif channels == 3:
# 3x8-bit pixels, true color
image_mode = 'RGB'
Expand Down Expand Up @@ -208,8 +246,9 @@ def resize_image(image, height, width,

width_ratio = float(image.shape[1]) / width
height_ratio = float(image.shape[0]) / height
image_data_format = 'F' if np.issubdtype(image.dtype, float) else None
if resize_mode == 'squash' or width_ratio == height_ratio:
return scipy.misc.imresize(image, (height, width), interp=interp)
return scipy.misc.imresize(image, (height, width), mode=image_data_format, interp=interp)
elif resize_mode == 'crop':
# resize to smallest of ratios (relatively larger image), keeping aspect ratio
if width_ratio > height_ratio:
Expand All @@ -218,7 +257,7 @@ def resize_image(image, height, width,
else:
resize_width = width
resize_height = int(round(image.shape[0] / width_ratio))
image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp)
image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp)

# chop off ends of dimension that is still too long
if width_ratio > height_ratio:
Expand All @@ -240,7 +279,7 @@ def resize_image(image, height, width,
resize_width = int(round(image.shape[1] / height_ratio))
if (width - resize_width) % 2 == 1:
resize_width += 1
image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp)
image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp)
elif resize_mode == 'half_crop':
# resize to average ratio keeping aspect ratio
new_ratio = (width_ratio + height_ratio) / 2.0
Expand All @@ -250,7 +289,7 @@ def resize_image(image, height, width,
resize_height += 1
elif width_ratio < height_ratio and (width - resize_width) % 2 == 1:
resize_width += 1
image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp)
image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp)
# chop off ends of dimension that is still too long
if width_ratio > height_ratio:
start = int(round((resize_width-width)/2.0))
Expand All @@ -267,14 +306,20 @@ def resize_image(image, height, width,
noise_size = (padding, width)
if channels > 1:
noise_size += (channels,)
noise = np.random.randint(0, 255, noise_size).astype('uint8')
if image_data_format == 'F':
noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype('float')
else:
noise = np.random.randint(0, 255, noise_size).astype('uint8')
image = np.concatenate((noise, image, noise), axis=0)
else:
padding = (width - resize_width)/2
noise_size = (height, padding)
if channels > 1:
noise_size += (channels,)
noise = np.random.randint(0, 255, noise_size).astype('uint8')
if image_data_format == 'F':
noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype('float')
else:
noise = np.random.randint(0, 255, noise_size).astype('uint8')
image = np.concatenate((noise, image, noise), axis=1)

return image
Expand Down Expand Up @@ -305,6 +350,10 @@ def embed_image_html(image):
fmt = fmt.lower()

string_buf = StringIO()
if image.mode == 'F':
min_pv, max_pv = image.getextrema()
tmp_array = 255.0*(np.asarray(image)-min_pv)/(max_pv-min_pv)
image = PIL.Image.fromarray(tmp_array.astype(np.uint8))
image.save(string_buf, format=fmt)
data = string_buf.getvalue().encode('base64').replace('\n', '')
return 'data:image/%s;base64,%s' % (fmt, data)
Expand Down