Skip to content

Commit

Permalink
picking and resampling works, also for filaments, but ugly gui
Browse files Browse the repository at this point in the history
  • Loading branch information
brisvag committed Oct 13, 2023
1 parent 85b4485 commit 6115930
Show file tree
Hide file tree
Showing 2 changed files with 267 additions and 104 deletions.
329 changes: 234 additions & 95 deletions src/blik/widgets/main_widget.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List
from typing import List

import napari
import numpy as np
import pandas as pd
from magicgui import magic_factory, magicgui
from magicgui.widgets import Container
from morphosamplers.surface_spline import GriddedSplineSurface
from napari.layers import Image, Labels, Points, Shapes, Vectors
from napari.utils._magicgui import find_viewer_ancestor
from napari.utils.notifications import show_info
from scipy.spatial.transform import Rotation

from ..reader import construct_particle_layer_tuples
from ..utils import generate_vectors, invert_xyz, layer_tuples_to_layers

if TYPE_CHECKING:
import napari
from .picking import (
_generate_filaments_from_points_layer,
_generate_surface_grids_from_shapes_layer,
_resample_filament,
_resample_surfaces,
)


def _get_choices(wdg, condition=None):
Expand Down Expand Up @@ -151,7 +154,9 @@ def add_to_exp(layer: napari.layers.Layer):
@magicgui(
labels=False,
call_button="Create",
l_type={"choices": ["segmentation", "particles", "surface_picking"]},
l_type={
"choices": ["segmentation", "particles", "surface_picking", "filament_picking"]
},
)
def new(
l_type,
Expand Down Expand Up @@ -198,6 +203,19 @@ def new(
ndim=3,
)

return [pts]
elif l_type == "filament_picking":
for lay in layers:
if isinstance(lay, Image) and lay.metadata["experiment_id"] == exp_id:
pts = Points(
name=f"{exp_id} - filament picks",
size=100 / lay.scale[0],
scale=lay.scale,
metadata={"experiment_id": exp_id},
face_color_cycle=np.random.rand(30, 3),
ndim=3,
)

return [pts]

show_info(f"cannot create a new {l_type}")
Expand All @@ -208,120 +226,235 @@ def new(
labels=True,
call_button="Generate",
spacing_A={"widget_type": "Slider", "min": 1, "max": 500},
output={"choices": ["surface", "particles"]},
inside_points={"nullable": True},
)
def surface(
surface_shapes: napari.layers.Shapes,
inside_points: napari.layers.Points,
spacing_A=100,
closed=False,
output="surface",
) -> napari.types.LayerDataTuple:
"""create a new surface representation from picked surface points."""
spacing_A /= surface_shapes.scale[0]
pos = []
ori = []
surface_grids, colors = _generate_surface_grids_from_shapes_layer(
surface_shapes, spacing_A
)

meshes = []
colors = []
exp_id = surface_shapes.metadata["experiment_id"]
data_array = np.array(surface_shapes.data, dtype=object) # helps with indexing
for i, (_, surf) in enumerate(surface_shapes.features.groupby("surface_id")):
lines = data_array[surf.index]
# sort so lines can be added in between at a later point
# also move to xyz world so math is the same as reader code
lines = [
invert_xyz(line).astype(float)
for line in sorted(lines, key=lambda x: x[0, 0])
]

if inside_points is not None:
inside_pt = invert_xyz(inside_points.data[i])
else:
inside_pt = None

try:
surface_grid = GriddedSplineSurface(
points=lines,
separation=spacing_A,
order=3,
closed=closed,
inside_point=inside_pt,
)
except ValueError:
continue
for surf in surface_grids:
meshes.append(surf.mesh())

offset = 0
vert = []
faces = []
ids = []
for surf_id, (v, f) in enumerate(meshes):
f += offset
offset += len(v)
vert.append(v)
faces.append(f)
ids.append(np.full(len(v), surf_id))
vert = np.concatenate(vert)
faces = np.concatenate(faces)
uniq_colors, idx = np.unique(colors, axis=0, return_index=True)
colormap = uniq_colors[np.argsort(idx)]
values = np.concatenate(ids) / len(colormap)
# special case for colormap with 1 color because blacks get autoadded at index 0
if colormap.shape[0] == 1:
values += 1

surface_layer_tuple = (
(invert_xyz(vert), faces, values),
{
"name": f"{exp_id} - surface",
"metadata": {
"experiment_id": exp_id,
"surface_grids": surface_grids,
"surface_colors": colors,
},
"scale": surface_shapes.scale,
"shading": "smooth",
"colormap": colormap,
},
"surface",
)
return [surface_layer_tuple]


@magicgui(
labels=True,
call_button="Generate",
spacing_A={"widget_type": "Slider", "min": 1, "max": 500},
)
def surface_particles(
surface: napari.layers.Surface,
spacing_A=50,
) -> napari.types.LayerDataTuple:
surface_grids = surface.metadata.get("surface_grids", None)
if surface_grids is None:
raise ValueError("This surface layer contains no surface grid object.")
colors = surface.metadata.get("surface_colors")

colors.append(surface_shapes.edge_color[surf.index])
exp_id = surface.metadata["experiment_id"]
spacing = spacing_A / surface.scale[0]

if output == "particles":
pos.append(surface_grid.sample())
ori.append(surface_grid.sample_orientations())
if output == "surface":
meshes.append(surface_grid.mesh())
pos = []
ori = []
for surf in surface_grids:
if not np.isclose(surf.separation, spacing):
surf.separation = spacing
pos.append(surf.sample())
ori.append(surf.sample_orientations())

if not colors:
raise RuntimeError("could not generate surfaces for some reason")
pos = np.concatenate(pos)
features = pd.DataFrame({"orientation": np.asarray(Rotation.concatenate(ori))})

colors = np.concatenate(colors)
vec_layer_tuple, pos_layer_tuple = construct_particle_layer_tuples(
coords=invert_xyz(pos),
features=features,
scale=surface.scale[0],
exp_id=exp_id,
face_color_cycle=colors,
)
return [vec_layer_tuple, pos_layer_tuple]

if output == "particles":
pos = np.concatenate(pos)
features = pd.DataFrame({"orientation": np.asarray(Rotation.concatenate(ori))})

vec_layer_tuple, pos_layer_tuple = construct_particle_layer_tuples(
coords=invert_xyz(pos),
features=features,
scale=surface_shapes.scale[0],
exp_id=exp_id,
face_color_cycle=colors,
@magicgui(
labels=True,
call_button="Resample",
spacing_A={"widget_type": "Slider", "min": 1, "max": 500},
thickness_A={"widget_type": "Slider", "min": 1, "max": 500},
)
def resample_surface(
surface: napari.layers.Surface,
volume: napari.layers.Image,
spacing_A=5,
thickness_A=200,
masked=False,
) -> napari.types.LayerDataTuple:
surface_grids = surface.metadata.get("surface_grids", None)
if surface_grids is None:
raise ValueError("This surface layer contains no surface grid object.")

exp_id = surface.metadata["experiment_id"]
spacing = spacing_A / surface.scale[0]
thickness = int(np.round(thickness_A / surface.scale[0]))
for surf in surface_grids:
if not np.isclose(surf.separation, spacing):
surf.separation = spacing

vols = _resample_surfaces(volume, surface_grids, spacing, thickness, masked)

v = napari.Viewer()
for i, vol in enumerate(vols):
v.add_image(
vol,
name=f"{exp_id} - surface_{i} resampled",
metadata={"experiment_id": exp_id},
scale=surface.scale,
)
return [vec_layer_tuple, pos_layer_tuple]

if output == "surface":
offset = 0
vert = []
faces = []
ids = []
for surf_id, (v, f) in enumerate(meshes):
f += offset
offset += len(v)
vert.append(v)
faces.append(f)
ids.append(np.full(len(v), surf_id))
vert = np.concatenate(vert)
faces = np.concatenate(faces)
uniq_colors, idx = np.unique(colors, axis=0, return_index=True)
colormap = uniq_colors[np.argsort(idx)]
values = np.concatenate(ids) / len(colormap)
# special case for colormap with 1 color because blacks get autoadded at index 0
if colormap.shape[0] == 1:
values += 1

surface_layer_tuple = (
(invert_xyz(vert), faces, values),
{
"name": f"{exp_id} - surface",
"metadata": {"experiment_id": exp_id},
"scale": surface_shapes.scale,
"shading": "smooth",
"colormap": colormap,


@magicgui(
labels=True,
call_button="Generate",
)
def filament(
points: napari.layers.Points,
) -> napari.types.LayerDataTuple:
filament = _generate_filaments_from_points_layer(points)

exp_id = points.metadata["experiment_id"]

path = filament.sample(n_samples=len(points.data) * 50)
shapes_layer_tuple = (
[invert_xyz(path)],
{
"name": f"{exp_id} - filament",
"metadata": {
"experiment_id": exp_id,
"helical_filament": filament,
},
"surface",
)
return [surface_layer_tuple]
"scale": points.scale,
"shape_type": "path",
},
"shapes",
)
return [shapes_layer_tuple]

return []

@magicgui(
labels=True,
call_button="Generate",
rise_A={"widget_type": "Slider", "min": 1, "max": 500},
radius_A={"widget_type": "Slider", "min": 0, "max": 500},
twist_deg={"widget_type": "Slider", "min": 0, "max": 360},
twist_offset={"widget_type": "Slider", "min": 0, "max": 360},
)
def filament_particles(
filament: napari.layers.Shapes,
rise_A=50,
twist_deg=0,
twist_offset=0,
radius_A=0,
cyclic_symmetry_order=1,
) -> napari.types.LayerDataTuple:
helical_filament = filament.metadata.get("helical_filament", None)
if helical_filament is None:
raise ValueError("This shapes layer contains no helical filament object.")

exp_id = filament.metadata["experiment_id"]

pos, ori = helical_filament.sample_helical(
rise=rise_A / filament.scale[0],
twist=twist_deg,
radial_offset=radius_A / filament.scale[0],
cyclic_symmetry_order=cyclic_symmetry_order,
twist_offset=twist_offset,
degrees=True,
)

features = pd.DataFrame({"orientation": np.asarray(Rotation.concatenate(ori))})

vec_layer_tuple, pos_layer_tuple = construct_particle_layer_tuples(
coords=invert_xyz(pos),
features=features,
scale=filament.scale[0],
exp_id=exp_id,
)
return [vec_layer_tuple, pos_layer_tuple]


@magicgui(
labels=False,
call_button="Add",
labels=True,
call_button="Resample",
spacing_A={"widget_type": "Slider", "min": 1, "max": 500},
thickness_A={"widget_type": "Slider", "min": 1, "max": 500},
)
def gen(layer: napari.layers.Layer):
"""add layer to the current experiment."""
layer.metadata["experiment_id"] = add_to_exp._main_widget[
"experiment"
].experiment_id.value
def resample_filament(
filament: napari.layers.Shapes,
volume: napari.layers.Image,
spacing_A=5,
thickness_A=200,
) -> napari.types.LayerDataTuple:
helical_filament = filament.metadata.get("helical_filament", None)
if helical_filament is None:
raise ValueError("This shapes layer contains no helical filament object.")

exp_id = filament.metadata["experiment_id"]
spacing = spacing_A / filament.scale[0]
thickness = int(np.round(thickness_A / filament.scale[0]))

vol = _resample_filament(volume, helical_filament, spacing, thickness)

v = napari.Viewer()
v.add_image(
vol,
name=f"{exp_id} - filament resampled",
metadata={"experiment_id": exp_id},
scale=filament.scale,
)


class MainBlikWidget(Container):
Expand All @@ -342,6 +475,12 @@ def __init__(self, *args, **kwargs):
self.append(new)
self.append(add_to_exp)
self.append(surface)
self.append(surface_particles)
self.append(resample_surface)
self.append(filament)
self.append(filament_particles)
self.append(resample_filament)
self.scrollable = True

def append(self, item):
super().append(item)
Expand Down
Loading

0 comments on commit 6115930

Please sign in to comment.