Skip to content

Commit 87790f5

Browse files
authored
Use Image trace and add low-res slices (#3)
* wip low-res data * use Image trace instead of layout images. * tweaks * cleanup
1 parent 56b173e commit 87790f5

File tree

2 files changed

+38
-32
lines changed

2 files changed

+38
-32
lines changed

dash_3d_viewer/slicer.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from plotly.graph_objects import Figure
2+
from plotly.graph_objects import Figure, Image
33
from dash import Dash
44
from dash.dependencies import Input, Output, State
55
from dash_core_components import Graph, Slider, Store
@@ -29,44 +29,39 @@ def __init__(self, app, volume, axis=0, id=None):
2929
self._id = id
3030

3131
# Get the slice size (width, height), and max index
32-
arr_shape = list(volume.shape)
33-
arr_shape.pop(self._axis)
34-
slice_size = list(reversed(arr_shape))
32+
# arr_shape = list(volume.shape)
33+
# arr_shape.pop(self._axis)
34+
# slice_size = list(reversed(arr_shape))
3535
self._max_index = self._volume.shape[self._axis] - 1
3636

37+
# Prep low-res slices
38+
thumbnails = [
39+
img_array_to_uri(self._slice(i), (32, 32))
40+
for i in range(self._max_index + 1)
41+
]
42+
43+
# Create a placeholder trace
44+
# todo: can add "%{z[0]}", but that would be the scaled value ...
45+
trace = Image(source="", hovertemplate="(%{x}, %{y})<extra></extra>")
3746
# Create the figure object
38-
fig = Figure()
47+
fig = Figure(data=[trace])
3948
fig.update_layout(
4049
template=None,
4150
margin=dict(l=0, r=0, b=0, t=0, pad=4),
4251
)
4352
fig.update_xaxes(
53+
# range=(0, slice_size[0]),
4454
showgrid=False,
45-
range=(0, slice_size[0]),
4655
showticklabels=False,
4756
zeroline=False,
4857
)
4958
fig.update_yaxes(
59+
# range=(slice_size[1], 0), # todo: allow flipping x or y
5060
showgrid=False,
5161
scaleanchor="x",
52-
range=(slice_size[1], 0), # todo: allow flipping x or y
5362
showticklabels=False,
5463
zeroline=False,
5564
)
56-
# Add an empty layout image that we can populate from JS.
57-
fig.add_layout_image(
58-
dict(
59-
source="",
60-
xref="x",
61-
yref="y",
62-
x=0,
63-
y=0,
64-
sizex=slice_size[0],
65-
sizey=slice_size[1],
66-
sizing="contain",
67-
layer="below",
68-
)
69-
)
7065
# Wrap the figure in a graph
7166
# todo: or should the user provide this?
7267
self.graph = Graph(
@@ -88,6 +83,7 @@ def __init__(self, app, volume, axis=0, id=None):
8883
Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2),
8984
Store(id=self._subid("_requested-slice-index"), data=0),
9085
Store(id=self._subid("_slice-data"), data=""),
86+
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
9187
]
9288

9389
self._create_server_callbacks(app)
@@ -101,7 +97,8 @@ def _slice(self, index):
10197
"""Sample a slice from the volume."""
10298
indices = [slice(None), slice(None), slice(None)]
10399
indices[self._axis] = index
104-
return self._volume[tuple(indices)]
100+
im = self._volume[tuple(indices)]
101+
return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8)
105102

106103
def _create_server_callbacks(self, app):
107104
"""Create the callbacks that run server-side."""
@@ -112,7 +109,6 @@ def _create_server_callbacks(self, app):
112109
)
113110
def upload_requested_slice(slice_index):
114111
slice = self._slice(slice_index)
115-
slice = (slice.astype(np.float32) * (255 / slice.max())).astype(np.uint8)
116112
return [slice_index, img_array_to_uri(slice)]
117113

118114
def _create_client_callbacks(self, app):
@@ -158,7 +154,7 @@ def _create_client_callbacks(self, app):
158154

159155
app.clientside_callback(
160156
"""
161-
function handle_incoming_slice(index, index_and_data, ori_figure) {
157+
function handle_incoming_slice(index, index_and_data, ori_figure, lowres) {
162158
let new_index = index_and_data[0];
163159
let new_data = index_and_data[1];
164160
// Store data in cache
@@ -167,17 +163,18 @@ def _create_client_callbacks(self, app):
167163
slice_cache[new_index] = new_data;
168164
// Get the data we need *now*
169165
let data = slice_cache[index];
166+
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
170167
// Maybe we do not need an update
171168
if (!data) {
172-
return window.dash_clientside.no_update;
169+
data = lowres[index];
173170
}
174-
if (data == ori_figure.layout.images[0].source) {
171+
if (data == ori_figure.data[0].source) {
175172
return window.dash_clientside.no_update;
176173
}
177174
// Otherwise, perform update
178175
console.log("updating figure");
179176
let figure = {...ori_figure};
180-
figure.layout.images[0].source = data;
177+
figure.data[0].source = data;
181178
return figure;
182179
}
183180
""".replace(
@@ -188,5 +185,8 @@ def _create_client_callbacks(self, app):
188185
Input(self._subid("slice-index"), "data"),
189186
Input(self._subid("_slice-data"), "data"),
190187
],
191-
[State(self._subid("graph"), "figure")],
188+
[
189+
State(self._subid("graph"), "figure"),
190+
State(self._subid("_slice-data-lowres"), "data"),
191+
],
192192
)

dash_3d_viewer/utils.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
1+
import io
12
import random
3+
import base64
24

35
import PIL.Image
46
import skimage
5-
from plotly.utils import ImageUriValidator
67

78

89
def gen_random_id(n=6):
910
return "".join(random.choice("abcdefghijklmnopqrtsuvwxyz") for i in range(n))
1011

1112

12-
def img_array_to_uri(img_array):
13+
def img_array_to_uri(img_array, new_size=None):
1314
img_array = skimage.util.img_as_ubyte(img_array)
1415
# todo: leverage this Plotly util once it becomes part of the public API (also drops the Pillow dependency)
1516
# from plotly.express._imshow import _array_to_b64str
1617
# return _array_to_b64str(img_array)
1718
img_pil = PIL.Image.fromarray(img_array)
18-
uri = ImageUriValidator.pil_image_to_uri(img_pil)
19-
return uri
19+
if new_size:
20+
img_pil.thumbnail(new_size)
21+
# The below was taken from plotly.utils.ImageUriValidator.pil_image_to_uri()
22+
f = io.BytesIO()
23+
img_pil.save(f, format="PNG")
24+
base64_str = base64.b64encode(f.getvalue()).decode()
25+
return "data:image/png;base64," + base64_str

0 commit comments

Comments
 (0)