-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Supporting DICOM via pydicom #1136
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -269,6 +269,9 @@ def create_db(input_file, output_dir, | |
write_queue = Queue.Queue(2*batch_size) | ||
summary_queue = Queue.Queue() | ||
|
||
# Init helper function for notification between threads | ||
_notification(set_flag=False) | ||
|
||
for _ in xrange(num_threads): | ||
p = threading.Thread(target=_load_thread, | ||
args=(load_queue, write_queue, summary_queue, | ||
|
@@ -331,6 +334,9 @@ def _create_lmdb(image_count, write_queue, batch_size, output_dir, | |
|
||
processed_something = False | ||
|
||
if _notification(): | ||
break | ||
|
||
if not summary_queue.empty(): | ||
result_count, result_sum = summary_queue.get() | ||
images_loaded += result_count | ||
|
@@ -360,6 +366,9 @@ def _create_lmdb(image_count, write_queue, batch_size, output_dir, | |
_write_batch_lmdb(db, batch, images_written) | ||
images_written += len(batch) | ||
|
||
if _notification(): | ||
raise WriteError('Encoding should be None for images with color depth higher than 8 bits.') | ||
|
||
if images_loaded == 0: | ||
raise LoadError('no images loaded from input file') | ||
logger.debug('%s images loaded' % images_loaded) | ||
|
@@ -412,6 +421,9 @@ def _create_hdf5(image_count, write_queue, batch_size, output_dir, | |
|
||
processed_something = False | ||
|
||
if _notification(): | ||
break | ||
|
||
if not summary_queue.empty(): | ||
result_count, result_sum = summary_queue.get() | ||
images_loaded += result_count | ||
|
@@ -442,6 +454,9 @@ def _create_hdf5(image_count, write_queue, batch_size, output_dir, | |
|
||
assert images_written == writer.count() | ||
|
||
if _notification(): | ||
raise WriteError('Encoding should be None for images with color depth higher than 8 bits.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the notion of "encoding" is irrelevant for HDF5 databases There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My bad. It clearly states on the beginning of Hdf5Writer function DTYPE='float32.' |
||
|
||
if images_loaded == 0: | ||
raise LoadError('no images loaded from input file') | ||
logger.debug('%s images loaded' % images_loaded) | ||
|
@@ -498,6 +513,14 @@ def _fill_load_queue(filename, queue, shuffle): | |
|
||
return valid_lines | ||
|
||
def _notification(set_flag=None): | ||
if set_flag is None: | ||
return _notification.flag | ||
elif set_flag: | ||
_notification.flag = True | ||
else: | ||
_notification.flag = False | ||
|
||
def _parse_line(line, distribution): | ||
""" | ||
Parse a line in the input file into (path, label) | ||
|
@@ -579,12 +602,15 @@ def _load_thread(load_queue, write_queue, summary_queue, | |
if compute_mean: | ||
image_sum += image | ||
|
||
if backend == 'lmdb': | ||
datum = _array_to_datum(image, label, encoding) | ||
write_queue.put(datum) | ||
else: | ||
write_queue.put((image, label)) | ||
|
||
try: | ||
if backend == 'lmdb': | ||
datum = _array_to_datum(image, label, encoding) | ||
write_queue.put(datum) | ||
else: | ||
write_queue.put((image, label)) | ||
except IOError: # try to save 16-bit images with PNG/JPG encoding | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me catch the error message, too, and set _notification with that error message. |
||
_notification(True) | ||
break | ||
images_added += 1 | ||
|
||
summary_queue.put((images_added, image_sum)) | ||
|
@@ -616,6 +642,8 @@ def _array_to_datum(image, label, encoding): | |
image = image[np.newaxis,:,:] | ||
else: | ||
raise Exception('Image has unrecognized shape: "%s"' % image.shape) | ||
if np.issubdtype(image.dtype, float): | ||
image = image.astype(float) | ||
datum = caffe.io.array_to_datum(image, label) | ||
else: | ||
datum = caffe_pb2.Datum() | ||
|
@@ -667,7 +695,11 @@ def _save_means(image_sum, image_count, mean_files): | |
""" | ||
Save mean[s] to file | ||
""" | ||
mean = np.around(image_sum / image_count).astype(np.uint8) | ||
mean = np.around(image_sum / image_count) | ||
if mean.max()>255: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not always use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gheinrich , thanks on the comments. My understanding is |
||
mean = mean.astype(np.float) | ||
else: | ||
mean = mean.astype(np.uint8) | ||
for mean_file in mean_files: | ||
if mean_file.lower().endswith('.npy'): | ||
np.save(mean_file, mean) | ||
|
@@ -693,7 +725,10 @@ def _save_means(image_sum, image_count, mean_files): | |
with open(mean_file, 'wb') as outfile: | ||
outfile.write(blob.SerializeToString()) | ||
elif mean_file.lower().endswith(('.jpg', '.jpeg', '.png')): | ||
image = PIL.Image.fromarray(mean) | ||
if np.issubdtype(mean.dtype, np.float): | ||
image = PIL.Image.fromarray(mean*255/mean.max()).convert('L') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't we resort to 16-bit png in this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. I don't think any pixel in mean can be greater than 65535. We should be able to do that for PNG case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @gheinrich , it looks like if the max value is greater than 255, but much smaller than 65536 (for example, around 800), and we save it to png file. That image is almost black. The reason, I believe, is the dynamic range of 16-bit PNG is much higher than mean pixel value. We can either scale it up to full dynamic range of 16-bit PNG, or scale it down to 8-bit png/jpg. Both look the same on screen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed it's impossible to know how to best visualize data on server side. This is something we discussed with @jmancewicz. Joe thinks the best way to deal with this is to have the user choose the amount of contrast, etc. though sliders on client side. I agree with him. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds a good idea. |
||
else: | ||
image = PIL.Image.fromarray(mean) | ||
image.save(mean_file) | ||
else: | ||
logger.warning('Unrecognized file extension for mean file: "%s"' % mean_file) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,20 @@ | |
except ImportError: | ||
from StringIO import StringIO | ||
|
||
try: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't you want to add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's required, we're going to need a deb package for it. Does this package work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let me check if 0.9.7 works. Pip installed 0.9.9 and it's the one I tested. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested 0.9.7 and it worked without problem. We should be able to use apt python-dicom package of 0.9.7. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so you can add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. |
||
import dicom | ||
dicom_extension = ('.dcm', '.dicom') | ||
except ImportError: | ||
dicom = None | ||
dicom_extension = None | ||
|
||
try: | ||
import nifti | ||
nifti_extension = ('.nii',) | ||
except ImportError: | ||
nifti = None | ||
nifti_extension = None | ||
|
||
import numpy as np | ||
import PIL.Image | ||
import scipy.misc | ||
|
@@ -35,6 +49,27 @@ | |
# List of supported file extensions | ||
# Use like "if filename.endswith(SUPPORTED_EXTENSIONS)" | ||
SUPPORTED_EXTENSIONS = ('.png','.jpg','.jpeg','.bmp','.ppm') | ||
if dicom is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if we always import There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's correct. We can remove that |
||
SUPPORTED_EXTENSIONS = SUPPORTED_EXTENSIONS + dicom_extension | ||
if nifti is not None: | ||
SUPPORTED_EXTENSIONS = SUPPORTED_EXTENSIONS + nifti_extension | ||
|
||
def load_image_ex(path): | ||
""" | ||
Handles images not recognized by load_image | ||
Reads a file from `path` and returns a PIL.Image with mode 'F' (float32) | ||
Raises LoadImageError | ||
|
||
Arguments: | ||
path -- file system path to the image | ||
""" | ||
try: | ||
dm = dicom.read_file(path) | ||
except dicom.errors.InvalidDicomError: | ||
raise errors.LoadImageError, 'Invalid Dicom file' | ||
pixels = dm.pixel_array | ||
image = PIL.Image.fromarray(pixels.astype(np.float)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how do you deal with 3D/4D data (which seem to be the most common use of DICOM)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. This definitely breaks when pixels.shape is higher than 2D. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can support that in the future. Adding 3D/4D data support from one single DICOM file would involve many modifications. Maybe that's because my limited understanding on the current codes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure that can be a future improvement. If There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Each slice should have its meta data. Therefore, we should be able to tell if it's a multi-slice DICOM, or a multi-channel DICOM. I need to dig out that information, but I remember I saw that before. |
||
return image | ||
|
||
def load_image(path): | ||
""" | ||
|
@@ -62,11 +97,14 @@ def load_image(path): | |
image = PIL.Image.open(path) | ||
image.load() | ||
except IOError as e: | ||
raise errors.LoadImageError, 'IOError: %s' % e.message | ||
if dicom is not None: | ||
image = load_image_ex(path) | ||
else: | ||
raise errors.LoadImageError, e.message | ||
else: | ||
raise errors.LoadImageError, '"%s" not found' % path | ||
|
||
if image.mode in ['L', 'RGB']: | ||
if image.mode in ['L', 'RGB', 'F']: | ||
# No conversion necessary | ||
return image | ||
elif image.mode in ['1']: | ||
|
@@ -103,7 +141,7 @@ def upscale(image, ratio): | |
width = int(math.floor(image.shape[1] * ratio)) | ||
height = int(math.floor(image.shape[0] * ratio)) | ||
channels = image.shape[2] | ||
out = np.ndarray((height, width, channels),dtype=np.uint8) | ||
out = np.ndarray((height, width, channels),dtype=image.dtype) | ||
for x, y in np.ndindex((width,height)): | ||
out[y,x] = image[int(math.floor(y/ratio)), int(math.floor(x/ratio))] | ||
return out | ||
|
@@ -128,15 +166,15 @@ def image_to_array(image, | |
# Convert image mode (channels) | ||
if channels is None: | ||
image_mode = image.mode | ||
if image_mode == 'L': | ||
if image_mode == 'L' or image_mode == 'F': | ||
channels = 1 | ||
elif image_mode == 'RGB': | ||
channels = 3 | ||
else: | ||
raise ValueError('unknown image mode "%s"' % image_mode) | ||
elif channels == 1: | ||
# 8-bit pixels, black and white | ||
image_mode = 'L' | ||
image_mode = image.mode if image.mode == 'F' else 'L' | ||
elif channels == 3: | ||
# 3x8-bit pixels, true color | ||
image_mode = 'RGB' | ||
|
@@ -208,8 +246,9 @@ def resize_image(image, height, width, | |
|
||
width_ratio = float(image.shape[1]) / width | ||
height_ratio = float(image.shape[0]) / height | ||
image_data_format = 'F' if np.issubdtype(image.dtype, float) else None | ||
if resize_mode == 'squash' or width_ratio == height_ratio: | ||
return scipy.misc.imresize(image, (height, width), interp=interp) | ||
return scipy.misc.imresize(image, (height, width), mode=image_data_format, interp=interp) | ||
elif resize_mode == 'crop': | ||
# resize to smallest of ratios (relatively larger image), keeping aspect ratio | ||
if width_ratio > height_ratio: | ||
|
@@ -218,7 +257,7 @@ def resize_image(image, height, width, | |
else: | ||
resize_width = width | ||
resize_height = int(round(image.shape[0] / width_ratio)) | ||
image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp) | ||
image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp) | ||
|
||
# chop off ends of dimension that is still too long | ||
if width_ratio > height_ratio: | ||
|
@@ -240,7 +279,7 @@ def resize_image(image, height, width, | |
resize_width = int(round(image.shape[1] / height_ratio)) | ||
if (width - resize_width) % 2 == 1: | ||
resize_width += 1 | ||
image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp) | ||
image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp) | ||
elif resize_mode == 'half_crop': | ||
# resize to average ratio keeping aspect ratio | ||
new_ratio = (width_ratio + height_ratio) / 2.0 | ||
|
@@ -250,7 +289,7 @@ def resize_image(image, height, width, | |
resize_height += 1 | ||
elif width_ratio < height_ratio and (width - resize_width) % 2 == 1: | ||
resize_width += 1 | ||
image = scipy.misc.imresize(image, (resize_height, resize_width), interp=interp) | ||
image = scipy.misc.imresize(image, (resize_height, resize_width), mode=image_data_format, interp=interp) | ||
# chop off ends of dimension that is still too long | ||
if width_ratio > height_ratio: | ||
start = int(round((resize_width-width)/2.0)) | ||
|
@@ -267,14 +306,20 @@ def resize_image(image, height, width, | |
noise_size = (padding, width) | ||
if channels > 1: | ||
noise_size += (channels,) | ||
noise = np.random.randint(0, 255, noise_size).astype('uint8') | ||
if image_data_format == 'F': | ||
noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype('float') | ||
else: | ||
noise = np.random.randint(0, 255, noise_size).astype('uint8') | ||
image = np.concatenate((noise, image, noise), axis=0) | ||
else: | ||
padding = (width - resize_width)/2 | ||
noise_size = (height, padding) | ||
if channels > 1: | ||
noise_size += (channels,) | ||
noise = np.random.randint(0, 255, noise_size).astype('uint8') | ||
if image_data_format == 'F': | ||
noise = np.random.randint(int(image.min()), int(image.max()), noise_size).astype('float') | ||
else: | ||
noise = np.random.randint(0, 255, noise_size).astype('uint8') | ||
image = np.concatenate((noise, image, noise), axis=1) | ||
|
||
return image | ||
|
@@ -305,6 +350,10 @@ def embed_image_html(image): | |
fmt = fmt.lower() | ||
|
||
string_buf = StringIO() | ||
if image.mode == 'F': | ||
min_pv, max_pv = image.getextrema() | ||
tmp_array = 255.0*(np.asarray(image)-min_pv)/(max_pv-min_pv) | ||
image = PIL.Image.fromarray(tmp_array.astype(np.uint8)) | ||
image.save(string_buf, format=fmt) | ||
data = string_buf.getvalue().encode('base64').replace('\n', '') | ||
return 'data:image/%s;base64,%s' % (fmt, data) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be more explicitly named or be made more generic? We could make the flag a bit field, and support a number of reasons for terminating threads (wrong encoding being just one of those reasons).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name is too vague, I agreed. I thought Python didn't natively support bit field. I will make this a dict with meaningful key names. Hopefully, it can be a better helpful function.