Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix adjusting contrast limits for rgb data #276

Merged
merged 5 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 44 additions & 31 deletions src/napari_spatialdata/_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,37 +238,7 @@ def __init__(self, napari_viewer: Viewer, model: DataModel | None = None) -> Non
self.var_widget = AListWidget(self.viewer, self.model, attr="var")
self.var_widget.setAdataLayer("X")

def channel_changed(event: Event) -> None:
layer = self.model.layer
is_image = isinstance(layer, Image)

has_sdata = layer is not None and layer.metadata.get("sdata") is not None
has_adata = layer is not None and layer.metadata.get("adata") is not None

# has_adata is added so we see the channels in the view widget under vars
if layer is not None and is_image and has_sdata and has_adata:
c_channel = event.value[0]

# TODO remove once contrast limits in napari are fixed
if isinstance(layer.data, MultiScaleData):
# just compute lowest resolution
image = layer.data[-1][c_channel, :, :].compute()
min_value = image.min().data
max_value = image.max().data
else:
image = layer.data[c_channel, :, :].compute()
min_value = image.min()
max_value = image.max()
if min_value == max_value:
min_value = np.iinfo(image.data.dtype).min
max_value = np.iinfo(image.data.dtype).max
layer.contrast_limits = [min_value, max_value]

item = self.var_widget.item(c_channel)
index = self.var_widget.indexFromItem(item)
self.var_widget.setCurrentIndex(index)

self.viewer.dims.events.current_step.connect(channel_changed)
self.viewer.dims.events.current_step.connect(self._channel_changed)

# layers
adata_layer_label = QLabel("Layers:")
Expand Down Expand Up @@ -322,6 +292,49 @@ def channel_changed(event: Event) -> None:
self.model.events.adata.connect(self._on_layer_update)
self.model.events.color_by.connect(self._change_color_by)

def _channel_changed(self, event: Event) -> None:
layer = self.model.layer
is_image = isinstance(layer, Image)

has_sdata = layer is not None and layer.metadata.get("sdata") is not None
has_adata = layer is not None and layer.metadata.get("adata") is not None

# has_adata is added so we see the channels in the view widget under vars
if layer is None or not is_image or not has_sdata or not has_adata or layer.rgb:
return

current_point = list(event.value)
displayed = self._viewer.dims.displayed

for i, (lo_size, hi_size, cord) in enumerate(zip(layer.data[-1].shape, layer.data[0].shape, displayed)):
if i in displayed:
current_point[i] = slice(None)
else:
current_point[i] = int(cord * lo_size / hi_size)
Comment on lines +306 to +313
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is logic for any dimensional data.

If you are 100% sure that data will always be 2D, then this logic may be obsolete.
If contrast limits needs to be calculated for the whole channel, not only the visible portion, then it is also overcomplicated. But then, for 3D data, should be no recalculation when change Z slice.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we will not be sure about this. SpatialData already does have 3D models. Where bugs occur for 3D napari-spatialdata should be adjusted.


# TODO remove once contrast limits in napari are fixed
if isinstance(layer.data, MultiScaleData):
# just compute lowest resolution
image = layer.data[-1][tuple(current_point)].compute()
min_value = image.min().data
max_value = image.max().data
else:
image = layer.data[tuple(current_point)].compute()
min_value = image.min()
max_value = image.max()
if min_value == max_value:
min_value = np.iinfo(image.data.dtype).min
max_value = np.iinfo(image.data.dtype).max
layer.contrast_limits = [min_value, max_value]
try:
channel_num = next(x for x in current_point if not isinstance(x, slice))
except StopIteration:
return

item = self.var_widget.item(channel_num)
index = self.var_widget.indexFromItem(item)
self.var_widget.setCurrentIndex(index)

def _on_layer_update(self, event: Any | None = None) -> None:
"""When the model updates the selected layer, update the relevant widgets."""
logger.info("Updating layer.")
Expand Down
3 changes: 2 additions & 1 deletion src/napari_spatialdata/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ def add_sdata_image(self, sdata: SpatialData, key: str, selected_cs: str, multi:
affine = _get_transform(sdata.images[original_name], selected_cs)
rgb_image, rgb = _adjust_channels_order(element=sdata.images[original_name])

channels = get_channels(sdata.images[original_name])
channels = ("RGB(A)",) if rgb else get_channels(sdata.images[original_name])

adata = AnnData(shape=(0, len(channels)), var=pd.DataFrame(index=channels))

# TODO: type check
Expand Down
Loading