Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP "Fake" ambiant occlusion for MI-Brain #1030

Merged
merged 4 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 80 additions & 5 deletions scilpy/viz/color.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# -*- coding: utf-8 -*-

from fury.colormap import distinguishable_colormap
from dipy.io.stateful_tractogram import StatefulTractogram
from fury import colormap
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib import colors as mcolors
import numpy as np
from scipy.spatial import KDTree

from scilpy.viz.backends.vtk import get_color_by_name, lut_from_colors

Expand Down Expand Up @@ -38,7 +40,7 @@ def convert_color_names_to_rgb(names):
"Brown"])


def generate_n_colors(n, generator=distinguishable_colormap,
def generate_n_colors(n, generator=colormap.distinguishable_colormap,
pick_from_base10=True, shuffle=False):
"""
Generate a set of N colors. When using the default parameters, colors will
Expand Down Expand Up @@ -98,8 +100,8 @@ def get_lookup_table(name):

if '-' in name:
name_list = name.split('-')
colors_list = [colors.to_rgba(color)[0:3] for color in name_list]
cmap = colors.LinearSegmentedColormap.from_list('CustomCmap',
colors_list = [mcolors.to_rgba(color)[0:3] for color in name_list]
cmap = mcolors.LinearSegmentedColormap.from_list('CustomCmap',
colors_list)
return cmap

Expand Down Expand Up @@ -279,3 +281,76 @@ def prepare_colorbar_figure(cmap, lbound, ubound, nb_values=255, nb_ticks=10,
ax.set_xticklabels(ticks_labels)
ax.set_yticks([])
return fig


def ambiant_occlusion(sft, colors, factor=4):
"""
Apply ambiant occlusion to a set of colors based on point density
around each points.

Parameters
----------
sft : StatefulTractogram
The streamlines.
colors : np.ndarray
The original colors to modify.
factor : float
The factor of occlusion (how density will affect the saturation).

Returns
-------
np.ndarray
The modified colors.
"""

pts = sft.streamlines._data
hsv = mcolors.rgb_to_hsv(colors)

tree = KDTree(pts)
nb_neighbor = np.array(tree.query_ball_point(pts, 1,
return_length=True),
dtype=float)
nb_neighbor /= np.max(nb_neighbor)
occlusion_w = np.exp(-factor * nb_neighbor)

hsv[:, 1] = np.clip(hsv[:, 1], max(1 / factor, np.min(hsv[:, 1])),
min(1 - 1 / factor, np.max(hsv[:, 1])))
hsv[:, 1] -= (occlusion_w / factor)

occlusion_w = np.clip(occlusion_w, 0.5 + (1 / factor), 1)
hsv[:, 2] *= occlusion_w
hsv[:, 0:2] = np.clip(hsv[:, 0:2], 0, 1)
hsv[:, 2] = np.clip(hsv[:, 2], 0, 255)

return mcolors.hsv_to_rgb(hsv)

def generate_local_coloring(sft):
"""
Generate a coloring based on the local orientation of the streamlines.

Parameters
----------
sft : StatefulTractogram / ArraySequence / List
The tractogram / streamlines to generate the coloring from.

Returns
-------
np.ndarray
The generated colors.
"""
if isinstance(sft, StatefulTractogram):
streamlines = sft.streamlines
else:
streamlines = sft

# Compute segment orientation
diff = [np.diff(list(s), axis=0) for s in streamlines]
# Repeat first segment so that the number of segments matches
# the number of points
diff = [[d[0]] + list(d) for d in diff]
# Flatten the list of segments
orientations = np.asarray([o for d in diff for o in d])
# Turn the segments into colors
color = colormap.orient2rgb(orientations)

return color
35 changes: 28 additions & 7 deletions scripts/scil_tractogram_assign_custom_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import logging

from dipy.io.streamline import save_tractogram
from fury import colormap
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
Expand All @@ -65,8 +66,9 @@
from scilpy.tractograms.dps_and_dpp_management import add_data_as_color_dpp
from scilpy.tractograms.streamline_operations import (
get_streamlines_as_linspaces, get_angles)
from scilpy.viz.color import get_lookup_table
from scilpy.viz.color import prepare_colorbar_figure
from scilpy.viz.color import (
get_lookup_table, prepare_colorbar_figure, ambiant_occlusion,
generate_local_coloring)


def _build_arg_parser():
Expand Down Expand Up @@ -105,12 +107,19 @@ def _build_arg_parser():
p1.add_argument('--along_profile', action='store_true',
help='Color streamlines according to each point position'
'along its length.')
p1.add_argument('--local_orientation', action='store_true',
help="Color streamlines according to the angle between "
"each segment (in degree). \nAngles at first and "
"last points are set to 0.")
p1.add_argument('--local_angle', action='store_true',
help="Color streamlines according to the angle between "
"each segment (in degree). \nAngles at first and "
"last points are set to 0.")

g2 = p.add_argument_group(title='Coloring options')
g2.add_argument('--ambiant_occlusion', nargs='?', const=4, type=int,
help='Impact factor of the ambiant occlusion '
'approximation. [%(default)s]')
g2.add_argument('--colormap', default='jet',
help='Select the colormap for colored trk (dps/dpp) '
'[%(default)s].\nUse two Matplotlib named color separeted '
Expand All @@ -128,8 +137,7 @@ def _build_arg_parser():
g2.add_argument('--max_cmap', type=float,
help='Set the maximum value of the colormap.')
g2.add_argument('--log', action='store_true',
help='Apply a base 10 logarithm for colored trk (dps/dpp).'
)
help='Apply a base 10 logarithm for colored trk (dps/dpp).')
g2.add_argument('--LUT', metavar='FILE',
help='If the dps/dpp or anatomy contain integer labels, '
'the value will be substituted.\nIf the LUT has 20 '
Expand All @@ -148,6 +156,9 @@ def main():
args = parser.parse_args()
logging.getLogger().setLevel(logging.getLevelName(args.verbose))

if args.local_orientation and args.out_colorbar:
parser.error("Cannot save a colorbar with local orientation coloring.")

# Verifications
assert_inputs_exist(parser, args.in_tractogram, args.reference)
assert_outputs_exist(parser, args, args.out_tractogram,
Expand Down Expand Up @@ -206,16 +217,26 @@ def main():
elif args.along_profile:
data = get_streamlines_as_linspaces(sft)
data = np.hstack(data)
elif args.local_orientation:
data = generate_local_coloring(sft)
else: # args.local_angle:
data = get_angles(sft, add_zeros=True)
data = np.hstack(data)

# Processing
sft, lbound, ubound = add_data_as_color_dpp(
sft, cmap, data, args.clip_outliers, args.min_range, args.max_range,
args.min_cmap, args.max_cmap, args.log, LUT)
if not args.local_orientation:
sft, lbound, ubound = add_data_as_color_dpp(
sft, cmap, data, args.clip_outliers, args.min_range, args.max_range,
args.min_cmap, args.max_cmap, args.log, LUT)
else:
sft.data_per_point['color'] = sft.streamlines.copy()
data *= 255
sft.data_per_point['color']._data = data.astype(np.uint8)

# Saving
if args.ambiant_occlusion:
sft.data_per_point['color']._data = ambiant_occlusion(
sft, sft.data_per_point['color']._data, args.ambiant_occlusion)
save_tractogram(sft, args.out_tractogram)

if args.out_colorbar:
Expand Down
21 changes: 15 additions & 6 deletions scripts/scil_tractogram_assign_uniform_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
add_overwrite_arg,
add_verbose_arg,
add_reference_arg, assert_headers_compatible)
from scilpy.viz.color import format_hexadecimal_color_to_rgb
from scilpy.viz.color import format_hexadecimal_color_to_rgb, ambiant_occlusion


def _build_arg_parser():
Expand All @@ -41,6 +41,11 @@ def _build_arg_parser():
p.add_argument('in_tractograms', nargs='+',
help='Input tractograms (.trk or .tck).')

p.add_argument('--ambiant_occlusion', nargs='?', const=4, type=int,
help='Impact factor of the ambiant occlusion '
'approximation.\n Use factor or 2. Decrease for '
'lighter and increase for darker [%(default)s].')

g1 = p.add_argument_group(title='Coloring Methods')
p1 = g1.add_mutually_exclusive_group(required=True)
p1.add_argument('--fill_color', metavar='str',
Expand Down Expand Up @@ -116,6 +121,10 @@ def main():

sft = load_tractogram_with_reference(parser, args, filename)

sft.data_per_point['color'] = sft.streamlines.copy()
sft.data_per_point['color']._data = np.zeros(
(len(sft.streamlines._data), 3), dtype=np.uint8)

if args.dict_colors:
base, ext = os.path.splitext(filename)
pos = base.index('__') if '__' in base else -2
Expand All @@ -132,11 +141,11 @@ def main():

red, green, blue = format_hexadecimal_color_to_rgb(color)

tmp = [np.tile([red, green, blue], (len(i), 1))
for i in sft.streamlines]
sft.data_per_point['color'] = tmp

# Saving
colors = np.tile([red, green, blue], (len(sft.streamlines._data), 1))
if args.ambiant_occlusion:
colors = ambiant_occlusion(sft, colors,
factor=args.ambiant_occlusion)
sft.data_per_point['color']._data = colors
save_tractogram(sft, out_filenames[i])


Expand Down
13 changes: 3 additions & 10 deletions scripts/scil_viz_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@
import random

from dipy.tracking.streamline import set_number_of_points
from fury import window, actor, colormap
from fury import window, actor

from scilpy.io.utils import (assert_inputs_exist,
add_verbose_arg,
parser_color_type)
from scilpy.viz.color import generate_local_coloring


streamline_actor = {'tube': actor.streamtube,
Expand Down Expand Up @@ -185,15 +186,7 @@ def subsample(list_obj):
elif args.uniform_coloring: # Assign uniform coloring to streamlines
color = tuple(np.asarray(args.uniform_coloring) / 255)
elif args.local_coloring: # Compute coloring from local orientations
# Compute segment orientation
diff = [np.diff(list(s), axis=0) for s in streamlines]
# Repeat first segment so that the number of segments matches
# the number of points
diff = [[d[0]] + list(d) for d in diff]
# Flatten the list of segments
orientations = np.asarray([o for d in diff for o in d])
# Turn the segments into colors
color = colormap.orient2rgb(orientations)
color = generate_local_coloring(streamlines)
else: # Streamline color will depend on the streamlines' endpoints.
color = None
# TODO: Coloring from a volume of local orientations
Expand Down