Skip to content

Commit c283fa2

Browse files
committed
Implement indicators
1 parent 87790f5 commit c283fa2

File tree

4 files changed

+163
-36
lines changed

4 files changed

+163
-36
lines changed

dash_3d_viewer/slicer.py

Lines changed: 149 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,47 @@
11
import numpy as np
2-
from plotly.graph_objects import Figure, Image
2+
from plotly.graph_objects import Figure, Image, Scatter
33
from dash import Dash
4-
from dash.dependencies import Input, Output, State
4+
from dash.dependencies import Input, Output, State, ALL
55
from dash_core_components import Graph, Slider, Store
66

7-
from .utils import gen_random_id, img_array_to_uri
7+
from .utils import img_array_to_uri, get_thumbnail_size_from_shape
88

99

1010
class DashVolumeSlicer:
11-
"""A slicer to show 3D image data in Dash."""
11+
"""A slicer to show 3D image data in Dash.
1212
13-
def __init__(self, app, volume, axis=0, id=None):
13+
Parameters:
14+
app (dash.Dash): the Dash application instance.
15+
volume (ndarray): the 3D numpy array to slice through.
16+
axis (int): the dimension to slice in. Default 0.
17+
volume_id (str): the id to use for the volume. By default this is a
18+
hash of ``id(volume)``. Slicers that have the same volume-id show
19+
each-other's positions with line indicators.
20+
21+
This is a placeholder object, not a Dash component. The components
22+
that make up the slicer can be accessed as attributes:
23+
24+
* ``graph``: the Graph object.
25+
* ``slider``: the Slider object.
26+
* ``stores``: a list of Store objects. Some are "public" values, others
27+
used internally. Make sure to put them somewhere in the layout.
28+
29+
Each component is given a dict-id with the following keys:
30+
31+
* "context": a unique string id for this slicer instance.
32+
* "volume": the volume_id.
33+
* "axis": the int axis.
34+
* "name": the name of the component.
35+
36+
TODO: iron out these details, list the stores that are public
37+
"""
38+
39+
_global_slicer_counter = 0
40+
41+
def __init__(self, app, volume, axis=0, volume_id=None):
1442
if not isinstance(app, Dash):
1543
raise TypeError("Expect first arg to be a Dash app.")
44+
self._app = app
1645
# Check and store volume
1746
if not (isinstance(volume, np.ndarray) and volume.ndim == 3):
1847
raise TypeError("Expected volume to be a 3D numpy array")
@@ -22,29 +51,36 @@ def __init__(self, app, volume, axis=0, id=None):
2251
raise ValueError("The given axis must be 0, 1, or 2.")
2352
self._axis = int(axis)
2453
# Check and store id
25-
if id is None:
26-
id = gen_random_id()
27-
elif not isinstance(id, str):
28-
raise TypeError("Id must be a string")
29-
self._id = id
54+
if volume_id is None:
55+
volume_id = hex(id(volume))
56+
elif not isinstance(volume_id, str):
57+
raise TypeError("volume_id must be a string")
58+
self.volume_id = volume_id
59+
# Get unique id scoped to this slicer object
60+
DashVolumeSlicer._global_slicer_counter += 1
61+
self.context_id = "slicer" + str(DashVolumeSlicer._global_slicer_counter)
3062

3163
# Get the slice size (width, height), and max index
32-
# arr_shape = list(volume.shape)
33-
# arr_shape.pop(self._axis)
34-
# slice_size = list(reversed(arr_shape))
64+
arr_shape = list(volume.shape)
65+
arr_shape.pop(self._axis)
66+
self._slice_size = tuple(reversed(arr_shape))
3567
self._max_index = self._volume.shape[self._axis] - 1
3668

3769
# Prep low-res slices
70+
thumbnail_size = get_thumbnail_size_from_shape(arr_shape, 32)
3871
thumbnails = [
39-
img_array_to_uri(self._slice(i), (32, 32))
72+
img_array_to_uri(self._slice(i), thumbnail_size)
4073
for i in range(self._max_index + 1)
4174
]
4275

4376
# Create a placeholder trace
4477
# todo: can add "%{z[0]}", but that would be the scaled value ...
45-
trace = Image(source="", hovertemplate="(%{x}, %{y})<extra></extra>")
78+
image_trace = Image(
79+
source="", dx=1, dy=1, hovertemplate="(%{x}, %{y})<extra></extra>"
80+
)
81+
scatter_trace = Scatter(x=[], y=[]) # placeholder
4682
# Create the figure object
47-
fig = Figure(data=[trace])
83+
self._fig = fig = Figure(data=[image_trace, scatter_trace])
4884
fig.update_layout(
4985
template=None,
5086
margin=dict(l=0, r=0, b=0, t=0, pad=4),
@@ -70,6 +106,7 @@ def __init__(self, app, volume, axis=0, id=None):
70106
config={"scrollZoom": True},
71107
)
72108
# Create a slider object that the user can put in the layout (or not)
109+
# todo: use tooltip to show current value?
73110
self.slider = Slider(
74111
id=self._subid("slider"),
75112
min=0,
@@ -80,18 +117,29 @@ def __init__(self, app, volume, axis=0, id=None):
80117
)
81118
# Create the stores that we need (these must be present in the layout)
82119
self.stores = [
83-
Store(id=self._subid("slice-index"), data=volume.shape[self._axis] // 2),
120+
Store(
121+
id=self._subid("_slice-size"), data=self._slice_size + thumbnail_size
122+
),
123+
Store(id=self._subid("index"), data=volume.shape[self._axis] // 2),
84124
Store(id=self._subid("_requested-slice-index"), data=0),
85125
Store(id=self._subid("_slice-data"), data=""),
86126
Store(id=self._subid("_slice-data-lowres"), data=thumbnails),
127+
Store(id=self._subid("_indicators"), data=[]),
87128
]
88129

89-
self._create_server_callbacks(app)
90-
self._create_client_callbacks(app)
130+
self._create_server_callbacks()
131+
self._create_client_callbacks()
91132

92-
def _subid(self, subid):
133+
def _subid(self, name):
93134
"""Given a subid, get the full id including the slicer's prefix."""
94-
return self._id + "-" + subid
135+
# return self.context_id + "-" + name
136+
# todo: is there a penalty for using a dict-id vs a string-id?
137+
return {
138+
"context": self.context_id,
139+
"volume-id": self.volume_id,
140+
"axis": self._axis,
141+
"name": name,
142+
}
95143

96144
def _slice(self, index):
97145
"""Sample a slice from the volume."""
@@ -100,8 +148,9 @@ def _slice(self, index):
100148
im = self._volume[tuple(indices)]
101149
return (im.astype(np.float32) * (255 / im.max())).astype(np.uint8)
102150

103-
def _create_server_callbacks(self, app):
151+
def _create_server_callbacks(self):
104152
"""Create the callbacks that run server-side."""
153+
app = self._app
105154

106155
@app.callback(
107156
Output(self._subid("_slice-data"), "data"),
@@ -111,16 +160,17 @@ def upload_requested_slice(slice_index):
111160
slice = self._slice(slice_index)
112161
return [slice_index, img_array_to_uri(slice)]
113162

114-
def _create_client_callbacks(self, app):
163+
def _create_client_callbacks(self):
115164
"""Create the callbacks that run client-side."""
165+
app = self._app
116166

117167
app.clientside_callback(
118168
"""
119169
function handle_slider_move(index) {
120170
return index;
121171
}
122172
""",
123-
Output(self._subid("slice-index"), "data"),
173+
Output(self._subid("index"), "data"),
124174
[Input(self._subid("slider"), "value")],
125175
)
126176

@@ -137,24 +187,24 @@ def _create_client_callbacks(self, app):
137187
}
138188
}
139189
""".replace(
140-
"{{ID}}", self._id
190+
"{{ID}}", self.context_id
141191
),
142192
Output(self._subid("_requested-slice-index"), "data"),
143-
[Input(self._subid("slice-index"), "data")],
193+
[Input(self._subid("index"), "data")],
144194
)
145195

146196
# app.clientside_callback("""
147197
# function update_slider_pos(index) {
148198
# return index;
149199
# }
150200
# """,
151-
# [Output("slice-index", "data")],
201+
# [Output("index", "data")],
152202
# [State("slider", "value")],
153203
# )
154204

155205
app.clientside_callback(
156206
"""
157-
function handle_incoming_slice(index, index_and_data, ori_figure, lowres) {
207+
function handle_incoming_slice(index, index_and_data, indicators, ori_figure, lowres, slice_size) {
158208
let new_index = index_and_data[0];
159209
let new_data = index_and_data[1];
160210
// Store data in cache
@@ -163,30 +213,98 @@ def _create_client_callbacks(self, app):
163213
slice_cache[new_index] = new_data;
164214
// Get the data we need *now*
165215
let data = slice_cache[index];
216+
let x0 = 0, y0 = 0, dx = 1, dy = 1;
166217
//slice_cache[new_index] = undefined; // todo: disabled cache for now!
167218
// Maybe we do not need an update
219+
console.log(slice_size)
168220
if (!data) {
169221
data = lowres[index];
222+
// Scale the image to take the exact same space as the full-res
223+
// version. It's not correct, but it looks better ...
224+
// slice_size = full_w, full_h, low_w, low_h
225+
dx = slice_size[0] / slice_size[2];
226+
dy = slice_size[1] / slice_size[3];
227+
x0 = 0.5 * dx - 0.5;
228+
y0 = 0.5 * dy - 0.5;
170229
}
171-
if (data == ori_figure.data[0].source) {
230+
if (data == ori_figure.data[0].source && indicators.version == ori_figure.data[1].version) {
172231
return window.dash_clientside.no_update;
173232
}
174233
// Otherwise, perform update
175234
console.log("updating figure");
176235
let figure = {...ori_figure};
177236
figure.data[0].source = data;
237+
figure.data[0].x0 = x0;
238+
figure.data[0].y0 = y0;
239+
figure.data[0].dx = dx;
240+
figure.data[0].dy = dy;
241+
figure.data[1] = indicators;
178242
return figure;
179243
}
180244
""".replace(
181-
"{{ID}}", self._id
245+
"{{ID}}", self.context_id
182246
),
183247
Output(self._subid("graph"), "figure"),
184248
[
185-
Input(self._subid("slice-index"), "data"),
249+
Input(self._subid("index"), "data"),
186250
Input(self._subid("_slice-data"), "data"),
251+
Input(self._subid("_indicators"), "data"),
187252
],
188253
[
189254
State(self._subid("graph"), "figure"),
190255
State(self._subid("_slice-data-lowres"), "data"),
256+
State(self._subid("_slice-size"), "data"),
257+
],
258+
)
259+
260+
# Select the *other* axii
261+
axii = [0, 1, 2]
262+
axii.pop(self._axis)
263+
264+
# Create a callback to create a trace representing all slice-indices that:
265+
# * corresponding to the same volume data
266+
# * match any of the selected axii
267+
app.clientside_callback(
268+
"""
269+
function handle_indicator(indices1, indices2, slice_size, current) {
270+
let w = slice_size[0], h = slice_size[1];
271+
let dx = w / 20, dy = h / 20;
272+
let version = (current.version || 0) + 1;
273+
let x = [], y = [];
274+
for (let index of indices1) {
275+
x.push(...[-dx, -1, null, w, w + dx, null]);
276+
y.push(...[index, index, index, index, index, index]);
277+
}
278+
for (let index of indices2) {
279+
x.push(...[index, index, index, index, index, index]);
280+
y.push(...[-dy, -1, null, h, h + dy, null]);
281+
}
282+
return {
283+
type: 'scatter',
284+
mode: 'lines',
285+
line: {color: '#ff00aa'},
286+
x: x,
287+
y: y,
288+
hoverinfo: 'skip',
289+
version: version
290+
}
291+
}
292+
""",
293+
Output(self._subid("_indicators"), "data"),
294+
[
295+
Input(
296+
{
297+
"volume-id": self.volume_id,
298+
"context": ALL,
299+
"name": "index",
300+
"axis": axis,
301+
},
302+
"data",
303+
)
304+
for axis in axii
305+
],
306+
[
307+
State(self._subid("_slice-size"), "data"),
308+
State(self._subid("_indicators"), "data"),
191309
],
192310
)

dash_3d_viewer/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33
import base64
44

5+
import numpy as np
56
import PIL.Image
67
import skimage
78

@@ -23,3 +24,11 @@ def img_array_to_uri(img_array, new_size=None):
2324
img_pil.save(f, format="PNG")
2425
base64_str = base64.b64encode(f.getvalue()).decode()
2526
return "data:image/png;base64," + base64_str
27+
28+
29+
def get_thumbnail_size_from_shape(shape, base_size):
30+
base_size = int(base_size)
31+
img_array = np.zeros(shape, np.uint8)
32+
img_pil = PIL.Image.fromarray(img_array)
33+
img_pil.thumbnail((base_size, base_size))
34+
return img_pil.size

examples/slicer_with_2_views.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
app = dash.Dash(__name__)
1212

1313
vol = imageio.volread("imageio:stent.npz")
14-
slicer1 = DashVolumeSlicer(app, vol, axis=1, id="slicer1")
15-
slicer2 = DashVolumeSlicer(app, vol, axis=2, id="slicer2")
14+
slicer1 = DashVolumeSlicer(app, vol, axis=1)
15+
slicer2 = DashVolumeSlicer(app, vol, axis=2)
1616

1717
app.layout = html.Div(
1818
style={

examples/slicer_with_3_views.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
# Read volumes and create slicer objects
1717
vol = imageio.volread("imageio:stent.npz")
18-
slicer1 = DashVolumeSlicer(app, vol, axis=0, id="slicer1")
19-
slicer2 = DashVolumeSlicer(app, vol, axis=1, id="slicer2")
20-
slicer3 = DashVolumeSlicer(app, vol, axis=2, id="slicer3")
18+
slicer1 = DashVolumeSlicer(app, vol, axis=0)
19+
slicer2 = DashVolumeSlicer(app, vol, axis=1)
20+
slicer3 = DashVolumeSlicer(app, vol, axis=2)
2121

2222
# Calculate isosurface and create a figure with a mesh object
2323
verts, faces, _, _ = marching_cubes(vol, 300, step_size=2)

0 commit comments

Comments
 (0)